feat: remove evalfirst callback with built-in trainer arg (#2797)
This commit is contained in:
@@ -386,8 +386,10 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
elif self.cfg.eval_steps:
|
elif self.cfg.eval_steps:
|
||||||
training_args_kwargs["eval_strategy"] = "steps"
|
training_args_kwargs["eval_strategy"] = "steps"
|
||||||
training_args_kwargs["eval_steps"] = self.cfg.eval_steps
|
training_args_kwargs["eval_steps"] = self.cfg.eval_steps
|
||||||
|
training_args_kwargs["eval_on_start"] = True
|
||||||
elif self.cfg.eval_strategy:
|
elif self.cfg.eval_strategy:
|
||||||
training_args_kwargs["eval_strategy"] = self.cfg.eval_strategy
|
training_args_kwargs["eval_strategy"] = self.cfg.eval_strategy
|
||||||
|
training_args_kwargs["eval_on_start"] = True
|
||||||
|
|
||||||
def _configure_reporting(self, training_args_kwargs: dict):
|
def _configure_reporting(self, training_args_kwargs: dict):
|
||||||
report_to = []
|
report_to = []
|
||||||
|
|||||||
@@ -27,7 +27,6 @@ from axolotl.monkeypatch.relora import ReLoRACallback
|
|||||||
from axolotl.processing_strategies import get_processing_strategy
|
from axolotl.processing_strategies import get_processing_strategy
|
||||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||||
from axolotl.utils.callbacks import (
|
from axolotl.utils.callbacks import (
|
||||||
EvalFirstStepCallback,
|
|
||||||
LossWatchDogCallback,
|
LossWatchDogCallback,
|
||||||
SaveBetterTransformerModelCallback,
|
SaveBetterTransformerModelCallback,
|
||||||
bench_eval_callback_factory,
|
bench_eval_callback_factory,
|
||||||
@@ -58,7 +57,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
|
|
||||||
def get_callbacks(self):
|
def get_callbacks(self):
|
||||||
callbacks = super().get_callbacks()
|
callbacks = super().get_callbacks()
|
||||||
callbacks.append(EvalFirstStepCallback())
|
|
||||||
|
|
||||||
if self.cfg.relora_steps:
|
if self.cfg.relora_steps:
|
||||||
callbacks.append(ReLoRACallback(self.cfg))
|
callbacks.append(ReLoRACallback(self.cfg))
|
||||||
|
|||||||
@@ -53,25 +53,6 @@ IGNORE_INDEX = -100
|
|||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class EvalFirstStepCallback(
|
|
||||||
TrainerCallback
|
|
||||||
): # pylint: disable=too-few-public-methods disable=unused-argument
|
|
||||||
"""
|
|
||||||
Callback to trigger evals on the first step
|
|
||||||
"""
|
|
||||||
|
|
||||||
def on_step_end(
|
|
||||||
self,
|
|
||||||
args: TrainingArguments,
|
|
||||||
state: TrainerState,
|
|
||||||
control: TrainerControl,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
if args.eval_strategy == IntervalStrategy.STEPS and state.global_step == 1:
|
|
||||||
control.should_evaluate = True
|
|
||||||
return control
|
|
||||||
|
|
||||||
|
|
||||||
class SaveBetterTransformerModelCallback(
|
class SaveBetterTransformerModelCallback(
|
||||||
TrainerCallback
|
TrainerCallback
|
||||||
): # pylint: disable=too-few-public-methods
|
): # pylint: disable=too-few-public-methods
|
||||||
|
|||||||
Reference in New Issue
Block a user