diff --git a/README.md b/README.md index 7e53da2e1..12fad21b1 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ Features: - Integrated with xformer, flash attention, rope scaling, and multipacking - Works with single GPU or multiple GPUs via FSDP or Deepspeed - Easily run with Docker locally or on the cloud -- Log results and optionally checkpoints to wandb +- Log results and optionally checkpoints to wandb or mlflow - And more! @@ -695,6 +695,10 @@ wandb_name: # Set the name of your wandb run wandb_run_id: # Set the ID of your wandb run wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training +# mlflow configuration if you're using it +mlflow_tracking_uri: # URI to mlflow +mlflow_experiment_name: # Your experiment name + # Where to save the full-finetuned model to output_dir: ./completed-model diff --git a/requirements.txt b/requirements.txt index 391bb52d9..b2f5958de 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,6 +22,7 @@ hf_transfer colorama numba numpy>=1.24.4 +mlflow # qlora things bert-score==0.3.13 evaluate==0.4.0 diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index 4f441f527..15902516d 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -29,6 +29,7 @@ from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.data import prepare_dataset from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import is_main_process +from axolotl.utils.mlflow_ import setup_mlflow_env_vars from axolotl.utils.models import load_tokenizer from axolotl.utils.tokenization import check_dataset_labels from axolotl.utils.trainer import prepare_optim_env @@ -289,6 +290,9 @@ def load_cfg(config: Path = Path("examples/"), **kwargs): normalize_config(cfg) setup_wandb_env_vars(cfg) + + setup_mlflow_env_vars(cfg) + return cfg diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index f0d1c4343..4c30fe517 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -747,7 +747,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): False if self.cfg.ddp else None ) training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length - training_arguments_kwargs["report_to"] = "wandb" if self.cfg.use_wandb else None + report_to = None + if self.cfg.use_wandb: + report_to = "wandb" + if self.cfg.use_mlflow: + report_to = "mlflow" + training_arguments_kwargs["report_to"] = report_to training_arguments_kwargs["run_name"] = ( self.cfg.wandb_name if self.cfg.use_wandb else None ) diff --git a/src/axolotl/utils/mlflow_.py b/src/axolotl/utils/mlflow_.py new file mode 100644 index 000000000..fec2028ba --- /dev/null +++ b/src/axolotl/utils/mlflow_.py @@ -0,0 +1,18 @@ +"""Module for mlflow utilities""" + +import os + +from axolotl.utils.dict import DictDefault + + +def setup_mlflow_env_vars(cfg: DictDefault): + for key in cfg.keys(): + if key.startswith("mlflow_"): + value = cfg.get(key, "") + + if value and isinstance(value, str) and len(value) > 0: + os.environ[key.upper()] = value + + # Enable mlflow if experiment name is present + if cfg.mlflow_experiment_name and len(cfg.mlflow_experiment_name) > 0: + cfg.use_mlflow = True