feat: move epoch setting to base
This commit is contained in:
@@ -304,6 +304,8 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
|
|
||||||
training_args_kwargs["max_steps"] = self.cfg.max_steps or total_num_steps or -1
|
training_args_kwargs["max_steps"] = self.cfg.max_steps or total_num_steps or -1
|
||||||
|
|
||||||
|
training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs
|
||||||
|
|
||||||
# max_length is not used in CausalTrainer
|
# max_length is not used in CausalTrainer
|
||||||
if self.cfg.reward_model or self.cfg.rl:
|
if self.cfg.reward_model or self.cfg.rl:
|
||||||
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||||
|
|||||||
@@ -156,7 +156,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
if blocklist_key in training_args_kwargs:
|
if blocklist_key in training_args_kwargs:
|
||||||
del training_args_kwargs[blocklist_key]
|
del training_args_kwargs[blocklist_key]
|
||||||
|
|
||||||
training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs
|
|
||||||
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
|
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
|
||||||
logging_first_step=True,
|
logging_first_step=True,
|
||||||
**training_args_kwargs,
|
**training_args_kwargs,
|
||||||
|
|||||||
@@ -148,7 +148,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
trainer_kwargs = {}
|
trainer_kwargs = {}
|
||||||
|
|
||||||
# Pop optimizer_cls_and_kwargs to trainer_kwargs
|
# Pop optimizer_cls_and_kwargs to trainer_kwargs
|
||||||
if "optimizer_cls_and_kwargs" in trainer_kwargs:
|
if "optimizer_cls_and_kwargs" in training_arguments_kwargs:
|
||||||
trainer_kwargs["optimizer_cls_and_kwargs"] = training_arguments_kwargs.pop(
|
trainer_kwargs["optimizer_cls_and_kwargs"] = training_arguments_kwargs.pop(
|
||||||
"optimizer_cls_and_kwargs"
|
"optimizer_cls_and_kwargs"
|
||||||
)
|
)
|
||||||
@@ -219,7 +219,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs["eval_accumulation_steps"] = (
|
training_arguments_kwargs["eval_accumulation_steps"] = (
|
||||||
self.cfg.gradient_accumulation_steps
|
self.cfg.gradient_accumulation_steps
|
||||||
)
|
)
|
||||||
training_arguments_kwargs["num_train_epochs"] = self.cfg.num_epochs
|
|
||||||
|
|
||||||
training_arguments_kwargs["load_best_model_at_end"] = (
|
training_arguments_kwargs["load_best_model_at_end"] = (
|
||||||
(
|
(
|
||||||
|
|||||||
Reference in New Issue
Block a user