Add: mlflow for experiment tracking (#1059) [skip ci]
* Update requirements.txt adding mlflow * Update __init__.py Imports for mlflow * Update README.md * Create mlflow_.py (#1) * Update README.md * fix precommits * Update README.md Update mlflow_tracking_uri * Update trainer_builder.py update trainer building * chore: lint * make ternary a bit more readable --------- Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
@@ -10,7 +10,7 @@ Features:
|
|||||||
- Integrated with xformer, flash attention, rope scaling, and multipacking
|
- Integrated with xformer, flash attention, rope scaling, and multipacking
|
||||||
- Works with single GPU or multiple GPUs via FSDP or Deepspeed
|
- Works with single GPU or multiple GPUs via FSDP or Deepspeed
|
||||||
- Easily run with Docker locally or on the cloud
|
- Easily run with Docker locally or on the cloud
|
||||||
- Log results and optionally checkpoints to wandb
|
- Log results and optionally checkpoints to wandb or mlflow
|
||||||
- And more!
|
- And more!
|
||||||
|
|
||||||
|
|
||||||
@@ -695,6 +695,10 @@ wandb_name: # Set the name of your wandb run
|
|||||||
wandb_run_id: # Set the ID of your wandb run
|
wandb_run_id: # Set the ID of your wandb run
|
||||||
wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training
|
wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training
|
||||||
|
|
||||||
|
# mlflow configuration if you're using it
|
||||||
|
mlflow_tracking_uri: # URI to mlflow
|
||||||
|
mlflow_experiment_name: # Your experiment name
|
||||||
|
|
||||||
# Where to save the full-finetuned model to
|
# Where to save the full-finetuned model to
|
||||||
output_dir: ./completed-model
|
output_dir: ./completed-model
|
||||||
|
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ hf_transfer
|
|||||||
colorama
|
colorama
|
||||||
numba
|
numba
|
||||||
numpy>=1.24.4
|
numpy>=1.24.4
|
||||||
|
mlflow
|
||||||
# qlora things
|
# qlora things
|
||||||
bert-score==0.3.13
|
bert-score==0.3.13
|
||||||
evaluate==0.4.0
|
evaluate==0.4.0
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ from axolotl.utils.config import normalize_config, validate_config
|
|||||||
from axolotl.utils.data import prepare_dataset
|
from axolotl.utils.data import prepare_dataset
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.distributed import is_main_process
|
from axolotl.utils.distributed import is_main_process
|
||||||
|
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
|
||||||
from axolotl.utils.models import load_tokenizer
|
from axolotl.utils.models import load_tokenizer
|
||||||
from axolotl.utils.tokenization import check_dataset_labels
|
from axolotl.utils.tokenization import check_dataset_labels
|
||||||
from axolotl.utils.trainer import prepare_optim_env
|
from axolotl.utils.trainer import prepare_optim_env
|
||||||
@@ -289,6 +290,9 @@ def load_cfg(config: Path = Path("examples/"), **kwargs):
|
|||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
|
|
||||||
setup_wandb_env_vars(cfg)
|
setup_wandb_env_vars(cfg)
|
||||||
|
|
||||||
|
setup_mlflow_env_vars(cfg)
|
||||||
|
|
||||||
return cfg
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -747,7 +747,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
False if self.cfg.ddp else None
|
False if self.cfg.ddp else None
|
||||||
)
|
)
|
||||||
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
|
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
|
||||||
training_arguments_kwargs["report_to"] = "wandb" if self.cfg.use_wandb else None
|
report_to = None
|
||||||
|
if self.cfg.use_wandb:
|
||||||
|
report_to = "wandb"
|
||||||
|
if self.cfg.use_mlflow:
|
||||||
|
report_to = "mlflow"
|
||||||
|
training_arguments_kwargs["report_to"] = report_to
|
||||||
training_arguments_kwargs["run_name"] = (
|
training_arguments_kwargs["run_name"] = (
|
||||||
self.cfg.wandb_name if self.cfg.use_wandb else None
|
self.cfg.wandb_name if self.cfg.use_wandb else None
|
||||||
)
|
)
|
||||||
|
|||||||
18
src/axolotl/utils/mlflow_.py
Normal file
18
src/axolotl/utils/mlflow_.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
"""Module for mlflow utilities"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
|
||||||
|
def setup_mlflow_env_vars(cfg: DictDefault):
|
||||||
|
for key in cfg.keys():
|
||||||
|
if key.startswith("mlflow_"):
|
||||||
|
value = cfg.get(key, "")
|
||||||
|
|
||||||
|
if value and isinstance(value, str) and len(value) > 0:
|
||||||
|
os.environ[key.upper()] = value
|
||||||
|
|
||||||
|
# Enable mlflow if experiment name is present
|
||||||
|
if cfg.mlflow_experiment_name and len(cfg.mlflow_experiment_name) > 0:
|
||||||
|
cfg.use_mlflow = True
|
||||||
Reference in New Issue
Block a user