Add mlflow callback for pushing config to mlflow artifacts (#1125)
* Update callbacks.py adding callback for mlflow * Update trainer_builder.py * clean up
This commit is contained in:
@@ -28,6 +28,7 @@ from axolotl.utils.callbacks import (
|
||||
EvalFirstStepCallback,
|
||||
GPUStatsCallback,
|
||||
LossWatchDogCallback,
|
||||
SaveAxolotlConfigtoMlflowCallback,
|
||||
SaveAxolotlConfigtoWandBCallback,
|
||||
SaveBetterTransformerModelCallback,
|
||||
bench_eval_callback_factory,
|
||||
@@ -543,6 +544,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
callbacks.append(
|
||||
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
||||
)
|
||||
if self.cfg.use_mlflow:
|
||||
callbacks.append(
|
||||
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
|
||||
)
|
||||
|
||||
if self.cfg.loss_watchdog_threshold is not None:
|
||||
callbacks.append(LossWatchDogCallback(self.cfg))
|
||||
|
||||
@@ -9,6 +9,7 @@ from tempfile import NamedTemporaryFile
|
||||
from typing import TYPE_CHECKING, Dict, List
|
||||
|
||||
import evaluate
|
||||
import mlflow
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
@@ -575,3 +576,31 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
|
||||
except (FileNotFoundError, ConnectionError) as err:
|
||||
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
|
||||
return control
|
||||
|
||||
|
||||
class SaveAxolotlConfigtoMlflowCallback(TrainerCallback):
|
||||
"""Callback to save axolotl config to mlflow"""
|
||||
|
||||
def __init__(self, axolotl_config_path):
|
||||
self.axolotl_config_path = axolotl_config_path
|
||||
|
||||
def on_train_begin(
|
||||
self,
|
||||
args: AxolotlTrainingArguments, # pylint: disable=unused-argument
|
||||
state: TrainerState, # pylint: disable=unused-argument
|
||||
control: TrainerControl,
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
):
|
||||
if is_main_process():
|
||||
try:
|
||||
with NamedTemporaryFile(
|
||||
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
|
||||
) as temp_file:
|
||||
copyfile(self.axolotl_config_path, temp_file.name)
|
||||
mlflow.log_artifact(temp_file.name, artifact_path="")
|
||||
LOG.info(
|
||||
"The Axolotl config has been saved to the MLflow artifacts."
|
||||
)
|
||||
except (FileNotFoundError, ConnectionError) as err:
|
||||
LOG.warning(f"Error while saving Axolotl config to MLflow: {err}")
|
||||
return control
|
||||
|
||||
Reference in New Issue
Block a user