remove ref_model when peft model is passed into grpo trainer

This commit is contained in:
Wing Lian
2025-02-20 21:53:20 -05:00
parent b53a41372f
commit a9ebff087c

View File

@@ -39,6 +39,15 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
self.model = self._enable_gradient_checkpointing(self.model, kwargs["args"])
# pylint: enable=access-member-before-definition
# cleanup the ref_model if we have a peft model passed in
# TODO remove this after next major trl release
if (
self.ref_model # pylint: disable=access-member-before-definition
and is_peft_model(self.model)
):
del self.ref_model
self.ref_model = None
def _enable_gradient_checkpointing(
self, model: PreTrainedModel, args: GRPOConfig
) -> PreTrainedModel: