From 490923fb78e0646c2c0ba427628a0daa72b86996 Mon Sep 17 00:00:00 2001 From: Jan Philipp Harries <2862336+jphme@users.noreply.github.com> Date: Wed, 11 Oct 2023 13:28:12 +0200 Subject: [PATCH] Save Axolotl config as WandB artifact (#716) --- src/axolotl/cli/__init__.py | 1 + src/axolotl/utils/callbacks.py | 24 ++++++++++++++++++++++++ src/axolotl/utils/trainer.py | 4 ++++ 3 files changed, 29 insertions(+) diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index c21d93170..07a6209e4 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -194,6 +194,7 @@ def load_cfg(config: Path = Path("examples/"), **kwargs): # load the config from the yaml file with open(config, encoding="utf-8") as file: cfg: DictDefault = DictDefault(yaml.safe_load(file)) + cfg.axolotl_config_path = config # if there are any options passed in the cli, if it is something that seems valid from the yaml, # then overwrite the value cfg_keys = cfg.keys() diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index 9a7ebe951..458e537c6 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -514,3 +514,27 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer): return control return LogPredictionCallback + + +class SaveAxolotlConfigtoWandBCallback(TrainerCallback): + """Callback to save axolotl config to wandb""" + + def __init__(self, axolotl_config_path): + self.axolotl_config_path = axolotl_config_path + + def on_train_begin( + self, + args: AxolotlTrainingArguments, # pylint: disable=unused-argument + state: TrainerState, # pylint: disable=unused-argument + control: TrainerControl, + **kwargs, # pylint: disable=unused-argument + ): + if is_main_process(): + try: + artifact = wandb.Artifact(name="axolotl-config", type="config") + artifact.add_file(local_path=self.axolotl_config_path) + wandb.run.log_artifact(artifact) + LOG.info("Axolotl config has been saved to WandB as an artifact.") + except (FileNotFoundError, ConnectionError) as err: + LOG.warning(f"Error while saving Axolotl config to WandB: {err}") + return control diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index a10a2b0e7..ee8c63496 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -30,6 +30,7 @@ from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler from axolotl.utils.callbacks import ( EvalFirstStepCallback, GPUStatsCallback, + SaveAxolotlConfigtoWandBCallback, SaveBetterTransformerModelCallback, bench_eval_callback_factory, log_prediction_callback_factory, @@ -775,6 +776,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_ LogPredictionCallback = log_prediction_callback_factory(trainer, tokenizer) trainer.add_callback(LogPredictionCallback(cfg)) + if cfg.use_wandb: + trainer.add_callback(SaveAxolotlConfigtoWandBCallback(cfg.axolotl_config_path)) + if cfg.do_bench_eval: trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer))