Compare commits

...

1 Commits

Author SHA1 Message Date
Wing Lian
a9ebff087c remove ref_model when peft model is passed into grpo trainer 2025-02-20 21:53:20 -05:00

View File

@@ -39,6 +39,15 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
self.model = self._enable_gradient_checkpointing(self.model, kwargs["args"]) self.model = self._enable_gradient_checkpointing(self.model, kwargs["args"])
# pylint: enable=access-member-before-definition # pylint: enable=access-member-before-definition
# cleanup the ref_model if we have a peft model passed in
# TODO remove this after next major trl release
if (
self.ref_model # pylint: disable=access-member-before-definition
and is_peft_model(self.model)
):
del self.ref_model
self.ref_model = None
def _enable_gradient_checkpointing( def _enable_gradient_checkpointing(
self, model: PreTrainedModel, args: GRPOConfig self, model: PreTrainedModel, args: GRPOConfig
) -> PreTrainedModel: ) -> PreTrainedModel: