run eval on the first step to get a baseline (#617)

* run eval on the first step to get a baseline

* wandb kleeps getting moved around by pre-commit ...
This commit is contained in:
Wing Lian
2023-09-21 21:51:09 -04:00
committed by GitHub
parent e85d2eb06b
commit 2844eb22b6
2 changed files with 25 additions and 0 deletions

View File

@@ -66,6 +66,29 @@ class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-
return control
class EvalFirstStepCallback(
TrainerCallback
): # pylint: disable=too-few-public-methods disable=unused-argument
"""
Callback to trigger evals on the first step
"""
def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
if (
args.evaluation_strategy == IntervalStrategy.STEPS
and args.eval_steps < 1.0
and state.global_step == 1
):
control.should_evaluate = True
return control
class SaveBetterTransformerModelCallback(
TrainerCallback
): # pylint: disable=too-few-public-methods

View File

@@ -28,6 +28,7 @@ from transformers.trainer_pt_utils import SequentialDistributedSampler
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils.callbacks import (
EvalFirstStepCallback,
GPUStatsCallback,
SaveBetterTransformerModelCallback,
SavePeftModelCallback,
@@ -704,6 +705,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
callbacks = []
callbacks.append(GPUStatsCallback(cfg))
callbacks.append(EvalFirstStepCallback)
if cfg.relora_steps:
callbacks.append(ReLoRACallback(cfg))