From a27b909c5c1c2c561a8d503024b89afcce15226f Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 16 May 2025 15:47:03 -0400 Subject: [PATCH] GRPO fixes (peft) (#2676) * don't set peft_config on grpo to prevent double peft wrap * remove overrides needed to support bug * fix grpo tests * require more CPU for multigpu to help with torch compile for vllm --- cicd/multigpu.py | 2 +- src/axolotl/core/trainer_builder.py | 3 +- src/axolotl/core/trainers/grpo/trainer.py | 57 +---------------------- tests/e2e/multigpu/solo/test_grpo.py | 6 --- 4 files changed, 5 insertions(+), 63 deletions(-) diff --git a/cicd/multigpu.py b/cicd/multigpu.py index 90d4ce1ee..7de4ae0a7 100644 --- a/cicd/multigpu.py +++ b/cicd/multigpu.py @@ -70,7 +70,7 @@ def run_cmd(cmd: str, run_folder: str): image=cicd_image, gpu=GPU_CONFIG, timeout=90 * 60, - cpu=8.0, + cpu=16.0, memory=131072 * N_GPUS, volumes=VOLUME_CONFIG, ) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 6bd4ef996..d82e4d20b 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1174,7 +1174,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase): if self.eval_dataset: trainer_kwargs["eval_dataset"] = self.eval_dataset if self.cfg.adapter and self.peft_config: - trainer_kwargs["peft_config"] = self.peft_config + if self.cfg.rl is not RLType.GRPO: + trainer_kwargs["peft_config"] = self.peft_config if self.cfg.precompute_ref_log_probs is not None: trainer_kwargs["precompute_ref_log_probs"] = ( self.cfg.precompute_ref_log_probs diff --git a/src/axolotl/core/trainers/grpo/trainer.py b/src/axolotl/core/trainers/grpo/trainer.py index bc3d140b1..8a89de333 100644 --- a/src/axolotl/core/trainers/grpo/trainer.py +++ b/src/axolotl/core/trainers/grpo/trainer.py @@ -3,7 +3,6 @@ # pylint: disable=too-many-lines,duplicate-code,protected-access,no-member import warnings -from contextlib import nullcontext from typing import Any import datasets @@ -14,7 +13,7 @@ from accelerate.utils import ( broadcast_object_list, gather, gather_object, - is_peft_model, + is_peft_available, ) from datasets import Dataset, IterableDataset from torch import nn @@ -30,15 +29,13 @@ from transformers import ( TrainerCallback, ) from transformers.trainer_utils import seed_worker -from transformers.utils import is_peft_available from trl import GRPOTrainer from trl.data_utils import ( apply_chat_template, is_conversational, maybe_apply_chat_template, ) -from trl.extras.profiling import profiling_context, profiling_decorator -from trl.import_utils import is_deepspeed_available +from trl.extras.profiling import profiling_context from trl.models import unwrap_model_for_generation from trl.trainer.grpo_config import GRPOConfig from trl.trainer.grpo_trainer import RewardFunc, nanstd @@ -52,62 +49,12 @@ if is_peft_available(): # pylint: disable=unused-import from peft import PeftConfig -if is_deepspeed_available(): - import deepspeed - class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer): """Extend the base GRPOTrainer for axolotl helpers""" _tag_names = ["trl", "grpo", "axolotl"] - @profiling_decorator - def _move_model_to_vllm(self): - # For DeepSpeed ZeRO-3, we need to gather all parameters before operations - deepspeed_plugin = self.accelerator.state.deepspeed_plugin - zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3 - gather_if_zero3 = ( - deepspeed.zero.GatheredParameters if zero_stage_3 else nullcontext - ) - - if is_peft_model(self.model): - # With PEFT and DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as merging - # adapters in a sharded manner is not supported. - with gather_if_zero3(list(self.model.parameters())): - self.model.merge_adapter() - - # Update vLLM weights while parameters are gathered - for name, param in self.model.named_parameters(): - # When using PEFT, we need to recover the original parameter name and discard some parameters - name = ( - name.removeprefix("base_model.model.") - .removeprefix("base_model.model.") - .replace(".base_layer", "") - ) - if self.model.prefix in name: - continue - # When module to save, remove its prefix and discard the original module - if "original_module" in name: - continue - name = name.replace("modules_to_save.default.", "") - - if self.accelerator.is_main_process: - self.vllm_client.update_named_param(name, param.data) - - # Unmerge adapters while parameters are still gathered - self.model.unmerge_adapter() - # Parameters will automatically be repartitioned when exiting the context - else: - # For non-PEFT models, simply gather and update each parameter individually. - for name, param in self.model.named_parameters(): - with gather_if_zero3([param]): - if self.accelerator.is_main_process: - self.vllm_client.update_named_param(name, param.data) - - # Reset cache on main process - if self.accelerator.is_main_process: - self.vllm_client.reset_prefix_cache() - class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer): """Extend the base GRPOTrainer for sequence parallelism handling""" diff --git a/tests/e2e/multigpu/solo/test_grpo.py b/tests/e2e/multigpu/solo/test_grpo.py index a1eade531..575b7a620 100644 --- a/tests/e2e/multigpu/solo/test_grpo.py +++ b/tests/e2e/multigpu/solo/test_grpo.py @@ -166,7 +166,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs): """ ) - @pytest.mark.skip(reason="flaky test") @pytest.mark.parametrize( "num_gpus", [1, 2], @@ -231,8 +230,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs): "NCCL_P2P_LEVEL": "LOC", **current_env, "CUDA_VISIBLE_DEVICES": "1", - "VLLM_DISABLE_COMPILE_CACHE": "1", - # "VLLM_USE_V1": "0", } vllm_process = start_vllm( cfg.base_model, @@ -266,7 +263,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs): finally: recursive_kill(vllm_process) - @pytest.mark.skip(reason="flaky test") @pytest.mark.parametrize( "num_gpus", [1, 2], @@ -325,8 +321,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs): "NCCL_P2P_LEVEL": "LOC", # nccl can be brittle, assume P2P isn't reliable **current_env, "CUDA_VISIBLE_DEVICES": "1", - "VLLM_DISABLE_COMPILE_CACHE": "1", - # "VLLM_USE_V1": "0", } vllm_process = start_vllm( cfg.base_model,