remove ref_model when peft model is passed into grpo trainer
This commit is contained in:
@@ -39,6 +39,15 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
|
||||
self.model = self._enable_gradient_checkpointing(self.model, kwargs["args"])
|
||||
# pylint: enable=access-member-before-definition
|
||||
|
||||
# cleanup the ref_model if we have a peft model passed in
|
||||
# TODO remove this after next major trl release
|
||||
if (
|
||||
self.ref_model # pylint: disable=access-member-before-definition
|
||||
and is_peft_model(self.model)
|
||||
):
|
||||
del self.ref_model
|
||||
self.ref_model = None
|
||||
|
||||
def _enable_gradient_checkpointing(
|
||||
self, model: PreTrainedModel, args: GRPOConfig
|
||||
) -> PreTrainedModel:
|
||||
|
||||
Reference in New Issue
Block a user