feat: move epoch setting to base
This commit is contained in:
@@ -304,6 +304,8 @@ class TrainerBuilderBase(abc.ABC):
|
||||
|
||||
training_args_kwargs["max_steps"] = self.cfg.max_steps or total_num_steps or -1
|
||||
|
||||
training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs
|
||||
|
||||
# max_length is not used in CausalTrainer
|
||||
if self.cfg.reward_model or self.cfg.rl:
|
||||
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||
|
||||
@@ -156,7 +156,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
if blocklist_key in training_args_kwargs:
|
||||
del training_args_kwargs[blocklist_key]
|
||||
|
||||
training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs
|
||||
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
|
||||
logging_first_step=True,
|
||||
**training_args_kwargs,
|
||||
|
||||
@@ -148,7 +148,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
trainer_kwargs = {}
|
||||
|
||||
# Pop optimizer_cls_and_kwargs to trainer_kwargs
|
||||
if "optimizer_cls_and_kwargs" in trainer_kwargs:
|
||||
if "optimizer_cls_and_kwargs" in training_arguments_kwargs:
|
||||
trainer_kwargs["optimizer_cls_and_kwargs"] = training_arguments_kwargs.pop(
|
||||
"optimizer_cls_and_kwargs"
|
||||
)
|
||||
@@ -219,7 +219,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
training_arguments_kwargs["eval_accumulation_steps"] = (
|
||||
self.cfg.gradient_accumulation_steps
|
||||
)
|
||||
training_arguments_kwargs["num_train_epochs"] = self.cfg.num_epochs
|
||||
|
||||
training_arguments_kwargs["load_best_model_at_end"] = (
|
||||
(
|
||||
|
||||
Reference in New Issue
Block a user