ADD: push checkpoints to mlflow artifact registry (#1295) [skip ci]

* Add checkpoint logging to mlflow artifact registry

* clean up

* Update README.md

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* update pydantic config from rebase

---------

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
JohanWork
2024-02-26 19:32:39 +01:00
committed by GitHub
parent c6b01e0f4a
commit d75653407c
3 changed files with 3 additions and 1 deletions

View File

@@ -763,6 +763,7 @@ wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_step
# mlflow configuration if you're using it
mlflow_tracking_uri: # URI to mlflow
mlflow_experiment_name: # Your experiment name
hf_mlflow_log_artifacts: # set to true to copy each saved checkpoint on each save to mlflow artifact registry
# Where to save the full-finetuned model to
output_dir: ./completed-model

View File

@@ -305,6 +305,7 @@ class MLFlowConfig(BaseModel):
use_mlflow: Optional[str] = None
mlflow_tracking_uri: Optional[str] = None
mlflow_experiment_name: Optional[str] = None
hf_mlflow_log_artifacts: Optional[bool] = None
class WandbConfig(BaseModel):

View File

@@ -7,7 +7,7 @@ from axolotl.utils.dict import DictDefault
def setup_mlflow_env_vars(cfg: DictDefault):
for key in cfg.keys():
if key.startswith("mlflow_"):
if key.startswith("mlflow_") or key.startswith("hf_mlflow_"):
value = cfg.get(key, "")
if value and isinstance(value, str) and len(value) > 0: