Add: mlflow for experiment tracking (#1059) [skip ci]

* Update requirements.txt

adding mlflow

* Update __init__.py

Imports for mlflow

* Update README.md

* Create mlflow_.py (#1)

* Update README.md

* fix precommits

* Update README.md

Update mlflow_tracking_uri

* Update trainer_builder.py

update trainer building

* chore: lint

* make ternary a bit more readable

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
Johan Hansson
2024-01-09 15:34:09 +01:00
committed by GitHub
parent 651b7a31fc
commit 090c24dcb0
5 changed files with 34 additions and 2 deletions

View File

@@ -29,6 +29,7 @@ from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.data import prepare_dataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
from axolotl.utils.models import load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.trainer import prepare_optim_env
@@ -289,6 +290,9 @@ def load_cfg(config: Path = Path("examples/"), **kwargs):
normalize_config(cfg)
setup_wandb_env_vars(cfg)
setup_mlflow_env_vars(cfg)
return cfg

View File

@@ -747,7 +747,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
False if self.cfg.ddp else None
)
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
training_arguments_kwargs["report_to"] = "wandb" if self.cfg.use_wandb else None
report_to = None
if self.cfg.use_wandb:
report_to = "wandb"
if self.cfg.use_mlflow:
report_to = "mlflow"
training_arguments_kwargs["report_to"] = report_to
training_arguments_kwargs["run_name"] = (
self.cfg.wandb_name if self.cfg.use_wandb else None
)

View File

@@ -0,0 +1,18 @@
"""Module for mlflow utilities"""
import os
from axolotl.utils.dict import DictDefault
def setup_mlflow_env_vars(cfg: DictDefault):
for key in cfg.keys():
if key.startswith("mlflow_"):
value = cfg.get(key, "")
if value and isinstance(value, str) and len(value) > 0:
os.environ[key.upper()] = value
# Enable mlflow if experiment name is present
if cfg.mlflow_experiment_name and len(cfg.mlflow_experiment_name) > 0:
cfg.use_mlflow = True