feat: add handling for seed and SP/ring-attn config
This commit is contained in:
@@ -331,6 +331,9 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
"save_only_model",
|
"save_only_model",
|
||||||
"include_tokens_per_second",
|
"include_tokens_per_second",
|
||||||
"weight_decay",
|
"weight_decay",
|
||||||
|
"sequence_parallel_degree",
|
||||||
|
"ring_attn_func",
|
||||||
|
"seed",
|
||||||
]:
|
]:
|
||||||
if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
|
if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
|
||||||
training_args_kwargs[arg] = getattr(self.cfg, arg)
|
training_args_kwargs[arg] = getattr(self.cfg, arg)
|
||||||
@@ -593,9 +596,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
def build(self, total_num_steps):
|
def build(self, total_num_steps):
|
||||||
training_arguments_kwargs = self._set_base_training_args(total_num_steps)
|
training_arguments_kwargs = self._set_base_training_args(total_num_steps)
|
||||||
|
|
||||||
if self.cfg.seed is not None:
|
|
||||||
training_arguments_kwargs["seed"] = self.cfg.seed
|
|
||||||
|
|
||||||
if self.cfg.fsdp:
|
if self.cfg.fsdp:
|
||||||
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
|
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
|
||||||
if self.cfg.fsdp_config:
|
if self.cfg.fsdp_config:
|
||||||
@@ -806,11 +806,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
self.cfg.kd_top_k_before_softmax
|
self.cfg.kd_top_k_before_softmax
|
||||||
)
|
)
|
||||||
|
|
||||||
training_arguments_kwargs["sequence_parallel_degree"] = (
|
|
||||||
self.cfg.sequence_parallel_degree
|
|
||||||
)
|
|
||||||
training_arguments_kwargs["ring_attn_func"] = self.cfg.ring_attn_func
|
|
||||||
|
|
||||||
if self.cfg.reward_model:
|
if self.cfg.reward_model:
|
||||||
training_args_cls = AxolotlRewardConfig
|
training_args_cls = AxolotlRewardConfig
|
||||||
elif self.cfg.process_reward_model:
|
elif self.cfg.process_reward_model:
|
||||||
@@ -1010,10 +1005,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.use_wandb:
|
if self.cfg.use_wandb:
|
||||||
training_args_kwargs["run_name"] = self.cfg.wandb_name
|
training_args_kwargs["run_name"] = self.cfg.wandb_name
|
||||||
|
|
||||||
training_args_kwargs["sequence_parallel_degree"] = (
|
|
||||||
self.cfg.sequence_parallel_degree
|
|
||||||
)
|
|
||||||
|
|
||||||
training_args_cls = None
|
training_args_cls = None
|
||||||
blocklist_args_kwargs = []
|
blocklist_args_kwargs = []
|
||||||
if self.cfg.rl is RLType.SIMPO:
|
if self.cfg.rl is RLType.SIMPO:
|
||||||
|
|||||||
Reference in New Issue
Block a user