Updates for trl 0.16.0 - mostly for GRPO (#2437) [skip ci]
* add grpo scale_rewards config for trl#3135 * options to connect to vllm server directly w grpo trl#3094 * temperature support trl#3029 * sampling/generation kwargs for grpo trl#2989 * make vllm_enable_prefix_caching a config param trl#2900 * grpo multi-step optimizeations trl#2899 * remove overrides for grpo trainer * bump trl to 0.16.0 * add cli to start vllm-serve via trl * call the python module directly * update to use vllm with 2.6.0 too now and call trl vllm serve from module * vllm 0.8.1 * use python3 * use sys.executable * remove context and wait for start * fixes to make it actually work * fixes so the grpo tests pass with new vllm paradigm * explicit host/port and check in start vllm * make sure that vllm doesn't hang by setting quiet so outouts go to dev null * also bump bnb to latest release * add option for wait from cli and nccl debugging for ci * grpo + vllm test on separate devices for now * make sure grpo + vllm tests runs single worker since pynccl comms would conflict * fix cli * remove wait and add caching for argilla dataset * refactoring configs * chore: lint * add vllm config * fixup vllm grpo args * fix one more incorrect schema/config path * fix another vlllm reference and increase timeout * make the tests run a bit faster * change mbsz back so it is correct for grpo * another change mbsz back so it is correct for grpo * fixing cli args * nits * adding docs * docs * include tensor parallel size for vllm in pydantic schema * moving start_vllm, more docs * limit output len for grpo vllm * vllm enable_prefix_caching isn't a bool cli arg * fix env ordering in tests and also use pid check when looking for vllm --------- Co-authored-by: Salman Mohammadi <salman.mohammadi@outlook.com>
This commit is contained in:
3
.github/workflows/multi-gpu-e2e.yml
vendored
3
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -42,8 +42,7 @@ jobs:
|
|||||||
cuda_version: 12.4.1
|
cuda_version: 12.4.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
# awaiting vllm#12721
|
axolotl_extras: vllm
|
||||||
axolotl_extras:
|
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
nightly_build: "true"
|
nightly_build: "true"
|
||||||
runs-on: [self-hosted, modal]
|
runs-on: [self-hosted, modal]
|
||||||
|
|||||||
2
.github/workflows/tests.yml
vendored
2
.github/workflows/tests.yml
vendored
@@ -256,7 +256,7 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
axolotl_extras:
|
axolotl_extras: vllm
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|||||||
@@ -40,6 +40,7 @@ quartodoc:
|
|||||||
- cli.preprocess
|
- cli.preprocess
|
||||||
- cli.sweeps
|
- cli.sweeps
|
||||||
- cli.utils
|
- cli.utils
|
||||||
|
- cli.vllm_serve
|
||||||
- cli.cloud.base
|
- cli.cloud.base
|
||||||
- cli.cloud.modal_
|
- cli.cloud.modal_
|
||||||
- title: Trainers
|
- title: Trainers
|
||||||
|
|||||||
@@ -2,4 +2,5 @@
|
|||||||
set -e
|
set -e
|
||||||
|
|
||||||
# only run one test at a time so as not to OOM the GPU
|
# only run one test at a time so as not to OOM the GPU
|
||||||
pytest -v -n2 /workspace/axolotl/tests/e2e/multigpu/
|
pytest -v -n2 /workspace/axolotl/tests/e2e/multigpu/ --ignore=/workspace/axolotl/tests/e2e/multigpu/solo/
|
||||||
|
pytest -v -n1 /workspace/axolotl/tests/e2e/multigpu/solo/
|
||||||
|
|||||||
@@ -238,10 +238,10 @@ simpo_gamma: 0.5 # Target reward margin for the SimPO loss
|
|||||||
# grpo
|
# grpo
|
||||||
trl:
|
trl:
|
||||||
use_vllm: # Optional[bool]. Whether to use VLLM for RL training.
|
use_vllm: # Optional[bool]. Whether to use VLLM for RL training.
|
||||||
vllm_device: # Optional[str]. Device to use for VLLM.
|
vllm_server_host: # Optional[str]. Host of the vLLM server to connect to.
|
||||||
vllm_gpu_memory_utilization: # Optional[float]. GPU memory utilization for VLLM.
|
vllm_server_port: # Optional[int]. Port of the vLLM server to connect to.
|
||||||
vllm_max_model_len: # Optional[int]. Maximum length of the model for VLLM.
|
vllm_server_timeout: # Optional[int]. Total timeout (in seconds) to wait for the vLLM server to respond.
|
||||||
vllm_dtype: # Optional[str]. Data type for VLLM.
|
vllm_guided_decoding_regex: # Optional[str]. Regex for vLLM guided decoding.
|
||||||
|
|
||||||
beta: # Optional[float]. Beta parameter for the RL training. Same as `rl_beta`. Use
|
beta: # Optional[float]. Beta parameter for the RL training. Same as `rl_beta`. Use
|
||||||
max_completion_length: # Optional[int]. Maximum length of the completion for RL training.
|
max_completion_length: # Optional[int]. Maximum length of the completion for RL training.
|
||||||
|
|||||||
@@ -502,9 +502,48 @@ The input format is a simple JSON input with customizable fields based on the ab
|
|||||||
Check out our [GRPO cookbook](https://github.com/axolotl-ai-cloud/axolotl-cookbook/tree/main/grpo#training-an-r1-style-large-language-model-using-grpo).
|
Check out our [GRPO cookbook](https://github.com/axolotl-ai-cloud/axolotl-cookbook/tree/main/grpo#training-an-r1-style-large-language-model-using-grpo).
|
||||||
:::
|
:::
|
||||||
|
|
||||||
|
If you have multiple GPUs available, we reccomend using `vLLM` with the `GRPOTrainer` to significantly speedup trajectory generation during training.
|
||||||
|
First, launch a `vLLM` server using `trl vllm-serve` - you may use a config file or CLI overrides to configure your vLLM server. In this example, we're
|
||||||
|
using 4 GPUs - 2 for training, and 2 for vLLM:
|
||||||
|
|
||||||
|
::: {.callout-important}
|
||||||
|
Make sure you've installed the correct version of vLLM by including it as an extra when installing axolotl, e.g. `pip install axolotl[vllm]`.
|
||||||
|
:::
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
base_model: Qwen/Qwen2.5-1.5B-Instruct
|
||||||
|
|
||||||
|
vllm:
|
||||||
|
host: 0.0.0.0
|
||||||
|
port: 8000
|
||||||
|
tensor_parallel_size: 2
|
||||||
|
gpu_memory_utilization: 0.85
|
||||||
|
dtype: auto
|
||||||
|
# max_model_len: # you may find it useful to set the vLLM model context length if you know this beforehand
|
||||||
|
|
||||||
|
rl: grpo
|
||||||
|
trl:
|
||||||
|
use_vllm: true
|
||||||
|
vllm_server_host: 0.0.0.0
|
||||||
|
vllm_server_port: 8000
|
||||||
|
vllm_server_timeout: 300
|
||||||
|
```
|
||||||
|
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=2,3 axolotl vllm_serve grpo.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
Your `vLLM` instance will now attempt to spin up, and it's time to kick off training utilizing our remaining two GPUs. In another terminal, execute:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1 axolotl train grpo.yaml --num-processes 2
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Reward functions
|
||||||
|
|
||||||
GRPO uses custom reward functions and transformations. Please have them ready locally.
|
GRPO uses custom reward functions and transformations. Please have them ready locally.
|
||||||
|
|
||||||
For ex, to load OpenAI's GSM8K and use a random reward for completions:
|
For example, to load OpenAI's GSM8K and use a random reward for completions:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# rewards.py
|
# rewards.py
|
||||||
@@ -530,8 +569,6 @@ trl:
|
|||||||
beta: 0.001
|
beta: 0.001
|
||||||
max_completion_length: 256
|
max_completion_length: 256
|
||||||
use_vllm: True
|
use_vllm: True
|
||||||
vllm_device: auto
|
|
||||||
vllm_gpu_memory_utilization: 0.15
|
|
||||||
num_generations: 4
|
num_generations: 4
|
||||||
reward_funcs: ["rewards.rand_reward_func"] # format: '{file_name}.{fn_name}'
|
reward_funcs: ["rewards.rand_reward_func"] # format: '{file_name}.{fn_name}'
|
||||||
reward_weights: [1.0]
|
reward_weights: [1.0]
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||||
|
|
||||||
# START section of dependencies that don't install on Darwin/MacOS
|
# START section of dependencies that don't install on Darwin/MacOS
|
||||||
bitsandbytes==0.45.3
|
bitsandbytes==0.45.4
|
||||||
triton>=3.0.0
|
triton>=3.0.0
|
||||||
mamba-ssm==1.2.0.post1
|
mamba-ssm==1.2.0.post1
|
||||||
xformers>=0.0.23.post1
|
xformers>=0.0.23.post1
|
||||||
@@ -17,7 +17,7 @@ tokenizers>=0.21.1
|
|||||||
accelerate==1.5.2
|
accelerate==1.5.2
|
||||||
datasets==3.5.0
|
datasets==3.5.0
|
||||||
deepspeed==0.16.4
|
deepspeed==0.16.4
|
||||||
trl==0.15.1
|
trl==0.16.0
|
||||||
|
|
||||||
optimum==1.16.2
|
optimum==1.16.2
|
||||||
hf_transfer
|
hf_transfer
|
||||||
|
|||||||
83
setup.py
83
setup.py
@@ -10,7 +10,7 @@ from pathlib import Path
|
|||||||
from setuptools import find_packages, setup
|
from setuptools import find_packages, setup
|
||||||
|
|
||||||
|
|
||||||
def parse_requirements():
|
def parse_requirements(extras_require_map):
|
||||||
_install_requires = []
|
_install_requires = []
|
||||||
_dependency_links = []
|
_dependency_links = []
|
||||||
with open("./requirements.txt", encoding="utf-8") as requirements_file:
|
with open("./requirements.txt", encoding="utf-8") as requirements_file:
|
||||||
@@ -67,6 +67,7 @@ def parse_requirements():
|
|||||||
if (major, minor) >= (2, 6):
|
if (major, minor) >= (2, 6):
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
_install_requires.append("xformers==0.0.29.post2")
|
_install_requires.append("xformers==0.0.29.post2")
|
||||||
|
extras_require_map["vllm"] = ["vllm==0.8.1"]
|
||||||
elif (major, minor) >= (2, 5):
|
elif (major, minor) >= (2, 5):
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
if patch == 0:
|
if patch == 0:
|
||||||
@@ -86,7 +87,7 @@ def parse_requirements():
|
|||||||
|
|
||||||
except PackageNotFoundError:
|
except PackageNotFoundError:
|
||||||
pass
|
pass
|
||||||
return _install_requires, _dependency_links
|
return _install_requires, _dependency_links, extras_require_map
|
||||||
|
|
||||||
|
|
||||||
def get_package_version():
|
def get_package_version():
|
||||||
@@ -103,7 +104,46 @@ def get_package_version():
|
|||||||
return version_
|
return version_
|
||||||
|
|
||||||
|
|
||||||
install_requires, dependency_links = parse_requirements()
|
extras_require = {
|
||||||
|
"flash-attn": ["flash-attn==2.7.4.post1"],
|
||||||
|
"ring-flash-attn": ["ring-flash-attn>=0.1.4", "yunchang==0.6.0"],
|
||||||
|
"deepspeed": [
|
||||||
|
"deepspeed==0.16.4",
|
||||||
|
"deepspeed-kernels",
|
||||||
|
],
|
||||||
|
"mamba-ssm": [
|
||||||
|
"mamba-ssm==1.2.0.post1",
|
||||||
|
"causal_conv1d",
|
||||||
|
],
|
||||||
|
"auto-gptq": [
|
||||||
|
"auto-gptq==0.5.1",
|
||||||
|
],
|
||||||
|
"mlflow": [
|
||||||
|
"mlflow",
|
||||||
|
],
|
||||||
|
"galore": [
|
||||||
|
"galore_torch",
|
||||||
|
],
|
||||||
|
"apollo": [
|
||||||
|
"apollo-torch",
|
||||||
|
],
|
||||||
|
"optimizers": [
|
||||||
|
"galore_torch",
|
||||||
|
"apollo-torch",
|
||||||
|
"lomo-optim==0.1.1",
|
||||||
|
"torch-optimi==0.2.1",
|
||||||
|
],
|
||||||
|
"ray": [
|
||||||
|
"ray[train]",
|
||||||
|
],
|
||||||
|
"vllm": [
|
||||||
|
"vllm==0.7.2",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
install_requires, dependency_links, extras_require_build = parse_requirements(
|
||||||
|
extras_require
|
||||||
|
)
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
version=get_package_version(),
|
version=get_package_version(),
|
||||||
@@ -116,40 +156,5 @@ setup(
|
|||||||
"axolotl=axolotl.cli.main:main",
|
"axolotl=axolotl.cli.main:main",
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
extras_require={
|
extras_require=extras_require_build,
|
||||||
"flash-attn": ["flash-attn==2.7.4.post1"],
|
|
||||||
"ring-flash-attn": ["ring-flash-attn>=0.1.4", "yunchang==0.6.0"],
|
|
||||||
"deepspeed": [
|
|
||||||
"deepspeed==0.16.4",
|
|
||||||
"deepspeed-kernels",
|
|
||||||
],
|
|
||||||
"mamba-ssm": [
|
|
||||||
"mamba-ssm==1.2.0.post1",
|
|
||||||
"causal_conv1d",
|
|
||||||
],
|
|
||||||
"auto-gptq": [
|
|
||||||
"auto-gptq==0.5.1",
|
|
||||||
],
|
|
||||||
"mlflow": [
|
|
||||||
"mlflow",
|
|
||||||
],
|
|
||||||
"galore": [
|
|
||||||
"galore_torch",
|
|
||||||
],
|
|
||||||
"apollo": [
|
|
||||||
"apollo-torch",
|
|
||||||
],
|
|
||||||
"optimizers": [
|
|
||||||
"galore_torch",
|
|
||||||
"apollo-torch",
|
|
||||||
"lomo-optim==0.1.1",
|
|
||||||
"torch-optimi==0.2.1",
|
|
||||||
],
|
|
||||||
"ray": [
|
|
||||||
"ray[train]",
|
|
||||||
],
|
|
||||||
"vllm": [
|
|
||||||
"vllm==0.7.2",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -35,6 +35,55 @@ class TrainerCliArgs:
|
|||||||
num_processes: Optional[int] = field(default=None)
|
num_processes: Optional[int] = field(default=None)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VllmServeCliArgs:
|
||||||
|
"""Dataclass with CLI arguments for `axolotl vllm-serve` command."""
|
||||||
|
|
||||||
|
tensor_parallel_size: int = field(
|
||||||
|
default=1,
|
||||||
|
metadata={"help": "Number of tensor parallel workers to use."},
|
||||||
|
)
|
||||||
|
host: str = field(
|
||||||
|
default="0.0.0.0", # nosec B104
|
||||||
|
metadata={"help": "Host address to run the server on."},
|
||||||
|
)
|
||||||
|
port: int = field(
|
||||||
|
default=8000,
|
||||||
|
metadata={"help": "Port to run the server on."},
|
||||||
|
)
|
||||||
|
gpu_memory_utilization: Optional[float] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV "
|
||||||
|
"cache on the device dedicated to generation powered by vLLM. Higher values will increase the KV cache "
|
||||||
|
"size and thus improve the model's throughput. However, if the value is too high, it may cause "
|
||||||
|
"out-of-memory (OOM) errors during initialization."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
dtype: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "Data type to use for vLLM generation. If set to 'auto', the data type will be automatically "
|
||||||
|
"determined based on the model configuration. Find the supported values in the vLLM documentation."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
max_model_len: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "If set, the `max_model_len` to use for vLLM. This can be useful when running with reduced "
|
||||||
|
"`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model "
|
||||||
|
"context size, which might be much larger than the KV cache, leading to inefficiencies."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
enable_prefix_caching: Optional[bool] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "Whether to enable prefix caching in vLLM. If set to `True`, ensure that the model and the "
|
||||||
|
"hardware support this feature."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EvaluateCliArgs:
|
class EvaluateCliArgs:
|
||||||
"""Dataclass with CLI arguments for `axolotl evaluate` command."""
|
"""Dataclass with CLI arguments for `axolotl evaluate` command."""
|
||||||
|
|||||||
@@ -14,7 +14,12 @@ import yaml
|
|||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
import axolotl
|
import axolotl
|
||||||
from axolotl.cli.args import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
|
from axolotl.cli.args import (
|
||||||
|
EvaluateCliArgs,
|
||||||
|
PreprocessCliArgs,
|
||||||
|
TrainerCliArgs,
|
||||||
|
VllmServeCliArgs,
|
||||||
|
)
|
||||||
from axolotl.cli.sweeps import generate_sweep_configs
|
from axolotl.cli.sweeps import generate_sweep_configs
|
||||||
from axolotl.cli.utils import (
|
from axolotl.cli.utils import (
|
||||||
add_options_from_config,
|
add_options_from_config,
|
||||||
@@ -23,6 +28,7 @@ from axolotl.cli.utils import (
|
|||||||
fetch_from_github,
|
fetch_from_github,
|
||||||
filter_none_kwargs,
|
filter_none_kwargs,
|
||||||
)
|
)
|
||||||
|
from axolotl.cli.vllm_serve import do_vllm_serve
|
||||||
from axolotl.integrations.lm_eval.cli import lm_eval
|
from axolotl.integrations.lm_eval.cli import lm_eval
|
||||||
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
||||||
from axolotl.utils.schemas.config import AxolotlInputConfig
|
from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||||
@@ -316,6 +322,14 @@ def fetch(directory: str, dest: Optional[str]) -> None:
|
|||||||
fetch_from_github(f"{directory}/", dest)
|
fetch_from_github(f"{directory}/", dest)
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command()
|
||||||
|
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||||
|
@add_options_from_dataclass(VllmServeCliArgs)
|
||||||
|
@filter_none_kwargs
|
||||||
|
def vllm_serve(config: str, **cli_args: VllmServeCliArgs):
|
||||||
|
do_vllm_serve(config, cli_args)
|
||||||
|
|
||||||
|
|
||||||
cli.add_command(lm_eval)
|
cli.add_command(lm_eval)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
55
src/axolotl/cli/vllm_serve.py
Normal file
55
src/axolotl/cli/vllm_serve.py
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
"""
|
||||||
|
CLI to start the vllm server for online RL
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
from trl.scripts.vllm_serve import ScriptArguments
|
||||||
|
from trl.scripts.vllm_serve import main as vllm_serve_main
|
||||||
|
|
||||||
|
from axolotl.cli.config import load_cfg
|
||||||
|
|
||||||
|
|
||||||
|
def do_vllm_serve(
|
||||||
|
config: Union[Path, str],
|
||||||
|
cli_args: dict,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Starts the VLLM server for serving LLM models used for online RL
|
||||||
|
|
||||||
|
Args
|
||||||
|
:param cfg: Parsed doct of the YAML config
|
||||||
|
:param cli_args: dict of additional command-line arguments of type VllmServeCliArgs
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
process_id: the process id of the started VLLM server
|
||||||
|
"""
|
||||||
|
cfg = load_cfg(config)
|
||||||
|
model = cfg.base_model
|
||||||
|
|
||||||
|
tensor_parallel_size = (
|
||||||
|
cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size
|
||||||
|
)
|
||||||
|
host = cli_args.get("host") or cfg.vllm.host
|
||||||
|
port = cli_args.get("port") or cfg.vllm.port
|
||||||
|
gpu_memory_utilization = (
|
||||||
|
cli_args.get("gpu_memory_utilization") or cfg.vllm.gpu_memory_utilization
|
||||||
|
)
|
||||||
|
dtype = cli_args.get("dtype") or cfg.vllm.dtype
|
||||||
|
max_model_len = cli_args.get("max_model_len") or cfg.vllm.max_model_len
|
||||||
|
enable_prefix_caching = (
|
||||||
|
cli_args.get("enable_prefix_caching") or cfg.vllm.enable_prefix_caching
|
||||||
|
)
|
||||||
|
|
||||||
|
vllm_script_args = ScriptArguments(
|
||||||
|
model,
|
||||||
|
tensor_parallel_size=tensor_parallel_size,
|
||||||
|
host=host,
|
||||||
|
port=port,
|
||||||
|
gpu_memory_utilization=gpu_memory_utilization,
|
||||||
|
dtype=dtype,
|
||||||
|
max_model_len=max_model_len,
|
||||||
|
enable_prefix_caching=enable_prefix_caching,
|
||||||
|
)
|
||||||
|
vllm_serve_main(vllm_script_args)
|
||||||
@@ -40,18 +40,15 @@ class GRPOStrategy:
|
|||||||
|
|
||||||
if trl.use_vllm:
|
if trl.use_vllm:
|
||||||
grpo_args_kwargs["use_vllm"] = trl.use_vllm
|
grpo_args_kwargs["use_vllm"] = trl.use_vllm
|
||||||
grpo_args_kwargs["vllm_device"] = (
|
grpo_args_kwargs["vllm_server_host"] = trl.vllm_server_host
|
||||||
trl.vllm_device if trl.vllm_device else "auto"
|
grpo_args_kwargs["vllm_server_port"] = trl.vllm_server_port
|
||||||
)
|
if trl.vllm_server_timeout:
|
||||||
|
grpo_args_kwargs["vllm_server_timeout"] = trl.vllm_server_timeout
|
||||||
if trl.vllm_gpu_memory_utilization:
|
if trl.vllm_guided_decoding_regex:
|
||||||
grpo_args_kwargs["vllm_gpu_memory_utilization"] = (
|
grpo_args_kwargs["vllm_guided_decoding_regex"] = (
|
||||||
trl.vllm_gpu_memory_utilization
|
trl.vllm_guided_decoding_regex
|
||||||
)
|
)
|
||||||
|
|
||||||
if trl.vllm_max_model_len:
|
|
||||||
grpo_args_kwargs["vllm_max_model_len"] = trl.vllm_max_model_len
|
|
||||||
|
|
||||||
if trl.num_generations:
|
if trl.num_generations:
|
||||||
grpo_args_kwargs["num_generations"] = trl.num_generations
|
grpo_args_kwargs["num_generations"] = trl.num_generations
|
||||||
|
|
||||||
@@ -70,6 +67,25 @@ class GRPOStrategy:
|
|||||||
if trl.reward_weights:
|
if trl.reward_weights:
|
||||||
grpo_args_kwargs["reward_weights"] = trl.reward_weights
|
grpo_args_kwargs["reward_weights"] = trl.reward_weights
|
||||||
|
|
||||||
|
if trl.scale_rewards is not None:
|
||||||
|
grpo_args_kwargs["scale_rewards"] = trl.scale_rewards
|
||||||
|
|
||||||
|
if trl.temperature is not None:
|
||||||
|
grpo_args_kwargs["temperature"] = trl.temperature
|
||||||
|
if trl.top_p is not None:
|
||||||
|
grpo_args_kwargs["top_p"] = trl.top_p
|
||||||
|
if trl.top_k is not None:
|
||||||
|
grpo_args_kwargs["top_k"] = trl.top_k
|
||||||
|
if trl.min_p is not None:
|
||||||
|
grpo_args_kwargs["min_p"] = trl.min_p
|
||||||
|
if trl.repetition_penalty is not None:
|
||||||
|
grpo_args_kwargs["repetition_penalty"] = trl.repetition_penalty
|
||||||
|
|
||||||
|
if trl.num_iterations is not None:
|
||||||
|
grpo_args_kwargs["num_iterations"] = trl.num_iterations
|
||||||
|
if trl.epsilon is not None:
|
||||||
|
grpo_args_kwargs["epsilon"] = trl.epsilon
|
||||||
|
|
||||||
return grpo_args_kwargs
|
return grpo_args_kwargs
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -2,16 +2,18 @@
|
|||||||
Axolotl GRPO trainer
|
Axolotl GRPO trainer
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from accelerate.utils import is_peft_model
|
from contextlib import nullcontext
|
||||||
from accelerate.utils.other import is_compiled_module
|
|
||||||
from transformers import PreTrainedModel
|
from accelerate.utils import is_deepspeed_available, is_peft_model
|
||||||
from trl import GRPOConfig, GRPOTrainer
|
from trl import GRPOTrainer
|
||||||
from trl.models import unwrap_model_for_generation
|
from trl.extras.profiling import profiling_decorator
|
||||||
|
|
||||||
from axolotl.core.trainers.base import SchedulerMixin
|
from axolotl.core.trainers.base import SchedulerMixin
|
||||||
|
|
||||||
|
if is_deepspeed_available():
|
||||||
|
import deepspeed
|
||||||
|
|
||||||
|
|
||||||
# mypy: ignore-errors
|
|
||||||
class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
|
class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
|
||||||
"""
|
"""
|
||||||
Extend the base GRPOTrainer for axolotl helpers
|
Extend the base GRPOTrainer for axolotl helpers
|
||||||
@@ -19,91 +21,49 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
|
|||||||
|
|
||||||
_tag_names = ["trl", "grpo", "axolotl"]
|
_tag_names = ["trl", "grpo", "axolotl"]
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
@profiling_decorator
|
||||||
super().__init__(*args, **kwargs)
|
def _move_model_to_vllm(self):
|
||||||
|
# For DeepSpeed ZeRO-3, we need to gather all parameters before operations
|
||||||
# pylint: disable=access-member-before-definition
|
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
|
||||||
# Enable gradient checkpointing if requested
|
zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3
|
||||||
if kwargs["args"].gradient_checkpointing:
|
gather_if_zero3 = (
|
||||||
# Ensure use_cache is disabled
|
deepspeed.zero.GatheredParameters if zero_stage_3 else nullcontext
|
||||||
if hasattr(self.model, "config"):
|
|
||||||
self.model.config.use_cache = False
|
|
||||||
|
|
||||||
# Enable gradient checkpointing on the base model for PEFT
|
|
||||||
if is_peft_model(self.model) and hasattr(
|
|
||||||
self.model.base_model, "gradient_checkpointing_enable"
|
|
||||||
):
|
|
||||||
self.model.base_model.gradient_checkpointing_enable()
|
|
||||||
# Enable gradient checkpointing for non-PEFT models
|
|
||||||
elif hasattr(self.model, "gradient_checkpointing_enable"):
|
|
||||||
self.model.gradient_checkpointing_enable()
|
|
||||||
self.model = self._enable_gradient_checkpointing(self.model, kwargs["args"])
|
|
||||||
# pylint: enable=access-member-before-definition
|
|
||||||
|
|
||||||
def _enable_gradient_checkpointing(
|
|
||||||
self, model: PreTrainedModel, args: GRPOConfig
|
|
||||||
) -> PreTrainedModel:
|
|
||||||
"""Enables gradient checkpointing for the model."""
|
|
||||||
# pylint: disable=unused-argument,redefined-builtin
|
|
||||||
gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
|
|
||||||
use_reentrant = (
|
|
||||||
"use_reentrant" not in gradient_checkpointing_kwargs
|
|
||||||
or gradient_checkpointing_kwargs["use_reentrant"]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if use_reentrant:
|
if is_peft_model(self.model):
|
||||||
if hasattr(model, "enable_input_require_grads"):
|
# With PEFT and DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as merging
|
||||||
model.enable_input_require_grads()
|
# adapters in a sharded manner is not supported.
|
||||||
else:
|
with gather_if_zero3(list(self.model.parameters())):
|
||||||
|
self.model.merge_adapter()
|
||||||
|
|
||||||
def make_inputs_require_grad(module, input, output):
|
# Update vLLM weights while parameters are gathered
|
||||||
output.requires_grad_(True)
|
for name, param in self.model.named_parameters():
|
||||||
|
# When using PEFT, we need to recover the original parameter name and discard some parameters
|
||||||
|
name = (
|
||||||
|
name.removeprefix("base_model.model.")
|
||||||
|
.removeprefix("base_model.model.")
|
||||||
|
.replace(".base_layer", "")
|
||||||
|
)
|
||||||
|
if self.model.prefix in name:
|
||||||
|
continue
|
||||||
|
# When module to save, remove its prefix and discard the original module
|
||||||
|
if "original_module" in name:
|
||||||
|
continue
|
||||||
|
name = name.replace("modules_to_save.default.", "")
|
||||||
|
|
||||||
model.get_input_embeddings().register_forward_hook(
|
if self.accelerator.is_main_process:
|
||||||
make_inputs_require_grad
|
self.vllm_client.update_named_param(name, param.data)
|
||||||
)
|
|
||||||
|
|
||||||
return model
|
# Unmerge adapters while parameters are still gathered
|
||||||
# pylint: enable=unused-argument,redefined-builtin
|
self.model.unmerge_adapter()
|
||||||
|
# Parameters will automatically be repartitioned when exiting the context
|
||||||
|
else:
|
||||||
|
# For non-PEFT models, simply gather and update each parameter individually.
|
||||||
|
for name, param in self.model.named_parameters():
|
||||||
|
with gather_if_zero3([param]):
|
||||||
|
if self.accelerator.is_main_process:
|
||||||
|
self.vllm_client.update_named_param(name, param.data)
|
||||||
|
|
||||||
def _move_model_to_vllm(self):
|
# Reset cache on main process
|
||||||
with unwrap_model_for_generation(
|
if self.accelerator.is_main_process:
|
||||||
self.model,
|
self.vllm_client.reset_prefix_cache()
|
||||||
self.accelerator,
|
|
||||||
gather_deepspeed3_params=self.args.ds3_gather_for_generation,
|
|
||||||
) as unwrapped_model:
|
|
||||||
if is_compiled_module(unwrapped_model):
|
|
||||||
unwrapped_model = (
|
|
||||||
unwrapped_model._orig_mod # pylint: disable=protected-access
|
|
||||||
)
|
|
||||||
if is_peft_model(unwrapped_model):
|
|
||||||
unwrapped_model.merge_adapter()
|
|
||||||
state_dict = unwrapped_model.state_dict()
|
|
||||||
# Remove base_model and base_layer prefixes
|
|
||||||
state_dict = {
|
|
||||||
k.removeprefix("base_model.model.")
|
|
||||||
.removeprefix("base_model.model.")
|
|
||||||
.replace(".base_layer", ""): v
|
|
||||||
for k, v in state_dict.items()
|
|
||||||
}
|
|
||||||
# Remove values with adapter prefix (example: "_lora")
|
|
||||||
state_dict = {
|
|
||||||
k: v
|
|
||||||
for k, v in state_dict.items()
|
|
||||||
if unwrapped_model.prefix not in k
|
|
||||||
}
|
|
||||||
# When module to save, remove its prefix and discard the original module
|
|
||||||
state_dict = {
|
|
||||||
k.replace("modules_to_save.default.", ""): v
|
|
||||||
for k, v in state_dict.items()
|
|
||||||
if "original_module" not in k
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
state_dict = unwrapped_model.state_dict()
|
|
||||||
if self.accelerator.is_main_process:
|
|
||||||
llm_model = (
|
|
||||||
self.llm.llm_engine.model_executor.driver_worker.model_runner.model
|
|
||||||
)
|
|
||||||
llm_model.load_weights(state_dict.items())
|
|
||||||
if is_peft_model(unwrapped_model):
|
|
||||||
unwrapped_model.unmerge_adapter()
|
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ from axolotl.utils.schemas.multimodal import MultiModalConfig
|
|||||||
from axolotl.utils.schemas.peft import LoraConfig, ReLoRAConfig
|
from axolotl.utils.schemas.peft import LoraConfig, ReLoRAConfig
|
||||||
from axolotl.utils.schemas.training import HyperparametersConfig
|
from axolotl.utils.schemas.training import HyperparametersConfig
|
||||||
from axolotl.utils.schemas.trl import TRLConfig
|
from axolotl.utils.schemas.trl import TRLConfig
|
||||||
|
from axolotl.utils.schemas.vllm import VllmConfig
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -86,6 +87,9 @@ class AxolotlInputConfig(
|
|||||||
trl: TRLConfig | None = Field(
|
trl: TRLConfig | None = Field(
|
||||||
default_factory=lambda: TRLConfig(), # pylint: disable=unnecessary-lambda
|
default_factory=lambda: TRLConfig(), # pylint: disable=unnecessary-lambda
|
||||||
)
|
)
|
||||||
|
vllm: VllmConfig | None = Field(
|
||||||
|
default_factory=lambda: VllmConfig(), # pylint: disable=unnecessary-lambda
|
||||||
|
)
|
||||||
reward_model: bool | None = None
|
reward_model: bool | None = None
|
||||||
process_reward_model: bool | None = None
|
process_reward_model: bool | None = None
|
||||||
num_labels: int | None = None
|
num_labels: int | None = None
|
||||||
|
|||||||
@@ -20,27 +20,30 @@ class TRLConfig(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# GRPO specific args
|
# GRPO specific args
|
||||||
# Ref: https://github.com/huggingface/trl/blob/e3244d2d096ff1e2e248c931d06d39e165e20623/trl/trainer/grpo_config.py#L22
|
# Ref: https://github.com/huggingface/trl/blob/26d86757a7c7e24e397ea44f57ecce6031dfac01/trl/trainer/grpo_config.py#L23
|
||||||
use_vllm: bool | None = Field(
|
use_vllm: bool = Field(
|
||||||
default=False,
|
default=False,
|
||||||
json_schema_extra={"description": "Whether to use VLLM for RL training"},
|
json_schema_extra={"description": "Whether to use VLLM for RL training"},
|
||||||
)
|
)
|
||||||
vllm_device: str | None = Field(
|
vllm_server_host: str | None = Field(
|
||||||
default="auto",
|
default="0.0.0.0", # nosec B104
|
||||||
json_schema_extra={"description": "Device to use for VLLM"},
|
json_schema_extra={"description": "Host of the vLLM server to connect to"},
|
||||||
)
|
)
|
||||||
vllm_gpu_memory_utilization: float | None = Field(
|
vllm_server_port: int | None = Field(
|
||||||
default=0.9,
|
default=8000,
|
||||||
json_schema_extra={"description": "GPU memory utilization for VLLM"},
|
json_schema_extra={"description": "Port of the vLLM server to connect to"},
|
||||||
)
|
)
|
||||||
vllm_dtype: str | None = Field(
|
vllm_server_timeout: int | None = Field(
|
||||||
default="auto",
|
|
||||||
json_schema_extra={"description": "Data type for VLLM"},
|
|
||||||
)
|
|
||||||
vllm_max_model_len: int | None = Field(
|
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "Maximum length of the model context for VLLM"
|
"description": "Total timeout duration in seconds to wait for the vLLM server to be up. If the server is not up "
|
||||||
|
"after the timeout, a `ConnectionError` is raised."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
vllm_guided_decoding_regex: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -85,3 +88,48 @@ class TRLConfig(BaseModel):
|
|||||||
"description": "Sync steps for the reference model. Requires `sync_ref_model=True`."
|
"description": "Sync steps for the reference model. Requires `sync_ref_model=True`."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
scale_rewards: bool = Field(
|
||||||
|
default=True,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Whether to scale the rewards for GRPO by dividing them by their standard deviation."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
temperature: float | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={"description": "Sampling temperature for the GRPO policy."},
|
||||||
|
)
|
||||||
|
top_p: float | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Top-p sampling probability for the generation policy."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
top_k: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={"description": "Top-k sampling for the generation policy."},
|
||||||
|
)
|
||||||
|
min_p: float | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Minimum probability for the generation policy."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
repetition_penalty: float | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Float that penalizes new tokens based on whether they appear in the prompt and the generated text so far."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
num_iterations: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Number of iterations per batch (denoted as μ in the algorithm) for GRPO."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
epsilon: float | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Epsilon value for clipping in the GRPO algorithm."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|||||||
38
src/axolotl/utils/schemas/vllm.py
Normal file
38
src/axolotl/utils/schemas/vllm.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
"""
|
||||||
|
Pydantic models for VLLM configuration, used primarily for RL training with TRL + grpo
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class VllmConfig(BaseModel):
|
||||||
|
"""
|
||||||
|
Configuration for VLLM server
|
||||||
|
"""
|
||||||
|
|
||||||
|
device: str | None = Field(
|
||||||
|
default="auto",
|
||||||
|
json_schema_extra={"description": "Device to use for VLLM"},
|
||||||
|
)
|
||||||
|
tensor_parallel_size: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={"description": "Tensor parallel size for VLLM"},
|
||||||
|
)
|
||||||
|
gpu_memory_utilization: float | None = Field(
|
||||||
|
default=0.9,
|
||||||
|
json_schema_extra={"description": "GPU memory utilization for VLLM"},
|
||||||
|
)
|
||||||
|
dtype: str | None = Field(
|
||||||
|
default="auto",
|
||||||
|
json_schema_extra={"description": "Data type for VLLM"},
|
||||||
|
)
|
||||||
|
max_model_len: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Maximum length of the model context for VLLM"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
enable_prefix_caching: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={"description": "Enable prefix caching for VLLM"},
|
||||||
|
)
|
||||||
@@ -100,6 +100,14 @@ def download_argilla_distilabel_capybara_dpo_7k_binarized_dataset():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def download_argilla_distilabel_intel_orca_dpo_dataset():
|
||||||
|
# download the dataset
|
||||||
|
snapshot_download_w_retry(
|
||||||
|
"argilla/distilabel-intel-orca-dpo-pairs", repo_type="dataset"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
def download_argilla_ultrafeedback_binarized_preferences_cleaned_dataset():
|
def download_argilla_ultrafeedback_binarized_preferences_cleaned_dataset():
|
||||||
# download the dataset
|
# download the dataset
|
||||||
|
|||||||
0
tests/e2e/multigpu/solo/__init__.py
Normal file
0
tests/e2e/multigpu/solo/__init__.py
Normal file
294
tests/e2e/multigpu/solo/test_grpo.py
Normal file
294
tests/e2e/multigpu/solo/test_grpo.py
Normal file
@@ -0,0 +1,294 @@
|
|||||||
|
"""
|
||||||
|
GRPO test suite
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import subprocess # nosec B404
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
import yaml
|
||||||
|
from accelerate.test_utils import execute_subprocess_async
|
||||||
|
from transformers.testing_utils import get_torch_dist_unique_port
|
||||||
|
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from tests.e2e.utils import require_vllm
|
||||||
|
|
||||||
|
|
||||||
|
def start_vllm(
|
||||||
|
model: str, env: dict | None = None, wait: int | None = None, quiet=False, **kwargs
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
helper function to start the VLLM server in the background, mostly for testing purposes
|
||||||
|
"""
|
||||||
|
cmd = [sys.executable, "-m", "trl.scripts.vllm_serve", "--model", model]
|
||||||
|
|
||||||
|
if tensor_parallel_size := kwargs.get("tensor_parallel_size"):
|
||||||
|
cmd.extend(["--tensor-parallel-size", str(tensor_parallel_size)])
|
||||||
|
if host := kwargs.get("host"):
|
||||||
|
cmd.extend(["--host", host])
|
||||||
|
if port := kwargs.get("port"):
|
||||||
|
cmd.extend(["--port", str(port)])
|
||||||
|
if gpu_memory_utilization := kwargs.get("gpu_memory_utilization"):
|
||||||
|
cmd.extend(["--gpu-memory-utilization", str(gpu_memory_utilization)])
|
||||||
|
if dtype := kwargs.get("dtype"):
|
||||||
|
cmd.extend(["--dtype", dtype])
|
||||||
|
if max_model_len := kwargs.get("max_model_len"):
|
||||||
|
cmd.extend(["--max-model-len", str(max_model_len)])
|
||||||
|
if kwargs.get("enable_prefix_caching"):
|
||||||
|
cmd.extend(["--enable-prefix-caching", "True"])
|
||||||
|
|
||||||
|
# print out the command to be executed
|
||||||
|
print(" ".join(cmd))
|
||||||
|
|
||||||
|
# start `trl vllm-serve` command in the background and capture the process id
|
||||||
|
process = subprocess.Popen( # pylint: disable=consider-using-with
|
||||||
|
cmd,
|
||||||
|
env=env,
|
||||||
|
stdout=subprocess.DEVNULL if quiet else subprocess.PIPE,
|
||||||
|
stderr=subprocess.DEVNULL if quiet else subprocess.PIPE,
|
||||||
|
) # nosec B603
|
||||||
|
|
||||||
|
# print out the process id so the user can easily kill it later
|
||||||
|
print(f"VLLM server process started (PID: {process.pid})")
|
||||||
|
|
||||||
|
# wait until the http server is ready, even if it 404s, but timeout after 60 seconds
|
||||||
|
started = False
|
||||||
|
if wait and host and port:
|
||||||
|
for _ in range(int(wait)):
|
||||||
|
try:
|
||||||
|
response = requests.get(f"http://{host}:{port}", timeout=1)
|
||||||
|
if int(response.status_code) in [200, 404]:
|
||||||
|
started = True
|
||||||
|
break
|
||||||
|
except requests.exceptions.RequestException:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# also check if the process.pid is still running
|
||||||
|
if not process.poll() is None:
|
||||||
|
break
|
||||||
|
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
if wait and not started:
|
||||||
|
print(
|
||||||
|
f"VLLM server process did not start within {wait} seconds. Please check your server logs."
|
||||||
|
)
|
||||||
|
process.kill()
|
||||||
|
raise RuntimeError(f"VLLM server process did not start within {wait} seconds.")
|
||||||
|
|
||||||
|
# return the process id
|
||||||
|
return process.pid
|
||||||
|
|
||||||
|
|
||||||
|
class TestGRPO:
|
||||||
|
"""
|
||||||
|
Test case for GRPO training using multilpe GPUs
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _utils_write_yaml_and_rewards(self, cfg, temp_dir, suffix=""):
|
||||||
|
# write cfg to yaml file
|
||||||
|
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||||
|
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||||
|
with open(f"rewards_{suffix}.py", "w", encoding="utf-8") as fout:
|
||||||
|
fout.write(
|
||||||
|
"""import random
|
||||||
|
def rand_reward_func(completions, **kwargs) -> list[float]:
|
||||||
|
return [random.uniform(0, 1) for _ in completions]
|
||||||
|
|
||||||
|
def oai_gsm8k_transform(cfg, *args, **kwargs):
|
||||||
|
def transform_fn(example, tokenizer=None):
|
||||||
|
label = example["answer"].split("####")[-1].strip().replace(",", "")
|
||||||
|
return {
|
||||||
|
"prompt": [{"role": "user", "content": example["question"]},],
|
||||||
|
"answer": label,
|
||||||
|
}
|
||||||
|
return transform_fn, {"remove_columns": ["question"]}
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"num_gpus",
|
||||||
|
[1, 2],
|
||||||
|
)
|
||||||
|
@require_vllm
|
||||||
|
def test_llama_dora(self, temp_dir, num_gpus):
|
||||||
|
rnd_reward_suffix = str(random.randint(1000, 9999))
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
|
"chat_template": "llama3",
|
||||||
|
"rl": "grpo",
|
||||||
|
"trl": {
|
||||||
|
"beta": 0.001,
|
||||||
|
"max_completion_length": 256,
|
||||||
|
"use_vllm": True,
|
||||||
|
"num_generations": 4,
|
||||||
|
"reward_funcs": [f"rewards_{rnd_reward_suffix}.rand_reward_func"],
|
||||||
|
},
|
||||||
|
"vllm": {
|
||||||
|
"max_model_len": 800,
|
||||||
|
"enable_prefix_caching": True,
|
||||||
|
},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "openai/gsm8k",
|
||||||
|
"name": "main",
|
||||||
|
"type": f"rewards_{rnd_reward_suffix}.oai_gsm8k_transform",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"adapter": "lora",
|
||||||
|
"lora_r": 8,
|
||||||
|
"lora_alpha": 16,
|
||||||
|
"lora_dropout": 0.05,
|
||||||
|
"lora_target_linear": True,
|
||||||
|
"peft_use_dora": True,
|
||||||
|
"flash_attention": True,
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"special_tokens": {
|
||||||
|
"pad_token": "<|endoftext|>",
|
||||||
|
},
|
||||||
|
"max_steps": 3,
|
||||||
|
"num_epochs": 1,
|
||||||
|
"micro_batch_size": 4,
|
||||||
|
"gradient_accumulation_steps": 2,
|
||||||
|
"warmup_steps": 10,
|
||||||
|
"val_set_size": 0.0,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.0001,
|
||||||
|
"optimizer": "adamw_torch_fused",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"save_safetensors": True,
|
||||||
|
"bf16": "auto",
|
||||||
|
"use_tensorboard": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_reward_suffix)
|
||||||
|
|
||||||
|
current_env = os.environ.copy()
|
||||||
|
env = {
|
||||||
|
"NCCL_P2P_LEVEL": "LOC",
|
||||||
|
**current_env,
|
||||||
|
"CUDA_VISIBLE_DEVICES": "1",
|
||||||
|
}
|
||||||
|
vllm_process_id = start_vllm(
|
||||||
|
cfg.base_model,
|
||||||
|
env=env,
|
||||||
|
quiet=True,
|
||||||
|
wait=120,
|
||||||
|
gpu_memory_utilization=0.15,
|
||||||
|
max_model_len=cfg.vllm.max_model_len,
|
||||||
|
enable_prefix_caching=cfg.vllm.enable_prefix_caching,
|
||||||
|
host="0.0.0.0",
|
||||||
|
port=8000,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
execute_subprocess_async(
|
||||||
|
[
|
||||||
|
"axolotl",
|
||||||
|
"train",
|
||||||
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
|
"--num-processes",
|
||||||
|
str(num_gpus),
|
||||||
|
"--main-process-port",
|
||||||
|
f"{get_torch_dist_unique_port()}",
|
||||||
|
],
|
||||||
|
env={"NCCL_P2P_LEVEL": "LOC", "NCCL_DEBUG": "INFO", **current_env},
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
os.kill(vllm_process_id, 9)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"num_gpus",
|
||||||
|
[1, 2],
|
||||||
|
)
|
||||||
|
@require_vllm
|
||||||
|
def test_llama_fft(self, temp_dir, num_gpus):
|
||||||
|
rnd_reward_suffix = str(random.randint(1000, 9999))
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
|
"chat_template": "llama3",
|
||||||
|
"rl": "grpo",
|
||||||
|
"trl": {
|
||||||
|
"beta": 0.001,
|
||||||
|
"max_completion_length": 256,
|
||||||
|
"use_vllm": True,
|
||||||
|
"num_generations": 4,
|
||||||
|
"reward_funcs": [f"rewards_{rnd_reward_suffix}.rand_reward_func"],
|
||||||
|
},
|
||||||
|
"vllm": {
|
||||||
|
"max_model_len": 800,
|
||||||
|
"enable_prefix_caching": True,
|
||||||
|
},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "openai/gsm8k",
|
||||||
|
"name": "main",
|
||||||
|
"type": f"rewards_{rnd_reward_suffix}.oai_gsm8k_transform",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"flash_attention": True,
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"special_tokens": {
|
||||||
|
"pad_token": "<|endoftext|>",
|
||||||
|
},
|
||||||
|
"max_steps": 3,
|
||||||
|
"num_epochs": 1,
|
||||||
|
"micro_batch_size": 4,
|
||||||
|
"gradient_accumulation_steps": 2,
|
||||||
|
"warmup_steps": 10,
|
||||||
|
"val_set_size": 0.0,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.0001,
|
||||||
|
"optimizer": "adamw_torch_fused",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"save_safetensors": True,
|
||||||
|
"bf16": "auto",
|
||||||
|
"use_tensorboard": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_reward_suffix)
|
||||||
|
|
||||||
|
current_env = os.environ.copy()
|
||||||
|
env = {
|
||||||
|
"NCCL_P2P_LEVEL": "LOC", # nccl can be brittle, assume P2P isn't reliable
|
||||||
|
**current_env,
|
||||||
|
"CUDA_VISIBLE_DEVICES": "1",
|
||||||
|
}
|
||||||
|
vllm_process_id = start_vllm(
|
||||||
|
cfg.base_model,
|
||||||
|
env=env,
|
||||||
|
quiet=True,
|
||||||
|
wait=120,
|
||||||
|
gpu_memory_utilization=0.15,
|
||||||
|
max_model_len=cfg.vllm.max_model_len,
|
||||||
|
enable_prefix_caching=cfg.vllm.enable_prefix_caching,
|
||||||
|
host="0.0.0.0",
|
||||||
|
port=8000,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
execute_subprocess_async(
|
||||||
|
[
|
||||||
|
"axolotl",
|
||||||
|
"train",
|
||||||
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
|
"--num-processes",
|
||||||
|
str(num_gpus),
|
||||||
|
"--main-process-port",
|
||||||
|
f"{get_torch_dist_unique_port()}",
|
||||||
|
],
|
||||||
|
env={"NCCL_P2P_LEVEL": "LOC", "NCCL_DEBUG": "INFO", **current_env},
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
os.kill(vllm_process_id, 9)
|
||||||
@@ -52,9 +52,9 @@ class TestMultiGPUEval:
|
|||||||
},
|
},
|
||||||
],
|
],
|
||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
"max_steps": 5,
|
"max_steps": 2,
|
||||||
"micro_batch_size": 2,
|
"micro_batch_size": 2,
|
||||||
"gradient_accumulation_steps": 4,
|
"gradient_accumulation_steps": 2,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_8bit",
|
"optimizer": "adamw_8bit",
|
||||||
@@ -121,9 +121,9 @@ class TestMultiGPUEval:
|
|||||||
},
|
},
|
||||||
],
|
],
|
||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
"max_steps": 5,
|
"max_steps": 2,
|
||||||
"micro_batch_size": 2,
|
"micro_batch_size": 2,
|
||||||
"gradient_accumulation_steps": 4,
|
"gradient_accumulation_steps": 2,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_8bit",
|
"optimizer": "adamw_8bit",
|
||||||
|
|||||||
@@ -1,175 +0,0 @@
|
|||||||
"""
|
|
||||||
GRPO test suite
|
|
||||||
"""
|
|
||||||
|
|
||||||
import random
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import yaml
|
|
||||||
from accelerate.test_utils import execute_subprocess_async
|
|
||||||
from transformers.testing_utils import get_torch_dist_unique_port
|
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
|
||||||
|
|
||||||
from tests.e2e.utils import require_vllm
|
|
||||||
|
|
||||||
|
|
||||||
class TestGRPO:
|
|
||||||
"""
|
|
||||||
Test case for GRPO training using multilpe GPUs
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _utils_write_yaml_and_rewards(self, cfg, temp_dir, suffix=""):
|
|
||||||
# write cfg to yaml file
|
|
||||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
|
||||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
|
||||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
|
||||||
with open(f"rewards_{suffix}.py", "w", encoding="utf-8") as fout:
|
|
||||||
fout.write(
|
|
||||||
"""import random
|
|
||||||
def rand_reward_func(completions, **kwargs) -> list[float]:
|
|
||||||
return [random.uniform(0, 1) for _ in completions]
|
|
||||||
|
|
||||||
def oai_gsm8k_transform(cfg, *args, **kwargs):
|
|
||||||
def transform_fn(example, tokenizer=None):
|
|
||||||
label = example["answer"].split("####")[-1].strip().replace(",", "")
|
|
||||||
return {
|
|
||||||
"prompt": [{"role": "user", "content": example["question"]},],
|
|
||||||
"answer": label,
|
|
||||||
}
|
|
||||||
return transform_fn, {"remove_columns": ["question"]}
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"num_gpus",
|
|
||||||
[1, 2],
|
|
||||||
)
|
|
||||||
@require_vllm
|
|
||||||
def test_llama_dora(self, temp_dir, num_gpus):
|
|
||||||
rnd_reward_suffix = str(random.randint(1000, 9999))
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
|
||||||
"chat_template": "llama3",
|
|
||||||
"rl": "grpo",
|
|
||||||
"trl": {
|
|
||||||
"beta": 0.001,
|
|
||||||
"max_completion_length": 256,
|
|
||||||
"use_vllm": True,
|
|
||||||
"vllm_device": "auto" if num_gpus == 1 else "cuda:1",
|
|
||||||
"vllm_gpu_memory_utilization": 0.15,
|
|
||||||
"num_generations": 4,
|
|
||||||
"reward_funcs": [f"rewards_{rnd_reward_suffix}.rand_reward_func"],
|
|
||||||
},
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"path": "openai/gsm8k",
|
|
||||||
"name": "main",
|
|
||||||
"type": f"rewards_{rnd_reward_suffix}.oai_gsm8k_transform",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"adapter": "lora",
|
|
||||||
"lora_r": 8,
|
|
||||||
"lora_alpha": 16,
|
|
||||||
"lora_dropout": 0.05,
|
|
||||||
"lora_target_linear": True,
|
|
||||||
"peft_use_dora": True,
|
|
||||||
"flash_attention": True,
|
|
||||||
"sequence_len": 1024,
|
|
||||||
"special_tokens": {
|
|
||||||
"pad_token": "<|endoftext|>",
|
|
||||||
},
|
|
||||||
"max_steps": 5,
|
|
||||||
"num_epochs": 1,
|
|
||||||
"micro_batch_size": 4,
|
|
||||||
"gradient_accumulation_steps": 2,
|
|
||||||
"warmup_steps": 10,
|
|
||||||
"val_set_size": 0.0,
|
|
||||||
"output_dir": temp_dir,
|
|
||||||
"learning_rate": 0.0001,
|
|
||||||
"optimizer": "adamw_torch_fused",
|
|
||||||
"lr_scheduler": "cosine",
|
|
||||||
"save_safetensors": True,
|
|
||||||
"bf16": "auto",
|
|
||||||
"use_tensorboard": True,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_reward_suffix)
|
|
||||||
|
|
||||||
execute_subprocess_async(
|
|
||||||
[
|
|
||||||
"axolotl",
|
|
||||||
"train",
|
|
||||||
str(Path(temp_dir) / "config.yaml"),
|
|
||||||
"--num-processes",
|
|
||||||
str(num_gpus),
|
|
||||||
"--main-process-port",
|
|
||||||
f"{get_torch_dist_unique_port()}",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"num_gpus",
|
|
||||||
[1, 2],
|
|
||||||
)
|
|
||||||
@require_vllm
|
|
||||||
def test_llama_fft(self, temp_dir, num_gpus):
|
|
||||||
rnd_reward_suffix = str(random.randint(1000, 9999))
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
|
||||||
"chat_template": "llama3",
|
|
||||||
"rl": "grpo",
|
|
||||||
"trl": {
|
|
||||||
"beta": 0.001,
|
|
||||||
"max_completion_length": 256,
|
|
||||||
"use_vllm": True,
|
|
||||||
"vllm_device": "auto" if num_gpus == 1 else "cuda:1",
|
|
||||||
"vllm_gpu_memory_utilization": 0.15,
|
|
||||||
"num_generations": 4,
|
|
||||||
"reward_funcs": [f"rewards_{rnd_reward_suffix}.rand_reward_func"],
|
|
||||||
},
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"path": "openai/gsm8k",
|
|
||||||
"name": "main",
|
|
||||||
"type": f"rewards_{rnd_reward_suffix}.oai_gsm8k_transform",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"flash_attention": True,
|
|
||||||
"sequence_len": 1024,
|
|
||||||
"special_tokens": {
|
|
||||||
"pad_token": "<|endoftext|>",
|
|
||||||
},
|
|
||||||
"max_steps": 5,
|
|
||||||
"num_epochs": 1,
|
|
||||||
"micro_batch_size": 4,
|
|
||||||
"gradient_accumulation_steps": 2,
|
|
||||||
"warmup_steps": 10,
|
|
||||||
"val_set_size": 0.0,
|
|
||||||
"output_dir": temp_dir,
|
|
||||||
"learning_rate": 0.0001,
|
|
||||||
"optimizer": "adamw_torch_fused",
|
|
||||||
"lr_scheduler": "cosine",
|
|
||||||
"save_safetensors": True,
|
|
||||||
"bf16": "auto",
|
|
||||||
"use_tensorboard": True,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_reward_suffix)
|
|
||||||
|
|
||||||
execute_subprocess_async(
|
|
||||||
[
|
|
||||||
"axolotl",
|
|
||||||
"train",
|
|
||||||
str(Path(temp_dir) / "config.yaml"),
|
|
||||||
"--num-processes",
|
|
||||||
str(num_gpus),
|
|
||||||
"--main-process-port",
|
|
||||||
f"{get_torch_dist_unique_port()}",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
@@ -399,7 +399,7 @@ class TestMultiGPULlama:
|
|||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
"max_steps": 2,
|
"max_steps": 2,
|
||||||
"micro_batch_size": 4,
|
"micro_batch_size": 4,
|
||||||
"gradient_accumulation_steps": 4,
|
"gradient_accumulation_steps": 2,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_torch_fused",
|
"optimizer": "adamw_torch_fused",
|
||||||
@@ -478,7 +478,7 @@ class TestMultiGPULlama:
|
|||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
"max_steps": 2,
|
"max_steps": 2,
|
||||||
"micro_batch_size": 4,
|
"micro_batch_size": 4,
|
||||||
"gradient_accumulation_steps": 4,
|
"gradient_accumulation_steps": 2,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_torch_fused",
|
"optimizer": "adamw_torch_fused",
|
||||||
@@ -778,7 +778,7 @@ class TestMultiGPULlama:
|
|||||||
},
|
},
|
||||||
],
|
],
|
||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
"max_steps": 5,
|
"max_steps": 2,
|
||||||
"micro_batch_size": 1,
|
"micro_batch_size": 1,
|
||||||
"gradient_accumulation_steps": 1,
|
"gradient_accumulation_steps": 1,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ class TestMultiGPUQwen2:
|
|||||||
},
|
},
|
||||||
],
|
],
|
||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
"max_steps": 5,
|
"max_steps": 2,
|
||||||
"warmup_steps": 20,
|
"warmup_steps": 20,
|
||||||
"micro_batch_size": 2,
|
"micro_batch_size": 2,
|
||||||
"gradient_accumulation_steps": 2,
|
"gradient_accumulation_steps": 2,
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ class TestMultiGPURay:
|
|||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
"max_steps": 2,
|
"max_steps": 2,
|
||||||
"micro_batch_size": 4,
|
"micro_batch_size": 4,
|
||||||
"gradient_accumulation_steps": 4,
|
"gradient_accumulation_steps": 2,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_8bit",
|
"optimizer": "adamw_8bit",
|
||||||
|
|||||||
Reference in New Issue
Block a user