Log checkpoints as mlflow artifacts (#1976)
* Ensure hf_mlflow_log_artifact config var is set in env * Add transformer MLflowCallback to callbacks list when mlflow enabled * Test hf_mlflow_log_artifacts is set correctly * Test mlflow not being used by default
This commit is contained in:
@@ -1119,12 +1119,17 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
||||||
)
|
)
|
||||||
if self.cfg.use_mlflow and is_mlflow_available():
|
if self.cfg.use_mlflow and is_mlflow_available():
|
||||||
|
from transformers.integrations.integration_utils import MLflowCallback
|
||||||
|
|
||||||
from axolotl.utils.callbacks.mlflow_ import (
|
from axolotl.utils.callbacks.mlflow_ import (
|
||||||
SaveAxolotlConfigtoMlflowCallback,
|
SaveAxolotlConfigtoMlflowCallback,
|
||||||
)
|
)
|
||||||
|
|
||||||
callbacks.append(
|
callbacks.extend(
|
||||||
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
|
[
|
||||||
|
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path),
|
||||||
|
MLflowCallback,
|
||||||
|
]
|
||||||
)
|
)
|
||||||
if self.cfg.use_comet and is_comet_available():
|
if self.cfg.use_comet and is_comet_available():
|
||||||
from axolotl.utils.callbacks.comet_ import SaveAxolotlConfigtoCometCallback
|
from axolotl.utils.callbacks.comet_ import SaveAxolotlConfigtoCometCallback
|
||||||
|
|||||||
@@ -16,3 +16,7 @@ def setup_mlflow_env_vars(cfg: DictDefault):
|
|||||||
# Enable mlflow if experiment name is present
|
# Enable mlflow if experiment name is present
|
||||||
if cfg.mlflow_experiment_name and len(cfg.mlflow_experiment_name) > 0:
|
if cfg.mlflow_experiment_name and len(cfg.mlflow_experiment_name) > 0:
|
||||||
cfg.use_mlflow = True
|
cfg.use_mlflow = True
|
||||||
|
|
||||||
|
# Enable logging hf artifacts in mlflow if value is truthy
|
||||||
|
if cfg.hf_mlflow_log_artifacts is True:
|
||||||
|
os.environ["HF_MLFLOW_LOG_ARTIFACTS"] = "true"
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from axolotl.utils import is_comet_available
|
|||||||
from axolotl.utils.config import validate_config
|
from axolotl.utils.config import validate_config
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import AxolotlConfigWCapabilities
|
from axolotl.utils.config.models.input.v0_4_1 import AxolotlConfigWCapabilities
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
|
||||||
from axolotl.utils.models import check_model_config
|
from axolotl.utils.models import check_model_config
|
||||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||||
|
|
||||||
@@ -1432,3 +1433,58 @@ class TestValidationComet(BaseValidation):
|
|||||||
|
|
||||||
for key in comet_env.keys():
|
for key in comet_env.keys():
|
||||||
os.environ.pop(key, None)
|
os.environ.pop(key, None)
|
||||||
|
|
||||||
|
|
||||||
|
class TestValidationMLflow(BaseValidation):
|
||||||
|
"""
|
||||||
|
Validation test for MLflow
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_hf_mlflow_artifacts_config_sets_env(self, minimal_cfg):
|
||||||
|
cfg = (
|
||||||
|
DictDefault(
|
||||||
|
{
|
||||||
|
"hf_mlflow_log_artifacts": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
| minimal_cfg
|
||||||
|
)
|
||||||
|
|
||||||
|
new_cfg = validate_config(cfg)
|
||||||
|
|
||||||
|
assert new_cfg.hf_mlflow_log_artifacts is True
|
||||||
|
|
||||||
|
# Check it's not already present in env
|
||||||
|
assert "HF_MLFLOW_LOG_ARTIFACTS" not in os.environ
|
||||||
|
|
||||||
|
setup_mlflow_env_vars(new_cfg)
|
||||||
|
|
||||||
|
assert os.environ.get("HF_MLFLOW_LOG_ARTIFACTS") == "true"
|
||||||
|
|
||||||
|
os.environ.pop("HF_MLFLOW_LOG_ARTIFACTS", None)
|
||||||
|
|
||||||
|
def test_mlflow_not_used_by_default(self, minimal_cfg):
|
||||||
|
cfg = DictDefault({}) | minimal_cfg
|
||||||
|
|
||||||
|
new_cfg = validate_config(cfg)
|
||||||
|
|
||||||
|
setup_mlflow_env_vars(new_cfg)
|
||||||
|
|
||||||
|
assert cfg.use_mlflow is not True
|
||||||
|
|
||||||
|
cfg = (
|
||||||
|
DictDefault(
|
||||||
|
{
|
||||||
|
"mlflow_experiment_name": "foo",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
| minimal_cfg
|
||||||
|
)
|
||||||
|
|
||||||
|
new_cfg = validate_config(cfg)
|
||||||
|
|
||||||
|
setup_mlflow_env_vars(new_cfg)
|
||||||
|
|
||||||
|
assert new_cfg.use_mlflow is True
|
||||||
|
|
||||||
|
os.environ.pop("MLFLOW_EXPERIMENT_NAME", None)
|
||||||
|
|||||||
Reference in New Issue
Block a user