checkpoint model on first step callback (#2906)

* checkpoint model on first step callback

* remove debug

* add test cases; update existing tests not to save on first step

* move test out of solo

* delete

* default to False

* typo
This commit is contained in:
Dan Saunders
2025-07-15 15:00:48 -04:00
committed by GitHub
parent d320ef6199
commit 10ba1622f7
146 changed files with 419 additions and 9 deletions

View File

@@ -36,6 +36,7 @@ from axolotl.utils.callbacks import (
GCCallback,
GPUStatsCallback,
SaveAxolotlConfigtoWandBCallback,
SaveModelOnFirstStepCallback,
)
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
@@ -135,6 +136,8 @@ class TrainerBuilderBase(abc.ABC):
callbacks.append(
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
)
if self.cfg.save_first_step:
callbacks.append(SaveModelOnFirstStepCallback())
callbacks.append(GPUStatsCallback(cfg=self.cfg))

View File

@@ -64,7 +64,7 @@ class SaveBetterTransformerModelCallback(
state: TrainerState,
control: TrainerControl,
**kwargs,
):
) -> TrainerControl:
# Save
if (
args.save_strategy == IntervalStrategy.STEPS
@@ -100,11 +100,11 @@ class GPUStatsCallback(
def on_step_end(
self,
args: TrainingArguments,
args: TrainingArguments, # pylint: disable=unused-argument
state: TrainerState,
control: TrainerControl,
**kwargs,
):
) -> TrainerControl:
if not self.logged and state.global_step > 1:
log_gpu_memory_usage(LOG, "while training", self.cfg.device)
self.logged = True
@@ -116,18 +116,17 @@ class LossWatchDogCallback(TrainerCallback):
def __init__(self, cfg):
self.cfg = cfg
self.logged = False
self.violations = 0
self.threshold = cfg.loss_watchdog_threshold
self.patience = cfg.loss_watchdog_patience or 3
def on_step_end(
self,
_args: TrainingArguments,
args: TrainingArguments, # pylint: disable=unused-argument
state: TrainerState,
control: TrainerControl,
**_kwargs,
):
) -> TrainerControl:
if len(state.log_history) > 0 and "loss" in state.log_history[-1]:
if state.log_history[-1]["loss"] > self.threshold:
self.violations += 1
@@ -141,6 +140,21 @@ class LossWatchDogCallback(TrainerCallback):
return control
class SaveModelOnFirstStepCallback(TrainerCallback):
"""Callback to save the model on the first step of training if enabled"""
def on_step_end(
self,
args: TrainingArguments, # pylint: disable=unused-argument
state: TrainerState,
control: TrainerControl,
**_kwargs,
) -> TrainerControl:
if state.global_step == 1:
control.should_save = True
return control
def bench_eval_callback_factory(trainer, tokenizer):
accuracy = evaluate.load("accuracy")
abcd_idx = [

View File

@@ -706,6 +706,7 @@ class AxolotlInputConfig(
"description": "Set to `no` to skip evaluation, `epoch` at end of each epoch, leave empty to infer from `eval_steps`"
},
)
save_steps: int | float | None = Field(
default=None,
json_schema_extra={
@@ -727,6 +728,13 @@ class AxolotlInputConfig(
save_total_limit: int | None = Field(
default=None, json_schema_extra={"description": "Checkpoints saved at a time"}
)
save_first_step: bool | None = Field(
default=None,
json_schema_extra={
"description": "Whether to checkpoint a model after the first step of training. Defaults to False."
},
)
logging_steps: int | None = Field(
default=None, json_schema_extra={"description": "Logging frequency"}
)