feat: add handling for seed and SP/ring-attn config

This commit is contained in:
NanoCode012
2025-05-14 09:49:46 +07:00
parent 9af4bffd5d
commit 67b1df21aa

View File

@@ -331,6 +331,9 @@ class TrainerBuilderBase(abc.ABC):
"save_only_model",
"include_tokens_per_second",
"weight_decay",
"sequence_parallel_degree",
"ring_attn_func",
"seed",
]:
if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
training_args_kwargs[arg] = getattr(self.cfg, arg)
@@ -593,9 +596,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
def build(self, total_num_steps):
training_arguments_kwargs = self._set_base_training_args(total_num_steps)
if self.cfg.seed is not None:
training_arguments_kwargs["seed"] = self.cfg.seed
if self.cfg.fsdp:
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
if self.cfg.fsdp_config:
@@ -806,11 +806,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self.cfg.kd_top_k_before_softmax
)
training_arguments_kwargs["sequence_parallel_degree"] = (
self.cfg.sequence_parallel_degree
)
training_arguments_kwargs["ring_attn_func"] = self.cfg.ring_attn_func
if self.cfg.reward_model:
training_args_cls = AxolotlRewardConfig
elif self.cfg.process_reward_model:
@@ -1010,10 +1005,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.use_wandb:
training_args_kwargs["run_name"] = self.cfg.wandb_name
training_args_kwargs["sequence_parallel_degree"] = (
self.cfg.sequence_parallel_degree
)
training_args_cls = None
blocklist_args_kwargs = []
if self.cfg.rl is RLType.SIMPO: