feat: add handling for seed and SP/ring-attn config
This commit is contained in:
@@ -331,6 +331,9 @@ class TrainerBuilderBase(abc.ABC):
|
||||
"save_only_model",
|
||||
"include_tokens_per_second",
|
||||
"weight_decay",
|
||||
"sequence_parallel_degree",
|
||||
"ring_attn_func",
|
||||
"seed",
|
||||
]:
|
||||
if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
|
||||
training_args_kwargs[arg] = getattr(self.cfg, arg)
|
||||
@@ -593,9 +596,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
def build(self, total_num_steps):
|
||||
training_arguments_kwargs = self._set_base_training_args(total_num_steps)
|
||||
|
||||
if self.cfg.seed is not None:
|
||||
training_arguments_kwargs["seed"] = self.cfg.seed
|
||||
|
||||
if self.cfg.fsdp:
|
||||
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
|
||||
if self.cfg.fsdp_config:
|
||||
@@ -806,11 +806,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
self.cfg.kd_top_k_before_softmax
|
||||
)
|
||||
|
||||
training_arguments_kwargs["sequence_parallel_degree"] = (
|
||||
self.cfg.sequence_parallel_degree
|
||||
)
|
||||
training_arguments_kwargs["ring_attn_func"] = self.cfg.ring_attn_func
|
||||
|
||||
if self.cfg.reward_model:
|
||||
training_args_cls = AxolotlRewardConfig
|
||||
elif self.cfg.process_reward_model:
|
||||
@@ -1010,10 +1005,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.use_wandb:
|
||||
training_args_kwargs["run_name"] = self.cfg.wandb_name
|
||||
|
||||
training_args_kwargs["sequence_parallel_degree"] = (
|
||||
self.cfg.sequence_parallel_degree
|
||||
)
|
||||
|
||||
training_args_cls = None
|
||||
blocklist_args_kwargs = []
|
||||
if self.cfg.rl is RLType.SIMPO:
|
||||
|
||||
Reference in New Issue
Block a user