diff --git a/src/axolotl/utils/callbacks/mlflow_.py b/src/axolotl/utils/callbacks/mlflow_.py index 47679001f..15ca1ca47 100644 --- a/src/axolotl/utils/callbacks/mlflow_.py +++ b/src/axolotl/utils/callbacks/mlflow_.py @@ -1,6 +1,7 @@ """MLFlow module for trainer callbacks""" import logging +import os from shutil import copyfile from tempfile import NamedTemporaryFile from typing import TYPE_CHECKING @@ -16,6 +17,11 @@ if TYPE_CHECKING: LOG = logging.getLogger("axolotl.callbacks") +def should_log_artifacts() -> bool: + truths = ["TRUE", "1", "YES"] + return os.getenv("HF_MLFLOW_LOG_ARTIFACTS", "FALSE").upper() in truths + + class SaveAxolotlConfigtoMlflowCallback(TrainerCallback): # pylint: disable=duplicate-code """Callback to save axolotl config to mlflow""" @@ -32,13 +38,18 @@ class SaveAxolotlConfigtoMlflowCallback(TrainerCallback): ): if is_main_process(): try: - with NamedTemporaryFile( - mode="w", delete=False, suffix=".yml", prefix="axolotl_config_" - ) as temp_file: - copyfile(self.axolotl_config_path, temp_file.name) - mlflow.log_artifact(temp_file.name, artifact_path="") + if should_log_artifacts(): + with NamedTemporaryFile( + mode="w", delete=False, suffix=".yml", prefix="axolotl_config_" + ) as temp_file: + copyfile(self.axolotl_config_path, temp_file.name) + mlflow.log_artifact(temp_file.name, artifact_path="") + LOG.info( + "The Axolotl config has been saved to the MLflow artifacts." + ) + else: LOG.info( - "The Axolotl config has been saved to the MLflow artifacts." + "Skipping logging artifacts to MLflow (hf_mlflow_log_artifacts is false)" ) except (FileNotFoundError, ConnectionError) as err: LOG.warning(f"Error while saving Axolotl config to MLflow: {err}")