diff --git a/src/axolotl/core/trainers/grpo/trainer.py b/src/axolotl/core/trainers/grpo/trainer.py index 8f8b9fcf9..98e01eb1d 100644 --- a/src/axolotl/core/trainers/grpo/trainer.py +++ b/src/axolotl/core/trainers/grpo/trainer.py @@ -39,6 +39,15 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer): self.model = self._enable_gradient_checkpointing(self.model, kwargs["args"]) # pylint: enable=access-member-before-definition + # cleanup the ref_model if we have a peft model passed in + # TODO remove this after next major trl release + if ( + self.ref_model # pylint: disable=access-member-before-definition + and is_peft_model(self.model) + ): + del self.ref_model + self.ref_model = None + def _enable_gradient_checkpointing( self, model: PreTrainedModel, args: GRPOConfig ) -> PreTrainedModel: