From 67b1df21aa507e10d72e37d033538c217bb931d5 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Wed, 14 May 2025 09:49:46 +0700 Subject: [PATCH] feat: add handling for seed and SP/ring-attn config --- src/axolotl/core/trainer_builder.py | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 13bac93b7..f0e10bf4f 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -331,6 +331,9 @@ class TrainerBuilderBase(abc.ABC): "save_only_model", "include_tokens_per_second", "weight_decay", + "sequence_parallel_degree", + "ring_attn_func", + "seed", ]: if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None: training_args_kwargs[arg] = getattr(self.cfg, arg) @@ -593,9 +596,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): def build(self, total_num_steps): training_arguments_kwargs = self._set_base_training_args(total_num_steps) - if self.cfg.seed is not None: - training_arguments_kwargs["seed"] = self.cfg.seed - if self.cfg.fsdp: training_arguments_kwargs["fsdp"] = self.cfg.fsdp if self.cfg.fsdp_config: @@ -806,11 +806,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): self.cfg.kd_top_k_before_softmax ) - training_arguments_kwargs["sequence_parallel_degree"] = ( - self.cfg.sequence_parallel_degree - ) - training_arguments_kwargs["ring_attn_func"] = self.cfg.ring_attn_func - if self.cfg.reward_model: training_args_cls = AxolotlRewardConfig elif self.cfg.process_reward_model: @@ -1010,10 +1005,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase): if self.cfg.use_wandb: training_args_kwargs["run_name"] = self.cfg.wandb_name - training_args_kwargs["sequence_parallel_degree"] = ( - self.cfg.sequence_parallel_degree - ) - training_args_cls = None blocklist_args_kwargs = [] if self.cfg.rl is RLType.SIMPO: