Updates for trl 0.16.0 - mostly for GRPO (#2437) [skip ci]

* add grpo scale_rewards config for trl#3135

* options to connect to vllm server directly w grpo trl#3094

* temperature support trl#3029

* sampling/generation kwargs for grpo trl#2989

* make vllm_enable_prefix_caching a config param trl#2900

* grpo multi-step optimizeations trl#2899

* remove overrides for grpo trainer

* bump trl to 0.16.0

* add cli  to start vllm-serve via trl

* call the python module directly

* update to use vllm with 2.6.0 too now and call trl vllm serve from module

* vllm 0.8.1

* use python3

* use sys.executable

* remove context and wait for start

* fixes to make it actually work

* fixes so the grpo tests pass with new vllm paradigm

* explicit host/port and check in start vllm

* make sure that vllm doesn't hang by setting quiet so outouts go to dev null

* also bump bnb to latest release

* add option for wait from cli and nccl debugging for ci

* grpo + vllm test on separate devices for now

* make sure grpo + vllm tests runs single worker since pynccl comms would conflict

* fix cli

* remove wait and add caching for argilla dataset

* refactoring configs

* chore: lint

* add vllm config

* fixup vllm grpo args

* fix one more incorrect schema/config path

* fix another vlllm reference and increase timeout

* make the tests run a bit faster

* change mbsz back so it is correct for grpo

* another change mbsz back so it is correct for grpo

* fixing cli args

* nits

* adding docs

* docs

* include tensor parallel size for vllm in pydantic schema

* moving start_vllm, more docs

* limit output len for grpo vllm

* vllm enable_prefix_caching isn't a bool cli arg

* fix env ordering in tests and also use pid check when looking for vllm

---------

Co-authored-by: Salman Mohammadi <salman.mohammadi@outlook.com>
This commit is contained in:
Wing Lian
2025-03-31 15:47:11 -04:00
committed by GitHub
parent b35992262e
commit b6fc46ada8
24 changed files with 703 additions and 349 deletions

View File

@@ -42,8 +42,7 @@ jobs:
cuda_version: 12.4.1 cuda_version: 12.4.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.6.0 pytorch: 2.6.0
# awaiting vllm#12721 axolotl_extras: vllm
axolotl_extras:
num_gpus: 2 num_gpus: 2
nightly_build: "true" nightly_build: "true"
runs-on: [self-hosted, modal] runs-on: [self-hosted, modal]

View File

@@ -256,7 +256,7 @@ jobs:
python_version: "3.11" python_version: "3.11"
pytorch: 2.6.0 pytorch: 2.6.0
num_gpus: 1 num_gpus: 1
axolotl_extras: axolotl_extras: vllm
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4

View File

@@ -40,6 +40,7 @@ quartodoc:
- cli.preprocess - cli.preprocess
- cli.sweeps - cli.sweeps
- cli.utils - cli.utils
- cli.vllm_serve
- cli.cloud.base - cli.cloud.base
- cli.cloud.modal_ - cli.cloud.modal_
- title: Trainers - title: Trainers

View File

@@ -2,4 +2,5 @@
set -e set -e
# only run one test at a time so as not to OOM the GPU # only run one test at a time so as not to OOM the GPU
pytest -v -n2 /workspace/axolotl/tests/e2e/multigpu/ pytest -v -n2 /workspace/axolotl/tests/e2e/multigpu/ --ignore=/workspace/axolotl/tests/e2e/multigpu/solo/
pytest -v -n1 /workspace/axolotl/tests/e2e/multigpu/solo/

View File

@@ -238,10 +238,10 @@ simpo_gamma: 0.5 # Target reward margin for the SimPO loss
# grpo # grpo
trl: trl:
use_vllm: # Optional[bool]. Whether to use VLLM for RL training. use_vllm: # Optional[bool]. Whether to use VLLM for RL training.
vllm_device: # Optional[str]. Device to use for VLLM. vllm_server_host: # Optional[str]. Host of the vLLM server to connect to.
vllm_gpu_memory_utilization: # Optional[float]. GPU memory utilization for VLLM. vllm_server_port: # Optional[int]. Port of the vLLM server to connect to.
vllm_max_model_len: # Optional[int]. Maximum length of the model for VLLM. vllm_server_timeout: # Optional[int]. Total timeout (in seconds) to wait for the vLLM server to respond.
vllm_dtype: # Optional[str]. Data type for VLLM. vllm_guided_decoding_regex: # Optional[str]. Regex for vLLM guided decoding.
beta: # Optional[float]. Beta parameter for the RL training. Same as `rl_beta`. Use beta: # Optional[float]. Beta parameter for the RL training. Same as `rl_beta`. Use
max_completion_length: # Optional[int]. Maximum length of the completion for RL training. max_completion_length: # Optional[int]. Maximum length of the completion for RL training.

View File

@@ -502,9 +502,48 @@ The input format is a simple JSON input with customizable fields based on the ab
Check out our [GRPO cookbook](https://github.com/axolotl-ai-cloud/axolotl-cookbook/tree/main/grpo#training-an-r1-style-large-language-model-using-grpo). Check out our [GRPO cookbook](https://github.com/axolotl-ai-cloud/axolotl-cookbook/tree/main/grpo#training-an-r1-style-large-language-model-using-grpo).
::: :::
If you have multiple GPUs available, we reccomend using `vLLM` with the `GRPOTrainer` to significantly speedup trajectory generation during training.
First, launch a `vLLM` server using `trl vllm-serve` - you may use a config file or CLI overrides to configure your vLLM server. In this example, we're
using 4 GPUs - 2 for training, and 2 for vLLM:
::: {.callout-important}
Make sure you've installed the correct version of vLLM by including it as an extra when installing axolotl, e.g. `pip install axolotl[vllm]`.
:::
```yaml
base_model: Qwen/Qwen2.5-1.5B-Instruct
vllm:
host: 0.0.0.0
port: 8000
tensor_parallel_size: 2
gpu_memory_utilization: 0.85
dtype: auto
# max_model_len: # you may find it useful to set the vLLM model context length if you know this beforehand
rl: grpo
trl:
use_vllm: true
vllm_server_host: 0.0.0.0
vllm_server_port: 8000
vllm_server_timeout: 300
```
```bash
CUDA_VISIBLE_DEVICES=2,3 axolotl vllm_serve grpo.yaml
```
Your `vLLM` instance will now attempt to spin up, and it's time to kick off training utilizing our remaining two GPUs. In another terminal, execute:
```bash
CUDA_VISIBLE_DEVICES=0,1 axolotl train grpo.yaml --num-processes 2
```
#### Reward functions
GRPO uses custom reward functions and transformations. Please have them ready locally. GRPO uses custom reward functions and transformations. Please have them ready locally.
For ex, to load OpenAI's GSM8K and use a random reward for completions: For example, to load OpenAI's GSM8K and use a random reward for completions:
```python ```python
# rewards.py # rewards.py
@@ -530,8 +569,6 @@ trl:
beta: 0.001 beta: 0.001
max_completion_length: 256 max_completion_length: 256
use_vllm: True use_vllm: True
vllm_device: auto
vllm_gpu_memory_utilization: 0.15
num_generations: 4 num_generations: 4
reward_funcs: ["rewards.rand_reward_func"] # format: '{file_name}.{fn_name}' reward_funcs: ["rewards.rand_reward_func"] # format: '{file_name}.{fn_name}'
reward_weights: [1.0] reward_weights: [1.0]

View File

@@ -1,7 +1,7 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
# START section of dependencies that don't install on Darwin/MacOS # START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.45.3 bitsandbytes==0.45.4
triton>=3.0.0 triton>=3.0.0
mamba-ssm==1.2.0.post1 mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1 xformers>=0.0.23.post1
@@ -17,7 +17,7 @@ tokenizers>=0.21.1
accelerate==1.5.2 accelerate==1.5.2
datasets==3.5.0 datasets==3.5.0
deepspeed==0.16.4 deepspeed==0.16.4
trl==0.15.1 trl==0.16.0
optimum==1.16.2 optimum==1.16.2
hf_transfer hf_transfer

View File

@@ -10,7 +10,7 @@ from pathlib import Path
from setuptools import find_packages, setup from setuptools import find_packages, setup
def parse_requirements(): def parse_requirements(extras_require_map):
_install_requires = [] _install_requires = []
_dependency_links = [] _dependency_links = []
with open("./requirements.txt", encoding="utf-8") as requirements_file: with open("./requirements.txt", encoding="utf-8") as requirements_file:
@@ -67,6 +67,7 @@ def parse_requirements():
if (major, minor) >= (2, 6): if (major, minor) >= (2, 6):
_install_requires.pop(_install_requires.index(xformers_version)) _install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers==0.0.29.post2") _install_requires.append("xformers==0.0.29.post2")
extras_require_map["vllm"] = ["vllm==0.8.1"]
elif (major, minor) >= (2, 5): elif (major, minor) >= (2, 5):
_install_requires.pop(_install_requires.index(xformers_version)) _install_requires.pop(_install_requires.index(xformers_version))
if patch == 0: if patch == 0:
@@ -86,7 +87,7 @@ def parse_requirements():
except PackageNotFoundError: except PackageNotFoundError:
pass pass
return _install_requires, _dependency_links return _install_requires, _dependency_links, extras_require_map
def get_package_version(): def get_package_version():
@@ -103,7 +104,46 @@ def get_package_version():
return version_ return version_
install_requires, dependency_links = parse_requirements() extras_require = {
"flash-attn": ["flash-attn==2.7.4.post1"],
"ring-flash-attn": ["ring-flash-attn>=0.1.4", "yunchang==0.6.0"],
"deepspeed": [
"deepspeed==0.16.4",
"deepspeed-kernels",
],
"mamba-ssm": [
"mamba-ssm==1.2.0.post1",
"causal_conv1d",
],
"auto-gptq": [
"auto-gptq==0.5.1",
],
"mlflow": [
"mlflow",
],
"galore": [
"galore_torch",
],
"apollo": [
"apollo-torch",
],
"optimizers": [
"galore_torch",
"apollo-torch",
"lomo-optim==0.1.1",
"torch-optimi==0.2.1",
],
"ray": [
"ray[train]",
],
"vllm": [
"vllm==0.7.2",
],
}
install_requires, dependency_links, extras_require_build = parse_requirements(
extras_require
)
setup( setup(
version=get_package_version(), version=get_package_version(),
@@ -116,40 +156,5 @@ setup(
"axolotl=axolotl.cli.main:main", "axolotl=axolotl.cli.main:main",
], ],
}, },
extras_require={ extras_require=extras_require_build,
"flash-attn": ["flash-attn==2.7.4.post1"],
"ring-flash-attn": ["ring-flash-attn>=0.1.4", "yunchang==0.6.0"],
"deepspeed": [
"deepspeed==0.16.4",
"deepspeed-kernels",
],
"mamba-ssm": [
"mamba-ssm==1.2.0.post1",
"causal_conv1d",
],
"auto-gptq": [
"auto-gptq==0.5.1",
],
"mlflow": [
"mlflow",
],
"galore": [
"galore_torch",
],
"apollo": [
"apollo-torch",
],
"optimizers": [
"galore_torch",
"apollo-torch",
"lomo-optim==0.1.1",
"torch-optimi==0.2.1",
],
"ray": [
"ray[train]",
],
"vllm": [
"vllm==0.7.2",
],
},
) )

