diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index f4a2f9001..4fadd7eb4 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -896,13 +896,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer): for key, value in metrics.items(): self._stored_metrics[train_eval][key].append(value) - def _save_checkpoint(self, model, trial): + def _save_checkpoint(self, model, trial, **kwargs): # make sure the checkpoint dir exists, since trainer is flakey checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" run_dir = self._get_output_dir(trial=trial) output_dir = os.path.join(run_dir, checkpoint_folder) os.makedirs(output_dir, exist_ok=True) - return super()._save_checkpoint(model, trial) + return super()._save_checkpoint(model, trial, **kwargs) class AxolotlMambaTrainer(AxolotlTrainer):