Add mlflow callback for pushing config to mlflow artifacts (#1125)
* Update callbacks.py adding callback for mlflow * Update trainer_builder.py * clean up
This commit is contained in:
@@ -28,6 +28,7 @@ from axolotl.utils.callbacks import (
|
|||||||
EvalFirstStepCallback,
|
EvalFirstStepCallback,
|
||||||
GPUStatsCallback,
|
GPUStatsCallback,
|
||||||
LossWatchDogCallback,
|
LossWatchDogCallback,
|
||||||
|
SaveAxolotlConfigtoMlflowCallback,
|
||||||
SaveAxolotlConfigtoWandBCallback,
|
SaveAxolotlConfigtoWandBCallback,
|
||||||
SaveBetterTransformerModelCallback,
|
SaveBetterTransformerModelCallback,
|
||||||
bench_eval_callback_factory,
|
bench_eval_callback_factory,
|
||||||
@@ -543,6 +544,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
callbacks.append(
|
callbacks.append(
|
||||||
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
||||||
)
|
)
|
||||||
|
if self.cfg.use_mlflow:
|
||||||
|
callbacks.append(
|
||||||
|
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
|
||||||
|
)
|
||||||
|
|
||||||
if self.cfg.loss_watchdog_threshold is not None:
|
if self.cfg.loss_watchdog_threshold is not None:
|
||||||
callbacks.append(LossWatchDogCallback(self.cfg))
|
callbacks.append(LossWatchDogCallback(self.cfg))
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from tempfile import NamedTemporaryFile
|
|||||||
from typing import TYPE_CHECKING, Dict, List
|
from typing import TYPE_CHECKING, Dict, List
|
||||||
|
|
||||||
import evaluate
|
import evaluate
|
||||||
|
import mlflow
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import torch
|
import torch
|
||||||
@@ -575,3 +576,31 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
|
|||||||
except (FileNotFoundError, ConnectionError) as err:
|
except (FileNotFoundError, ConnectionError) as err:
|
||||||
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
|
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
|
||||||
return control
|
return control
|
||||||
|
|
||||||
|
|
||||||
|
class SaveAxolotlConfigtoMlflowCallback(TrainerCallback):
|
||||||
|
"""Callback to save axolotl config to mlflow"""
|
||||||
|
|
||||||
|
def __init__(self, axolotl_config_path):
|
||||||
|
self.axolotl_config_path = axolotl_config_path
|
||||||
|
|
||||||
|
def on_train_begin(
|
||||||
|
self,
|
||||||
|
args: AxolotlTrainingArguments, # pylint: disable=unused-argument
|
||||||
|
state: TrainerState, # pylint: disable=unused-argument
|
||||||
|
control: TrainerControl,
|
||||||
|
**kwargs, # pylint: disable=unused-argument
|
||||||
|
):
|
||||||
|
if is_main_process():
|
||||||
|
try:
|
||||||
|
with NamedTemporaryFile(
|
||||||
|
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
|
||||||
|
) as temp_file:
|
||||||
|
copyfile(self.axolotl_config_path, temp_file.name)
|
||||||
|
mlflow.log_artifact(temp_file.name, artifact_path="")
|
||||||
|
LOG.info(
|
||||||
|
"The Axolotl config has been saved to the MLflow artifacts."
|
||||||
|
)
|
||||||
|
except (FileNotFoundError, ConnectionError) as err:
|
||||||
|
LOG.warning(f"Error while saving Axolotl config to MLflow: {err}")
|
||||||
|
return control
|
||||||
|
|||||||
Reference in New Issue
Block a user