From d75653407cd65111742eca9f2b0b3963dc0268ea Mon Sep 17 00:00:00 2001 From: JohanWork <39947546+JohanWork@users.noreply.github.com> Date: Mon, 26 Feb 2024 19:32:39 +0100 Subject: [PATCH] ADD: push checkpoints to mlflow artifact registry (#1295) [skip ci] * Add checkpoint logging to mlflow artifact registry * clean up * Update README.md Co-authored-by: NanoCode012 * update pydantic config from rebase --------- Co-authored-by: NanoCode012 Co-authored-by: Wing Lian --- README.md | 1 + src/axolotl/utils/config/models/input/v0_4_1/__init__.py | 1 + src/axolotl/utils/mlflow_.py | 2 +- 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index b09d0b162..f5058db3b 100644 --- a/README.md +++ b/README.md @@ -763,6 +763,7 @@ wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_step # mlflow configuration if you're using it mlflow_tracking_uri: # URI to mlflow mlflow_experiment_name: # Your experiment name +hf_mlflow_log_artifacts: # set to true to copy each saved checkpoint on each save to mlflow artifact registry # Where to save the full-finetuned model to output_dir: ./completed-model diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 433c84af1..4d7d8280e 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -305,6 +305,7 @@ class MLFlowConfig(BaseModel): use_mlflow: Optional[str] = None mlflow_tracking_uri: Optional[str] = None mlflow_experiment_name: Optional[str] = None + hf_mlflow_log_artifacts: Optional[bool] = None class WandbConfig(BaseModel): diff --git a/src/axolotl/utils/mlflow_.py b/src/axolotl/utils/mlflow_.py index fec2028ba..ce7739034 100644 --- a/src/axolotl/utils/mlflow_.py +++ b/src/axolotl/utils/mlflow_.py @@ -7,7 +7,7 @@ from axolotl.utils.dict import DictDefault def setup_mlflow_env_vars(cfg: DictDefault): for key in cfg.keys(): - if key.startswith("mlflow_"): + if key.startswith("mlflow_") or key.startswith("hf_mlflow_"): value = cfg.get(key, "") if value and isinstance(value, str) and len(value) > 0: