checkpoint model on first step callback (#2906)
* checkpoint model on first step callback * remove debug * add test cases; update existing tests not to save on first step * move test out of solo * delete * default to False * typo
This commit is contained in:
@@ -36,6 +36,7 @@ from axolotl.utils.callbacks import (
|
||||
GCCallback,
|
||||
GPUStatsCallback,
|
||||
SaveAxolotlConfigtoWandBCallback,
|
||||
SaveModelOnFirstStepCallback,
|
||||
)
|
||||
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
|
||||
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
|
||||
@@ -135,6 +136,8 @@ class TrainerBuilderBase(abc.ABC):
|
||||
callbacks.append(
|
||||
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
|
||||
)
|
||||
if self.cfg.save_first_step:
|
||||
callbacks.append(SaveModelOnFirstStepCallback())
|
||||
|
||||
callbacks.append(GPUStatsCallback(cfg=self.cfg))
|
||||
|
||||
|
||||
@@ -64,7 +64,7 @@ class SaveBetterTransformerModelCallback(
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
) -> TrainerControl:
|
||||
# Save
|
||||
if (
|
||||
args.save_strategy == IntervalStrategy.STEPS
|
||||
@@ -100,11 +100,11 @@ class GPUStatsCallback(
|
||||
|
||||
def on_step_end(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
args: TrainingArguments, # pylint: disable=unused-argument
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
) -> TrainerControl:
|
||||
if not self.logged and state.global_step > 1:
|
||||
log_gpu_memory_usage(LOG, "while training", self.cfg.device)
|
||||
self.logged = True
|
||||
@@ -116,18 +116,17 @@ class LossWatchDogCallback(TrainerCallback):
|
||||
|
||||
def __init__(self, cfg):
|
||||
self.cfg = cfg
|
||||
self.logged = False
|
||||
self.violations = 0
|
||||
self.threshold = cfg.loss_watchdog_threshold
|
||||
self.patience = cfg.loss_watchdog_patience or 3
|
||||
|
||||
def on_step_end(
|
||||
self,
|
||||
_args: TrainingArguments,
|
||||
args: TrainingArguments, # pylint: disable=unused-argument
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**_kwargs,
|
||||
):
|
||||
) -> TrainerControl:
|
||||
if len(state.log_history) > 0 and "loss" in state.log_history[-1]:
|
||||
if state.log_history[-1]["loss"] > self.threshold:
|
||||
self.violations += 1
|
||||
@@ -141,6 +140,21 @@ class LossWatchDogCallback(TrainerCallback):
|
||||
return control
|
||||
|
||||
|
||||
class SaveModelOnFirstStepCallback(TrainerCallback):
|
||||
"""Callback to save the model on the first step of training if enabled"""
|
||||
|
||||
def on_step_end(
|
||||
self,
|
||||
args: TrainingArguments, # pylint: disable=unused-argument
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**_kwargs,
|
||||
) -> TrainerControl:
|
||||
if state.global_step == 1:
|
||||
control.should_save = True
|
||||
return control
|
||||
|
||||
|
||||
def bench_eval_callback_factory(trainer, tokenizer):
|
||||
accuracy = evaluate.load("accuracy")
|
||||
abcd_idx = [
|
||||
|
||||
@@ -706,6 +706,7 @@ class AxolotlInputConfig(
|
||||
"description": "Set to `no` to skip evaluation, `epoch` at end of each epoch, leave empty to infer from `eval_steps`"
|
||||
},
|
||||
)
|
||||
|
||||
save_steps: int | float | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
@@ -727,6 +728,13 @@ class AxolotlInputConfig(
|
||||
save_total_limit: int | None = Field(
|
||||
default=None, json_schema_extra={"description": "Checkpoints saved at a time"}
|
||||
)
|
||||
save_first_step: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Whether to checkpoint a model after the first step of training. Defaults to False."
|
||||
},
|
||||
)
|
||||
|
||||
logging_steps: int | None = Field(
|
||||
default=None, json_schema_extra={"description": "Logging frequency"}
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user