Bumping 0.15.1 TRL version for GRPO+PEFT fix (#2344)
* bumping TRL version * apply upstream fixes to our custom fix --------- Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
@@ -18,7 +18,7 @@ tokenizers>=0.21.0
|
|||||||
accelerate==1.3.0
|
accelerate==1.3.0
|
||||||
datasets==3.2.0
|
datasets==3.2.0
|
||||||
deepspeed==0.16.1
|
deepspeed==0.16.1
|
||||||
trl==0.15.0
|
trl==0.15.1
|
||||||
|
|
||||||
optimum==1.16.2
|
optimum==1.16.2
|
||||||
hf_transfer
|
hf_transfer
|
||||||
|
|||||||
@@ -78,7 +78,6 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
|
|||||||
if is_peft_model(unwrapped_model):
|
if is_peft_model(unwrapped_model):
|
||||||
unwrapped_model.merge_adapter()
|
unwrapped_model.merge_adapter()
|
||||||
state_dict = unwrapped_model.state_dict()
|
state_dict = unwrapped_model.state_dict()
|
||||||
unwrapped_model.unmerge_adapter()
|
|
||||||
# Remove base_model and base_layer prefixes
|
# Remove base_model and base_layer prefixes
|
||||||
state_dict = {
|
state_dict = {
|
||||||
k.removeprefix("base_model.model.")
|
k.removeprefix("base_model.model.")
|
||||||
@@ -100,8 +99,10 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
|
|||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
state_dict = unwrapped_model.state_dict()
|
state_dict = unwrapped_model.state_dict()
|
||||||
if self.accelerator.is_main_process:
|
if self.accelerator.is_main_process:
|
||||||
llm_model = (
|
llm_model = (
|
||||||
self.llm.llm_engine.model_executor.driver_worker.model_runner.model
|
self.llm.llm_engine.model_executor.driver_worker.model_runner.model
|
||||||
)
|
)
|
||||||
llm_model.load_weights(state_dict.items())
|
llm_model.load_weights(state_dict.items())
|
||||||
|
if is_peft_model(unwrapped_model):
|
||||||
|
unwrapped_model.unmerge_adapter()
|
||||||
|
|||||||
Reference in New Issue
Block a user