run eval on the first step to get a baseline (#617)
* run eval on the first step to get a baseline * wandb kleeps getting moved around by pre-commit ...
This commit is contained in:
@@ -66,6 +66,29 @@ class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-
|
|||||||
return control
|
return control
|
||||||
|
|
||||||
|
|
||||||
|
class EvalFirstStepCallback(
|
||||||
|
TrainerCallback
|
||||||
|
): # pylint: disable=too-few-public-methods disable=unused-argument
|
||||||
|
"""
|
||||||
|
Callback to trigger evals on the first step
|
||||||
|
"""
|
||||||
|
|
||||||
|
def on_step_end(
|
||||||
|
self,
|
||||||
|
args: TrainingArguments,
|
||||||
|
state: TrainerState,
|
||||||
|
control: TrainerControl,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if (
|
||||||
|
args.evaluation_strategy == IntervalStrategy.STEPS
|
||||||
|
and args.eval_steps < 1.0
|
||||||
|
and state.global_step == 1
|
||||||
|
):
|
||||||
|
control.should_evaluate = True
|
||||||
|
return control
|
||||||
|
|
||||||
|
|
||||||
class SaveBetterTransformerModelCallback(
|
class SaveBetterTransformerModelCallback(
|
||||||
TrainerCallback
|
TrainerCallback
|
||||||
): # pylint: disable=too-few-public-methods
|
): # pylint: disable=too-few-public-methods
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ from transformers.trainer_pt_utils import SequentialDistributedSampler
|
|||||||
|
|
||||||
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
||||||
from axolotl.utils.callbacks import (
|
from axolotl.utils.callbacks import (
|
||||||
|
EvalFirstStepCallback,
|
||||||
GPUStatsCallback,
|
GPUStatsCallback,
|
||||||
SaveBetterTransformerModelCallback,
|
SaveBetterTransformerModelCallback,
|
||||||
SavePeftModelCallback,
|
SavePeftModelCallback,
|
||||||
@@ -704,6 +705,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|||||||
|
|
||||||
callbacks = []
|
callbacks = []
|
||||||
callbacks.append(GPUStatsCallback(cfg))
|
callbacks.append(GPUStatsCallback(cfg))
|
||||||
|
callbacks.append(EvalFirstStepCallback)
|
||||||
|
|
||||||
if cfg.relora_steps:
|
if cfg.relora_steps:
|
||||||
callbacks.append(ReLoRACallback(cfg))
|
callbacks.append(ReLoRACallback(cfg))
|
||||||
|
|||||||
Reference in New Issue
Block a user