Bumping 0.15.1 TRL version for GRPO+PEFT fix (#2344)

* bumping TRL version

* apply upstream fixes to our custom fix

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
salman
2025-02-21 03:56:04 +00:00
committed by GitHub
parent b53a41372f
commit 29b366b2e1
2 changed files with 8 additions and 7 deletions

View File

@@ -78,7 +78,6 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
if is_peft_model(unwrapped_model):
unwrapped_model.merge_adapter()
state_dict = unwrapped_model.state_dict()
unwrapped_model.unmerge_adapter()
# Remove base_model and base_layer prefixes
state_dict = {
k.removeprefix("base_model.model.")
@@ -100,8 +99,10 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
}
else:
state_dict = unwrapped_model.state_dict()
if self.accelerator.is_main_process:
llm_model = (
self.llm.llm_engine.model_executor.driver_worker.model_runner.model
)
llm_model.load_weights(state_dict.items())
if self.accelerator.is_main_process:
llm_model = (
self.llm.llm_engine.model_executor.driver_worker.model_runner.model
)
llm_model.load_weights(state_dict.items())
if is_peft_model(unwrapped_model):
unwrapped_model.unmerge_adapter()