diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index c5136ce0a..18c67919c 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1019,7 +1019,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase): training_args_kwargs.update(DPOConfig.set_training_args_kwargs(self.cfg)) training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg - output_dir=self.cfg.output_dir, + self.cfg.output_dir, per_device_train_batch_size=self.cfg.micro_batch_size, max_steps=self.cfg.max_steps or total_num_steps, gradient_accumulation_steps=self.cfg.gradient_accumulation_steps, diff --git a/src/axolotl/core/trainers/grpo/args.py b/src/axolotl/core/trainers/grpo/args.py index e5f2cc254..e14e6b0dc 100644 --- a/src/axolotl/core/trainers/grpo/args.py +++ b/src/axolotl/core/trainers/grpo/args.py @@ -9,7 +9,7 @@ from axolotl.core.training_args import AxolotlTrainingMixins @dataclass -class AxolotlGRPOConfig(GRPOConfig, AxolotlTrainingMixins): +class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig): """ Axolotl GRPO Config for GRPO training """