From 29b366b2e1facb0067da5d2d519c896206bc0f6c Mon Sep 17 00:00:00 2001 From: salman Date: Fri, 21 Feb 2025 03:56:04 +0000 Subject: [PATCH] Bumping 0.15.1 TRL version for GRPO+PEFT fix (#2344) * bumping TRL version * apply upstream fixes to our custom fix --------- Co-authored-by: Wing Lian --- requirements.txt | 2 +- src/axolotl/core/trainers/grpo/trainer.py | 13 +++++++------ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/requirements.txt b/requirements.txt index b54c6e8d6..535d79933 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,7 +18,7 @@ tokenizers>=0.21.0 accelerate==1.3.0 datasets==3.2.0 deepspeed==0.16.1 -trl==0.15.0 +trl==0.15.1 optimum==1.16.2 hf_transfer diff --git a/src/axolotl/core/trainers/grpo/trainer.py b/src/axolotl/core/trainers/grpo/trainer.py index 8f8b9fcf9..6c8f39ac6 100644 --- a/src/axolotl/core/trainers/grpo/trainer.py +++ b/src/axolotl/core/trainers/grpo/trainer.py @@ -78,7 +78,6 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer): if is_peft_model(unwrapped_model): unwrapped_model.merge_adapter() state_dict = unwrapped_model.state_dict() - unwrapped_model.unmerge_adapter() # Remove base_model and base_layer prefixes state_dict = { k.removeprefix("base_model.model.") @@ -100,8 +99,10 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer): } else: state_dict = unwrapped_model.state_dict() - if self.accelerator.is_main_process: - llm_model = ( - self.llm.llm_engine.model_executor.driver_worker.model_runner.model - ) - llm_model.load_weights(state_dict.items()) + if self.accelerator.is_main_process: + llm_model = ( + self.llm.llm_engine.model_executor.driver_worker.model_runner.model + ) + llm_model.load_weights(state_dict.items()) + if is_peft_model(unwrapped_model): + unwrapped_model.unmerge_adapter()