use kwargs to support patch release

This commit is contained in:
Wing Lian
2024-11-06 09:43:35 -05:00
parent 6b617a4fd5
commit 613f238e56

View File

@@ -896,13 +896,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
for key, value in metrics.items():
self._stored_metrics[train_eval][key].append(value)
def _save_checkpoint(self, model, trial):
def _save_checkpoint(self, model, trial, **kwargs):
# make sure the checkpoint dir exists, since trainer is flakey
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
run_dir = self._get_output_dir(trial=trial)
output_dir = os.path.join(run_dir, checkpoint_folder)
os.makedirs(output_dir, exist_ok=True)
return super()._save_checkpoint(model, trial)
return super()._save_checkpoint(model, trial, **kwargs)
class AxolotlMambaTrainer(AxolotlTrainer):