remove ref_model when peft model is passed into grpo trainer
This commit is contained in:
@@ -39,6 +39,15 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
|
|||||||
self.model = self._enable_gradient_checkpointing(self.model, kwargs["args"])
|
self.model = self._enable_gradient_checkpointing(self.model, kwargs["args"])
|
||||||
# pylint: enable=access-member-before-definition
|
# pylint: enable=access-member-before-definition
|
||||||
|
|
||||||
|
# cleanup the ref_model if we have a peft model passed in
|
||||||
|
# TODO remove this after next major trl release
|
||||||
|
if (
|
||||||
|
self.ref_model # pylint: disable=access-member-before-definition
|
||||||
|
and is_peft_model(self.model)
|
||||||
|
):
|
||||||
|
del self.ref_model
|
||||||
|
self.ref_model = None
|
||||||
|
|
||||||
def _enable_gradient_checkpointing(
|
def _enable_gradient_checkpointing(
|
||||||
self, model: PreTrainedModel, args: GRPOConfig
|
self, model: PreTrainedModel, args: GRPOConfig
|
||||||
) -> PreTrainedModel:
|
) -> PreTrainedModel:
|
||||||
|
|||||||
Reference in New Issue
Block a user