From a9ebff087ca7e0ccaae65a4e85eba163765f9793 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 20 Feb 2025 21:53:20 -0500 Subject: [PATCH] remove ref_model when peft model is passed into grpo trainer --- src/axolotl/core/trainers/grpo/trainer.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/axolotl/core/trainers/grpo/trainer.py b/src/axolotl/core/trainers/grpo/trainer.py index 8f8b9fcf9..98e01eb1d 100644 --- a/src/axolotl/core/trainers/grpo/trainer.py +++ b/src/axolotl/core/trainers/grpo/trainer.py @@ -39,6 +39,15 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer): self.model = self._enable_gradient_checkpointing(self.model, kwargs["args"]) # pylint: enable=access-member-before-definition + # cleanup the ref_model if we have a peft model passed in + # TODO remove this after next major trl release + if ( + self.ref_model # pylint: disable=access-member-before-definition + and is_peft_model(self.model) + ): + del self.ref_model + self.ref_model = None + def _enable_gradient_checkpointing( self, model: PreTrainedModel, args: GRPOConfig ) -> PreTrainedModel: