GRPO fixes (peft) (#2676)
* don't set peft_config on grpo to prevent double peft wrap * remove overrides needed to support bug * fix grpo tests * require more CPU for multigpu to help with torch compile for vllm
This commit is contained in:
@@ -70,7 +70,7 @@ def run_cmd(cmd: str, run_folder: str):
|
|||||||
image=cicd_image,
|
image=cicd_image,
|
||||||
gpu=GPU_CONFIG,
|
gpu=GPU_CONFIG,
|
||||||
timeout=90 * 60,
|
timeout=90 * 60,
|
||||||
cpu=8.0,
|
cpu=16.0,
|
||||||
memory=131072 * N_GPUS,
|
memory=131072 * N_GPUS,
|
||||||
volumes=VOLUME_CONFIG,
|
volumes=VOLUME_CONFIG,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1174,7 +1174,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.eval_dataset:
|
if self.eval_dataset:
|
||||||
trainer_kwargs["eval_dataset"] = self.eval_dataset
|
trainer_kwargs["eval_dataset"] = self.eval_dataset
|
||||||
if self.cfg.adapter and self.peft_config:
|
if self.cfg.adapter and self.peft_config:
|
||||||
trainer_kwargs["peft_config"] = self.peft_config
|
if self.cfg.rl is not RLType.GRPO:
|
||||||
|
trainer_kwargs["peft_config"] = self.peft_config
|
||||||
if self.cfg.precompute_ref_log_probs is not None:
|
if self.cfg.precompute_ref_log_probs is not None:
|
||||||
trainer_kwargs["precompute_ref_log_probs"] = (
|
trainer_kwargs["precompute_ref_log_probs"] = (
|
||||||
self.cfg.precompute_ref_log_probs
|
self.cfg.precompute_ref_log_probs
|
||||||
|
|||||||
@@ -3,7 +3,6 @@
|
|||||||
# pylint: disable=too-many-lines,duplicate-code,protected-access,no-member
|
# pylint: disable=too-many-lines,duplicate-code,protected-access,no-member
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
from contextlib import nullcontext
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
@@ -14,7 +13,7 @@ from accelerate.utils import (
|
|||||||
broadcast_object_list,
|
broadcast_object_list,
|
||||||
gather,
|
gather,
|
||||||
gather_object,
|
gather_object,
|
||||||
is_peft_model,
|
is_peft_available,
|
||||||
)
|
)
|
||||||
from datasets import Dataset, IterableDataset
|
from datasets import Dataset, IterableDataset
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@@ -30,15 +29,13 @@ from transformers import (
|
|||||||
TrainerCallback,
|
TrainerCallback,
|
||||||
)
|
)
|
||||||
from transformers.trainer_utils import seed_worker
|
from transformers.trainer_utils import seed_worker
|
||||||
from transformers.utils import is_peft_available
|
|
||||||
from trl import GRPOTrainer
|
from trl import GRPOTrainer
|
||||||
from trl.data_utils import (
|
from trl.data_utils import (
|
||||||
apply_chat_template,
|
apply_chat_template,
|
||||||
is_conversational,
|
is_conversational,
|
||||||
maybe_apply_chat_template,
|
maybe_apply_chat_template,
|
||||||
)
|
)
|
||||||
from trl.extras.profiling import profiling_context, profiling_decorator
|
from trl.extras.profiling import profiling_context
|
||||||
from trl.import_utils import is_deepspeed_available
|
|
||||||
from trl.models import unwrap_model_for_generation
|
from trl.models import unwrap_model_for_generation
|
||||||
from trl.trainer.grpo_config import GRPOConfig
|
from trl.trainer.grpo_config import GRPOConfig
|
||||||
from trl.trainer.grpo_trainer import RewardFunc, nanstd
|
from trl.trainer.grpo_trainer import RewardFunc, nanstd
|
||||||
@@ -52,62 +49,12 @@ if is_peft_available():
|
|||||||
# pylint: disable=unused-import
|
# pylint: disable=unused-import
|
||||||
from peft import PeftConfig
|
from peft import PeftConfig
|
||||||
|
|
||||||
if is_deepspeed_available():
|
|
||||||
import deepspeed
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
|
class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
|
||||||
"""Extend the base GRPOTrainer for axolotl helpers"""
|
"""Extend the base GRPOTrainer for axolotl helpers"""
|
||||||
|
|
||||||
_tag_names = ["trl", "grpo", "axolotl"]
|
_tag_names = ["trl", "grpo", "axolotl"]
|
||||||
|
|
||||||
@profiling_decorator
|
|
||||||
def _move_model_to_vllm(self):
|
|
||||||
# For DeepSpeed ZeRO-3, we need to gather all parameters before operations
|
|
||||||
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
|
|
||||||
zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3
|
|
||||||
gather_if_zero3 = (
|
|
||||||
deepspeed.zero.GatheredParameters if zero_stage_3 else nullcontext
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_peft_model(self.model):
|
|
||||||
# With PEFT and DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as merging
|
|
||||||
# adapters in a sharded manner is not supported.
|
|
||||||
with gather_if_zero3(list(self.model.parameters())):
|
|
||||||
self.model.merge_adapter()
|
|
||||||
|
|
||||||
# Update vLLM weights while parameters are gathered
|
|
||||||
for name, param in self.model.named_parameters():
|
|
||||||
# When using PEFT, we need to recover the original parameter name and discard some parameters
|
|
||||||
name = (
|
|
||||||
name.removeprefix("base_model.model.")
|
|
||||||
.removeprefix("base_model.model.")
|
|
||||||
.replace(".base_layer", "")
|
|
||||||
)
|
|
||||||
if self.model.prefix in name:
|
|
||||||
continue
|
|
||||||
# When module to save, remove its prefix and discard the original module
|
|
||||||
if "original_module" in name:
|
|
||||||
continue
|
|
||||||
name = name.replace("modules_to_save.default.", "")
|
|
||||||
|
|
||||||
if self.accelerator.is_main_process:
|
|
||||||
self.vllm_client.update_named_param(name, param.data)
|
|
||||||
|
|
||||||
# Unmerge adapters while parameters are still gathered
|
|
||||||
self.model.unmerge_adapter()
|
|
||||||
# Parameters will automatically be repartitioned when exiting the context
|
|
||||||
else:
|
|
||||||
# For non-PEFT models, simply gather and update each parameter individually.
|
|
||||||
for name, param in self.model.named_parameters():
|
|
||||||
with gather_if_zero3([param]):
|
|
||||||
if self.accelerator.is_main_process:
|
|
||||||
self.vllm_client.update_named_param(name, param.data)
|
|
||||||
|
|
||||||
# Reset cache on main process
|
|
||||||
if self.accelerator.is_main_process:
|
|
||||||
self.vllm_client.reset_prefix_cache()
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||||
"""Extend the base GRPOTrainer for sequence parallelism handling"""
|
"""Extend the base GRPOTrainer for sequence parallelism handling"""
|
||||||
|
|||||||
@@ -166,7 +166,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
|||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.skip(reason="flaky test")
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"num_gpus",
|
"num_gpus",
|
||||||
[1, 2],
|
[1, 2],
|
||||||
@@ -231,8 +230,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
|||||||
"NCCL_P2P_LEVEL": "LOC",
|
"NCCL_P2P_LEVEL": "LOC",
|
||||||
**current_env,
|
**current_env,
|
||||||
"CUDA_VISIBLE_DEVICES": "1",
|
"CUDA_VISIBLE_DEVICES": "1",
|
||||||
"VLLM_DISABLE_COMPILE_CACHE": "1",
|
|
||||||
# "VLLM_USE_V1": "0",
|
|
||||||
}
|
}
|
||||||
vllm_process = start_vllm(
|
vllm_process = start_vllm(
|
||||||
cfg.base_model,
|
cfg.base_model,
|
||||||
@@ -266,7 +263,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
|||||||
finally:
|
finally:
|
||||||
recursive_kill(vllm_process)
|
recursive_kill(vllm_process)
|
||||||
|
|
||||||
@pytest.mark.skip(reason="flaky test")
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"num_gpus",
|
"num_gpus",
|
||||||
[1, 2],
|
[1, 2],
|
||||||
@@ -325,8 +321,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
|||||||
"NCCL_P2P_LEVEL": "LOC", # nccl can be brittle, assume P2P isn't reliable
|
"NCCL_P2P_LEVEL": "LOC", # nccl can be brittle, assume P2P isn't reliable
|
||||||
**current_env,
|
**current_env,
|
||||||
"CUDA_VISIBLE_DEVICES": "1",
|
"CUDA_VISIBLE_DEVICES": "1",
|
||||||
"VLLM_DISABLE_COMPILE_CACHE": "1",
|
|
||||||
# "VLLM_USE_V1": "0",
|
|
||||||
}
|
}
|
||||||
vllm_process = start_vllm(
|
vllm_process = start_vllm(
|
||||||
cfg.base_model,
|
cfg.base_model,
|
||||||
|
|||||||
Reference in New Issue
Block a user