View File

@@ -35,6 +35,55 @@ class TrainerCliArgs:
num_processes: Optional[int] = field(default=None) num_processes: Optional[int] = field(default=None)
@dataclass
class VllmServeCliArgs:
"""Dataclass with CLI arguments for `axolotl vllm-serve` command."""
tensor_parallel_size: int = field(
default=1,
metadata={"help": "Number of tensor parallel workers to use."},
)
host: str = field(
default="0.0.0.0", # nosec B104
metadata={"help": "Host address to run the server on."},
)
port: int = field(
default=8000,
metadata={"help": "Port to run the server on."},
)
gpu_memory_utilization: Optional[float] = field(
default=None,
metadata={
"help": "Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV "
"cache on the device dedicated to generation powered by vLLM. Higher values will increase the KV cache "
"size and thus improve the model's throughput. However, if the value is too high, it may cause "
"out-of-memory (OOM) errors during initialization."
},
)
dtype: Optional[str] = field(
default=None,
metadata={
"help": "Data type to use for vLLM generation. If set to 'auto', the data type will be automatically "
"determined based on the model configuration. Find the supported values in the vLLM documentation."
},
)
max_model_len: Optional[int] = field(
default=None,
metadata={
"help": "If set, the `max_model_len` to use for vLLM. This can be useful when running with reduced "
"`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model "
"context size, which might be much larger than the KV cache, leading to inefficiencies."
},
)
enable_prefix_caching: Optional[bool] = field(
default=None,
metadata={
"help": "Whether to enable prefix caching in vLLM. If set to `True`, ensure that the model and the "
"hardware support this feature."
},
)
@dataclass @dataclass
class EvaluateCliArgs: class EvaluateCliArgs:
"""Dataclass with CLI arguments for `axolotl evaluate` command.""" """Dataclass with CLI arguments for `axolotl evaluate` command."""

View File

@@ -14,7 +14,12 @@ import yaml
from dotenv import load_dotenv from dotenv import load_dotenv
import axolotl import axolotl
from axolotl.cli.args import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs from axolotl.cli.args import (
EvaluateCliArgs,
PreprocessCliArgs,
TrainerCliArgs,
VllmServeCliArgs,
)
from axolotl.cli.sweeps import generate_sweep_configs from axolotl.cli.sweeps import generate_sweep_configs
from axolotl.cli.utils import ( from axolotl.cli.utils import (
add_options_from_config, add_options_from_config,
@@ -23,6 +28,7 @@ from axolotl.cli.utils import (
fetch_from_github, fetch_from_github,
filter_none_kwargs, filter_none_kwargs,
) )
from axolotl.cli.vllm_serve import do_vllm_serve
from axolotl.integrations.lm_eval.cli import lm_eval from axolotl.integrations.lm_eval.cli import lm_eval
from axolotl.utils import set_pytorch_cuda_alloc_conf from axolotl.utils import set_pytorch_cuda_alloc_conf
from axolotl.utils.schemas.config import AxolotlInputConfig from axolotl.utils.schemas.config import AxolotlInputConfig
@@ -316,6 +322,14 @@ def fetch(directory: str, dest: Optional[str]) -> None:
fetch_from_github(f"{directory}/", dest) fetch_from_github(f"{directory}/", dest)
@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@add_options_from_dataclass(VllmServeCliArgs)
@filter_none_kwargs
def vllm_serve(config: str, **cli_args: VllmServeCliArgs):
do_vllm_serve(config, cli_args)
cli.add_command(lm_eval) cli.add_command(lm_eval)

View File

@@ -0,0 +1,55 @@
"""
CLI to start the vllm server for online RL
"""
from pathlib import Path
from typing import Union
from trl.scripts.vllm_serve import ScriptArguments
from trl.scripts.vllm_serve import main as vllm_serve_main
from axolotl.cli.config import load_cfg
def do_vllm_serve(
config: Union[Path, str],
cli_args: dict,
):
"""
Starts the VLLM server for serving LLM models used for online RL
Args
:param cfg: Parsed doct of the YAML config
:param cli_args: dict of additional command-line arguments of type VllmServeCliArgs
Returns:
process_id: the process id of the started VLLM server
"""
cfg = load_cfg(config)
model = cfg.base_model
tensor_parallel_size = (
cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size
)
host = cli_args.get("host") or cfg.vllm.host
port = cli_args.get("port") or cfg.vllm.port
gpu_memory_utilization = (
cli_args.get("gpu_memory_utilization") or cfg.vllm.gpu_memory_utilization
)
dtype = cli_args.get("dtype") or cfg.vllm.dtype
max_model_len = cli_args.get("max_model_len") or cfg.vllm.max_model_len
enable_prefix_caching = (
cli_args.get("enable_prefix_caching") or cfg.vllm.enable_prefix_caching
)
vllm_script_args = ScriptArguments(
model,
tensor_parallel_size=tensor_parallel_size,
host=host,
port=port,
gpu_memory_utilization=gpu_memory_utilization,
dtype=dtype,
max_model_len=max_model_len,
enable_prefix_caching=enable_prefix_caching,
)
vllm_serve_main(vllm_script_args)

View File

@@ -40,18 +40,15 @@ class GRPOStrategy:
if trl.use_vllm: if trl.use_vllm:
grpo_args_kwargs["use_vllm"] = trl.use_vllm grpo_args_kwargs["use_vllm"] = trl.use_vllm
grpo_args_kwargs["vllm_device"] = ( grpo_args_kwargs["vllm_server_host"] = trl.vllm_server_host
trl.vllm_device if trl.vllm_device else "auto" grpo_args_kwargs["vllm_server_port"] = trl.vllm_server_port
) if trl.vllm_server_timeout:
grpo_args_kwargs["vllm_server_timeout"] = trl.vllm_server_timeout
if trl.vllm_gpu_memory_utilization: if trl.vllm_guided_decoding_regex:
grpo_args_kwargs["vllm_gpu_memory_utilization"] = ( grpo_args_kwargs["vllm_guided_decoding_regex"] = (
trl.vllm_gpu_memory_utilization trl.vllm_guided_decoding_regex
) )
if trl.vllm_max_model_len:
grpo_args_kwargs["vllm_max_model_len"] = trl.vllm_max_model_len
if trl.num_generations: if trl.num_generations:
grpo_args_kwargs["num_generations"] = trl.num_generations grpo_args_kwargs["num_generations"] = trl.num_generations
@@ -70,6 +67,25 @@ class GRPOStrategy:
if trl.reward_weights: if trl.reward_weights:
grpo_args_kwargs["reward_weights"] = trl.reward_weights grpo_args_kwargs["reward_weights"] = trl.reward_weights
if trl.scale_rewards is not None:
grpo_args_kwargs["scale_rewards"] = trl.scale_rewards
if trl.temperature is not None:
grpo_args_kwargs["temperature"] = trl.temperature
if trl.top_p is not None:
grpo_args_kwargs["top_p"] = trl.top_p
if trl.top_k is not None:
grpo_args_kwargs["top_k"] = trl.top_k
if trl.min_p is not None:
grpo_args_kwargs["min_p"] = trl.min_p
if trl.repetition_penalty is not None:
grpo_args_kwargs["repetition_penalty"] = trl.repetition_penalty
if trl.num_iterations is not None:
grpo_args_kwargs["num_iterations"] = trl.num_iterations
if trl.epsilon is not None:
grpo_args_kwargs["epsilon"] = trl.epsilon
return grpo_args_kwargs return grpo_args_kwargs
@classmethod @classmethod

View File

@@ -2,16 +2,18 @@
Axolotl GRPO trainer Axolotl GRPO trainer
""" """
from accelerate.utils import is_peft_model from contextlib import nullcontext
from accelerate.utils.other import is_compiled_module
from transformers import PreTrainedModel from accelerate.utils import is_deepspeed_available, is_peft_model
from trl import GRPOConfig, GRPOTrainer from trl import GRPOTrainer
from trl.models import unwrap_model_for_generation from trl.extras.profiling import profiling_decorator
from axolotl.core.trainers.base import SchedulerMixin from axolotl.core.trainers.base import SchedulerMixin
if is_deepspeed_available():
import deepspeed
# mypy: ignore-errors
class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer): class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
""" """
Extend the base GRPOTrainer for axolotl helpers Extend the base GRPOTrainer for axolotl helpers
@@ -19,91 +21,49 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
_tag_names = ["trl", "grpo", "axolotl"] _tag_names = ["trl", "grpo", "axolotl"]
def __init__(self, *args, **kwargs): @profiling_decorator
super().__init__(*args, **kwargs) def _move_model_to_vllm(self):
# For DeepSpeed ZeRO-3, we need to gather all parameters before operations
# pylint: disable=access-member-before-definition deepspeed_plugin = self.accelerator.state.deepspeed_plugin
# Enable gradient checkpointing if requested zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3
if kwargs["args"].gradient_checkpointing: gather_if_zero3 = (
# Ensure use_cache is disabled deepspeed.zero.GatheredParameters if zero_stage_3 else nullcontext
if hasattr(self.model, "config"):
self.model.config.use_cache = False
# Enable gradient checkpointing on the base model for PEFT
if is_peft_model(self.model) and hasattr(
self.model.base_model, "gradient_checkpointing_enable"
):
self.model.base_model.gradient_checkpointing_enable()
# Enable gradient checkpointing for non-PEFT models
elif hasattr(self.model, "gradient_checkpointing_enable"):
self.model.gradient_checkpointing_enable()
self.model = self._enable_gradient_checkpointing(self.model, kwargs["args"])
# pylint: enable=access-member-before-definition
def _enable_gradient_checkpointing(
self, model: PreTrainedModel, args: GRPOConfig
) -> PreTrainedModel:
"""Enables gradient checkpointing for the model."""
# pylint: disable=unused-argument,redefined-builtin
gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
use_reentrant = (
"use_reentrant" not in gradient_checkpointing_kwargs
or gradient_checkpointing_kwargs["use_reentrant"]
) )
if use_reentrant: if is_peft_model(self.model):
if hasattr(model, "enable_input_require_grads"): # With PEFT and DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as merging
model.enable_input_require_grads() # adapters in a sharded manner is not supported.
else: with gather_if_zero3(list(self.model.parameters())):
self.model.merge_adapter()
def make_inputs_require_grad(module, input, output): # Update vLLM weights while parameters are gathered
output.requires_grad_(True) for name, param in self.model.named_parameters():
# When using PEFT, we need to recover the original parameter name and discard some parameters
name = (
name.removeprefix("base_model.model.")
.removeprefix("base_model.model.")
.replace(".base_layer", "")
)
if self.model.prefix in name:
continue
# When module to save, remove its prefix and discard the original module
if "original_module" in name:
continue
name = name.replace("modules_to_save.default.", "")
model.get_input_embeddings().register_forward_hook( if self.accelerator.is_main_process:
make_inputs_require_grad self.vllm_client.update_named_param(name, param.data)
)
return model # Unmerge adapters while parameters are still gathered
# pylint: enable=unused-argument,redefined-builtin self.model.unmerge_adapter()
# Parameters will automatically be repartitioned when exiting the context
else:
# For non-PEFT models, simply gather and update each parameter individually.
for name, param in self.model.named_parameters():
with gather_if_zero3([param]):
if self.accelerator.is_main_process:
self.vllm_client.update_named_param(name, param.data)
def _move_model_to_vllm(self): # Reset cache on main process
with unwrap_model_for_generation( if self.accelerator.is_main_process:
self.model, self.vllm_client.reset_prefix_cache()
self.accelerator,
gather_deepspeed3_params=self.args.ds3_gather_for_generation,
) as unwrapped_model:
if is_compiled_module(unwrapped_model):
unwrapped_model = (
unwrapped_model._orig_mod # pylint: disable=protected-access
)
if is_peft_model(unwrapped_model):
unwrapped_model.merge_adapter()
state_dict = unwrapped_model.state_dict()
# Remove base_model and base_layer prefixes
state_dict = {
k.removeprefix("base_model.model.")
.removeprefix("base_model.model.")
.replace(".base_layer", ""): v
for k, v in state_dict.items()
}
# Remove values with adapter prefix (example: "_lora")
state_dict = {
k: v
for k, v in state_dict.items()
if unwrapped_model.prefix not in k
}
# When module to save, remove its prefix and discard the original module
state_dict = {
k.replace("modules_to_save.default.", ""): v
for k, v in state_dict.items()
if "original_module" not in k
}
else:
state_dict = unwrapped_model.state_dict()
if self.accelerator.is_main_process:
llm_model = (
self.llm.llm_engine.model_executor.driver_worker.model_runner.model
)
llm_model.load_weights(state_dict.items())
if is_peft_model(unwrapped_model):
unwrapped_model.unmerge_adapter()

View File

@@ -46,6 +46,7 @@ from axolotl.utils.schemas.multimodal import MultiModalConfig
from axolotl.utils.schemas.peft import LoraConfig, ReLoRAConfig from axolotl.utils.schemas.peft import LoraConfig, ReLoRAConfig
from axolotl.utils.schemas.training import HyperparametersConfig from axolotl.utils.schemas.training import HyperparametersConfig
from axolotl.utils.schemas.trl import TRLConfig from axolotl.utils.schemas.trl import TRLConfig
from axolotl.utils.schemas.vllm import VllmConfig
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@@ -86,6 +87,9 @@ class AxolotlInputConfig(
trl: TRLConfig | None = Field( trl: TRLConfig | None = Field(
default_factory=lambda: TRLConfig(), # pylint: disable=unnecessary-lambda default_factory=lambda: TRLConfig(), # pylint: disable=unnecessary-lambda
) )
vllm: VllmConfig | None = Field(
default_factory=lambda: VllmConfig(), # pylint: disable=unnecessary-lambda
)
reward_model: bool | None = None reward_model: bool | None = None
process_reward_model: bool | None = None process_reward_model: bool | None = None
num_labels: int | None = None num_labels: int | None = None

View File

@@ -20,27 +20,30 @@ class TRLConfig(BaseModel):
) )
# GRPO specific args # GRPO specific args
# Ref: https://github.com/huggingface/trl/blob/e3244d2d096ff1e2e248c931d06d39e165e20623/trl/trainer/grpo_config.py#L22 # Ref: https://github.com/huggingface/trl/blob/26d86757a7c7e24e397ea44f57ecce6031dfac01/trl/trainer/grpo_config.py#L23
use_vllm: bool | None = Field( use_vllm: bool = Field(
default=False, default=False,
json_schema_extra={"description": "Whether to use VLLM for RL training"}, json_schema_extra={"description": "Whether to use VLLM for RL training"},
) )
vllm_device: str | None = Field( vllm_server_host: str | None = Field(
default="auto", default="0.0.0.0", # nosec B104
json_schema_extra={"description": "Device to use for VLLM"}, json_schema_extra={"description": "Host of the vLLM server to connect to"},
) )
vllm_gpu_memory_utilization: float | None = Field( vllm_server_port: int | None = Field(
default=0.9, default=8000,
json_schema_extra={"description": "GPU memory utilization for VLLM"}, json_schema_extra={"description": "Port of the vLLM server to connect to"},
) )
vllm_dtype: str | None = Field( vllm_server_timeout: int | None = Field(
default="auto",
json_schema_extra={"description": "Data type for VLLM"},
)
vllm_max_model_len: int | None = Field(
default=None, default=None,
json_schema_extra={ json_schema_extra={
"description": "Maximum length of the model context for VLLM" "description": "Total timeout duration in seconds to wait for the vLLM server to be up. If the server is not up "
"after the timeout, a `ConnectionError` is raised."
},
)
vllm_guided_decoding_regex: str | None = Field(
default=None,
json_schema_extra={
"description": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."
}, },
) )
@@ -85,3 +88,48 @@ class TRLConfig(BaseModel):
"description": "Sync steps for the reference model. Requires `sync_ref_model=True`." "description": "Sync steps for the reference model. Requires `sync_ref_model=True`."
}, },
) )
scale_rewards: bool = Field(
default=True,
json_schema_extra={
"description": "Whether to scale the rewards for GRPO by dividing them by their standard deviation."
},
)
temperature: float | None = Field(
default=None,
json_schema_extra={"description": "Sampling temperature for the GRPO policy."},
)
top_p: float | None = Field(
default=None,
json_schema_extra={
"description": "Top-p sampling probability for the generation policy."
},
)
top_k: int | None = Field(
default=None,
json_schema_extra={"description": "Top-k sampling for the generation policy."},
)
min_p: float | None = Field(
default=None,
json_schema_extra={
"description": "Minimum probability for the generation policy."
},
)
repetition_penalty: float | None = Field(
default=None,
json_schema_extra={
"description": "Float that penalizes new tokens based on whether they appear in the prompt and the generated text so far."
},
)
num_iterations: int | None = Field(
default=None,
json_schema_extra={
"description": "Number of iterations per batch (denoted as μ in the algorithm) for GRPO."
},
)
epsilon: float | None = Field(
default=None,
json_schema_extra={
"description": "Epsilon value for clipping in the GRPO algorithm."
},
)

View File

@@ -0,0 +1,38 @@
"""
Pydantic models for VLLM configuration, used primarily for RL training with TRL + grpo
"""
from pydantic import BaseModel, Field
class VllmConfig(BaseModel):
"""
Configuration for VLLM server
"""
device: str | None = Field(
default="auto",
json_schema_extra={"description": "Device to use for VLLM"},
)
tensor_parallel_size: int | None = Field(
default=None,
json_schema_extra={"description": "Tensor parallel size for VLLM"},
)
gpu_memory_utilization: float | None = Field(
default=0.9,
json_schema_extra={"description": "GPU memory utilization for VLLM"},
)
dtype: str | None = Field(
default="auto",
json_schema_extra={"description": "Data type for VLLM"},
)
max_model_len: int | None = Field(
default=None,
json_schema_extra={
"description": "Maximum length of the model context for VLLM"
},
)
enable_prefix_caching: bool | None = Field(
default=None,
json_schema_extra={"description": "Enable prefix caching for VLLM"},
)

View File

@@ -100,6 +100,14 @@ def download_argilla_distilabel_capybara_dpo_7k_binarized_dataset():
) )
@pytest.fixture(scope="session", autouse=True)
def download_argilla_distilabel_intel_orca_dpo_dataset():
# download the dataset
snapshot_download_w_retry(
"argilla/distilabel-intel-orca-dpo-pairs", repo_type="dataset"
)
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
def download_argilla_ultrafeedback_binarized_preferences_cleaned_dataset(): def download_argilla_ultrafeedback_binarized_preferences_cleaned_dataset():
# download the dataset # download the dataset

View File

View File

@@ -0,0 +1,294 @@
"""
GRPO test suite
"""
import os
import random
import subprocess # nosec B404
import sys
import time
from pathlib import Path
import pytest
import requests
import yaml
from accelerate.test_utils import execute_subprocess_async
from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import require_vllm
def start_vllm(
model: str, env: dict | None = None, wait: int | None = None, quiet=False, **kwargs
) -> int:
"""
helper function to start the VLLM server in the background, mostly for testing purposes
"""
cmd = [sys.executable, "-m", "trl.scripts.vllm_serve", "--model", model]
if tensor_parallel_size := kwargs.get("tensor_parallel_size"):
cmd.extend(["--tensor-parallel-size", str(tensor_parallel_size)])
if host := kwargs.get("host"):
cmd.extend(["--host", host])
if port := kwargs.get("port"):
cmd.extend(["--port", str(port)])
if gpu_memory_utilization := kwargs.get("gpu_memory_utilization"):
cmd.extend(["--gpu-memory-utilization", str(gpu_memory_utilization)])
if dtype := kwargs.get("dtype"):
cmd.extend(["--dtype", dtype])
if max_model_len := kwargs.get("max_model_len"):
cmd.extend(["--max-model-len", str(max_model_len)])
if kwargs.get("enable_prefix_caching"):
cmd.extend(["--enable-prefix-caching", "True"])
# print out the command to be executed
print(" ".join(cmd))
# start `trl vllm-serve` command in the background and capture the process id
process = subprocess.Popen( # pylint: disable=consider-using-with
cmd,
env=env,
stdout=subprocess.DEVNULL if quiet else subprocess.PIPE,
stderr=subprocess.DEVNULL if quiet else subprocess.PIPE,
) # nosec B603
# print out the process id so the user can easily kill it later
print(f"VLLM server process started (PID: {process.pid})")
# wait until the http server is ready, even if it 404s, but timeout after 60 seconds
started = False
if wait and host and port:
for _ in range(int(wait)):
try:
response = requests.get(f"http://{host}:{port}", timeout=1)
if int(response.status_code) in [200, 404]:
started = True
break
except requests.exceptions.RequestException:
pass
# also check if the process.pid is still running
if not process.poll() is None:
break
time.sleep(1)
if wait and not started:
print(
f"VLLM server process did not start within {wait} seconds. Please check your server logs."
)
process.kill()
raise RuntimeError(f"VLLM server process did not start within {wait} seconds.")
# return the process id
return process.pid
class TestGRPO:
"""
Test case for GRPO training using multilpe GPUs
"""
def _utils_write_yaml_and_rewards(self, cfg, temp_dir, suffix=""):
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
with open(f"rewards_{suffix}.py", "w", encoding="utf-8") as fout:
fout.write(
"""import random
def rand_reward_func(completions, **kwargs) -> list[float]:
return [random.uniform(0, 1) for _ in completions]
def oai_gsm8k_transform(cfg, *args, **kwargs):
def transform_fn(example, tokenizer=None):
label = example["answer"].split("####")[-1].strip().replace(",", "")
return {
"prompt": [{"role": "user", "content": example["question"]},],
"answer": label,
}
return transform_fn, {"remove_columns": ["question"]}
"""
)
@pytest.mark.parametrize(
"num_gpus",
[1, 2],
)
@require_vllm
def test_llama_dora(self, temp_dir, num_gpus):
rnd_reward_suffix = str(random.randint(1000, 9999))
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"chat_template": "llama3",
"rl": "grpo",
"trl": {
"beta": 0.001,
"max_completion_length": 256,
"use_vllm": True,
"num_generations": 4,
"reward_funcs": [f"rewards_{rnd_reward_suffix}.rand_reward_func"],
},
"vllm": {
"max_model_len": 800,
"enable_prefix_caching": True,
},
"datasets": [
{
"path": "openai/gsm8k",
"name": "main",
"type": f"rewards_{rnd_reward_suffix}.oai_gsm8k_transform",
},
],
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"peft_use_dora": True,
"flash_attention": True,
"sequence_len": 1024,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"max_steps": 3,
"num_epochs": 1,
"micro_batch_size": 4,
"gradient_accumulation_steps": 2,
"warmup_steps": 10,
"val_set_size": 0.0,
"output_dir": temp_dir,
"learning_rate": 0.0001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
}
)
self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_reward_suffix)
current_env = os.environ.copy()
env = {
"NCCL_P2P_LEVEL": "LOC",
**current_env,
"CUDA_VISIBLE_DEVICES": "1",
}
vllm_process_id = start_vllm(
cfg.base_model,
env=env,
quiet=True,
wait=120,
gpu_memory_utilization=0.15,
max_model_len=cfg.vllm.max_model_len,
enable_prefix_caching=cfg.vllm.enable_prefix_caching,
host="0.0.0.0",
port=8000,
)
try:
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
str(num_gpus),
"--main-process-port",
f"{get_torch_dist_unique_port()}",
],
env={"NCCL_P2P_LEVEL": "LOC", "NCCL_DEBUG": "INFO", **current_env},
)
finally:
os.kill(vllm_process_id, 9)
@pytest.mark.parametrize(
"num_gpus",
[1, 2],
)
@require_vllm
def test_llama_fft(self, temp_dir, num_gpus):
rnd_reward_suffix = str(random.randint(1000, 9999))
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"chat_template": "llama3",
"rl": "grpo",
"trl": {
"beta": 0.001,
"max_completion_length": 256,
"use_vllm": True,
"num_generations": 4,
"reward_funcs": [f"rewards_{rnd_reward_suffix}.rand_reward_func"],
},
"vllm": {
"max_model_len": 800,
"enable_prefix_caching": True,
},
"datasets": [
{
"path": "openai/gsm8k",
"name": "main",
"type": f"rewards_{rnd_reward_suffix}.oai_gsm8k_transform",
},
],
"flash_attention": True,
"sequence_len": 1024,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"max_steps": 3,
"num_epochs": 1,
"micro_batch_size": 4,
"gradient_accumulation_steps": 2,
"warmup_steps": 10,
"val_set_size": 0.0,
"output_dir": temp_dir,
"learning_rate": 0.0001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
}
)
self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_reward_suffix)
current_env = os.environ.copy()
env = {
"NCCL_P2P_LEVEL": "LOC", # nccl can be brittle, assume P2P isn't reliable
**current_env,
"CUDA_VISIBLE_DEVICES": "1",
}
vllm_process_id = start_vllm(
cfg.base_model,
env=env,
quiet=True,
wait=120,
gpu_memory_utilization=0.15,
max_model_len=cfg.vllm.max_model_len,
enable_prefix_caching=cfg.vllm.enable_prefix_caching,
host="0.0.0.0",
port=8000,
)
try:
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
str(num_gpus),
"--main-process-port",
f"{get_torch_dist_unique_port()}",
],
env={"NCCL_P2P_LEVEL": "LOC", "NCCL_DEBUG": "INFO", **current_env},
)
finally:
os.kill(vllm_process_id, 9)

View File

@@ -52,9 +52,9 @@ class TestMultiGPUEval:
}, },
], ],
"num_epochs": 1, "num_epochs": 1,
"max_steps": 5, "max_steps": 2,
"micro_batch_size": 2, "micro_batch_size": 2,
"gradient_accumulation_steps": 4, "gradient_accumulation_steps": 2,
"output_dir": temp_dir, "output_dir": temp_dir,
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_8bit", "optimizer": "adamw_8bit",
@@ -121,9 +121,9 @@ class TestMultiGPUEval:
}, },
], ],
"num_epochs": 1, "num_epochs": 1,
"max_steps": 5, "max_steps": 2,
"micro_batch_size": 2, "micro_batch_size": 2,
"gradient_accumulation_steps": 4, "gradient_accumulation_steps": 2,
"output_dir": temp_dir, "output_dir": temp_dir,
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_8bit", "optimizer": "adamw_8bit",

View File

@@ -1,175 +0,0 @@
"""
GRPO test suite
"""
import random
from pathlib import Path
import pytest
import yaml
from accelerate.test_utils import execute_subprocess_async
from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import require_vllm
class TestGRPO:
"""
Test case for GRPO training using multilpe GPUs
"""
def _utils_write_yaml_and_rewards(self, cfg, temp_dir, suffix=""):
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
with open(f"rewards_{suffix}.py", "w", encoding="utf-8") as fout:
fout.write(
"""import random
def rand_reward_func(completions, **kwargs) -> list[float]:
return [random.uniform(0, 1) for _ in completions]
def oai_gsm8k_transform(cfg, *args, **kwargs):
def transform_fn(example, tokenizer=None):
label = example["answer"].split("####")[-1].strip().replace(",", "")
return {
"prompt": [{"role": "user", "content": example["question"]},],
"answer": label,
}
return transform_fn, {"remove_columns": ["question"]}
"""
)
@pytest.mark.parametrize(
"num_gpus",
[1, 2],
)
@require_vllm
def test_llama_dora(self, temp_dir, num_gpus):
rnd_reward_suffix = str(random.randint(1000, 9999))
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"chat_template": "llama3",
"rl": "grpo",
"trl": {
"beta": 0.001,
"max_completion_length": 256,
"use_vllm": True,
"vllm_device": "auto" if num_gpus == 1 else "cuda:1",
"vllm_gpu_memory_utilization": 0.15,
"num_generations": 4,
"reward_funcs": [f"rewards_{rnd_reward_suffix}.rand_reward_func"],
},
"datasets": [
{
"path": "openai/gsm8k",
"name": "main",
"type": f"rewards_{rnd_reward_suffix}.oai_gsm8k_transform",
},
],
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"peft_use_dora": True,
"flash_attention": True,
"sequence_len": 1024,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"max_steps": 5,
"num_epochs": 1,
"micro_batch_size": 4,
"gradient_accumulation_steps": 2,
"warmup_steps": 10,
"val_set_size": 0.0,
"output_dir": temp_dir,
"learning_rate": 0.0001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
}
)
self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_reward_suffix)
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
str(num_gpus),
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)
@pytest.mark.parametrize(
"num_gpus",
[1, 2],
)
@require_vllm
def test_llama_fft(self, temp_dir, num_gpus):
rnd_reward_suffix = str(random.randint(1000, 9999))
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"chat_template": "llama3",
"rl": "grpo",
"trl": {
"beta": 0.001,
"max_completion_length": 256,
"use_vllm": True,
"vllm_device": "auto" if num_gpus == 1 else "cuda:1",
"vllm_gpu_memory_utilization": 0.15,
"num_generations": 4,
"reward_funcs": [f"rewards_{rnd_reward_suffix}.rand_reward_func"],
},
"datasets": [
{
"path": "openai/gsm8k",
"name": "main",
"type": f"rewards_{rnd_reward_suffix}.oai_gsm8k_transform",
},
],
"flash_attention": True,
"sequence_len": 1024,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"max_steps": 5,
"num_epochs": 1,
"micro_batch_size": 4,
"gradient_accumulation_steps": 2,
"warmup_steps": 10,
"val_set_size": 0.0,
"output_dir": temp_dir,
"learning_rate": 0.0001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
}
)
self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_reward_suffix)
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
str(num_gpus),
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)

View File

@@ -399,7 +399,7 @@ class TestMultiGPULlama:
"num_epochs": 1, "num_epochs": 1,
"max_steps": 2, "max_steps": 2,
"micro_batch_size": 4, "micro_batch_size": 4,
"gradient_accumulation_steps": 4, "gradient_accumulation_steps": 2,
"output_dir": temp_dir, "output_dir": temp_dir,
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_torch_fused", "optimizer": "adamw_torch_fused",
@@ -478,7 +478,7 @@ class TestMultiGPULlama:
"num_epochs": 1, "num_epochs": 1,
"max_steps": 2, "max_steps": 2,
"micro_batch_size": 4, "micro_batch_size": 4,
"gradient_accumulation_steps": 4, "gradient_accumulation_steps": 2,
"output_dir": temp_dir, "output_dir": temp_dir,
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_torch_fused", "optimizer": "adamw_torch_fused",
@@ -778,7 +778,7 @@ class TestMultiGPULlama:
}, },
], ],
"num_epochs": 1, "num_epochs": 1,
"max_steps": 5, "max_steps": 2,
"micro_batch_size": 1, "micro_batch_size": 1,
"gradient_accumulation_steps": 1, "gradient_accumulation_steps": 1,
"output_dir": temp_dir, "output_dir": temp_dir,

View File

@@ -46,7 +46,7 @@ class TestMultiGPUQwen2:
}, },
], ],
"num_epochs": 1, "num_epochs": 1,
"max_steps": 5, "max_steps": 2,
"warmup_steps": 20, "warmup_steps": 20,
"micro_batch_size": 2, "micro_batch_size": 2,
"gradient_accumulation_steps": 2, "gradient_accumulation_steps": 2,

View File

@@ -50,7 +50,7 @@ class TestMultiGPURay:
"num_epochs": 1, "num_epochs": 1,
"max_steps": 2, "max_steps": 2,
"micro_batch_size": 4, "micro_batch_size": 4,
"gradient_accumulation_steps": 4, "gradient_accumulation_steps": 2,
"output_dir": temp_dir, "output_dir": temp_dir,
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_8bit", "optimizer": "adamw_8bit",