ADD: push checkpoints to mlflow artifact registry (#1295) [skip ci]
* Add checkpoint logging to mlflow artifact registry * clean up * Update README.md Co-authored-by: NanoCode012 <kevinvong@rocketmail.com> * update pydantic config from rebase --------- Co-authored-by: NanoCode012 <kevinvong@rocketmail.com> Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
@@ -763,6 +763,7 @@ wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_step
|
||||
# mlflow configuration if you're using it
|
||||
mlflow_tracking_uri: # URI to mlflow
|
||||
mlflow_experiment_name: # Your experiment name
|
||||
hf_mlflow_log_artifacts: # set to true to copy each saved checkpoint on each save to mlflow artifact registry
|
||||
|
||||
# Where to save the full-finetuned model to
|
||||
output_dir: ./completed-model
|
||||
|
||||
@@ -305,6 +305,7 @@ class MLFlowConfig(BaseModel):
|
||||
use_mlflow: Optional[str] = None
|
||||
mlflow_tracking_uri: Optional[str] = None
|
||||
mlflow_experiment_name: Optional[str] = None
|
||||
hf_mlflow_log_artifacts: Optional[bool] = None
|
||||
|
||||
|
||||
class WandbConfig(BaseModel):
|
||||
|
||||
@@ -7,7 +7,7 @@ from axolotl.utils.dict import DictDefault
|
||||
|
||||
def setup_mlflow_env_vars(cfg: DictDefault):
|
||||
for key in cfg.keys():
|
||||
if key.startswith("mlflow_"):
|
||||
if key.startswith("mlflow_") or key.startswith("hf_mlflow_"):
|
||||
value = cfg.get(key, "")
|
||||
|
||||
if value and isinstance(value, str) and len(value) > 0:
|
||||
|
||||
Reference in New Issue
Block a user