diff --git a/README.md b/README.md index b09d0b162..f5058db3b 100644 --- a/README.md +++ b/README.md @@ -763,6 +763,7 @@ wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_step # mlflow configuration if you're using it mlflow_tracking_uri: # URI to mlflow mlflow_experiment_name: # Your experiment name +hf_mlflow_log_artifacts: # set to true to copy each saved checkpoint on each save to mlflow artifact registry # Where to save the full-finetuned model to output_dir: ./completed-model diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 433c84af1..4d7d8280e 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -305,6 +305,7 @@ class MLFlowConfig(BaseModel): use_mlflow: Optional[str] = None mlflow_tracking_uri: Optional[str] = None mlflow_experiment_name: Optional[str] = None + hf_mlflow_log_artifacts: Optional[bool] = None class WandbConfig(BaseModel): diff --git a/src/axolotl/utils/mlflow_.py b/src/axolotl/utils/mlflow_.py index fec2028ba..ce7739034 100644 --- a/src/axolotl/utils/mlflow_.py +++ b/src/axolotl/utils/mlflow_.py @@ -7,7 +7,7 @@ from axolotl.utils.dict import DictDefault def setup_mlflow_env_vars(cfg: DictDefault): for key in cfg.keys(): - if key.startswith("mlflow_"): + if key.startswith("mlflow_") or key.startswith("hf_mlflow_"): value = cfg.get(key, "") if value and isinstance(value, str) and len(value) > 0: