From cfc7fe0df298ffaff123621ebf259061045c07b0 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 3 Feb 2025 00:54:52 -0500 Subject: [PATCH] remove ununsable args kwargs --- src/axolotl/core/trainer_builder.py | 4 ++++ src/axolotl/core/trainers/grpo/__init__.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index b4480c0bd..b95263f05 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1013,6 +1013,10 @@ class HFRLTrainerBuilder(TrainerBuilderBase): elif self.cfg.rl == "grpo": training_args_cls = GRPOStrategy.get_training_args_class() training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg)) + blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs() + for blocklist_key in blocklist_args_kwargs: + if blocklist_key in training_args_kwargs: + del training_args_kwargs[blocklist_key] else: training_args_cls = DPOConfig.get_training_args_class() diff --git a/src/axolotl/core/trainers/grpo/__init__.py b/src/axolotl/core/trainers/grpo/__init__.py index 9c2fa583b..01e19976f 100644 --- a/src/axolotl/core/trainers/grpo/__init__.py +++ b/src/axolotl/core/trainers/grpo/__init__.py @@ -57,3 +57,7 @@ class GRPOStrategy: return features return data_collator + + @classmethod + def get_blocklist_args_kwargs(cls): + return ["dataset_num_proc"]