Updates for trl 0.16.0 - mostly for GRPO (#2437) [skip ci]

* add grpo scale_rewards config for trl#3135

* options to connect to vllm server directly w grpo trl#3094

* temperature support trl#3029

* sampling/generation kwargs for grpo trl#2989

* make vllm_enable_prefix_caching a config param trl#2900

* grpo multi-step optimizeations trl#2899

* remove overrides for grpo trainer

* bump trl to 0.16.0

* add cli  to start vllm-serve via trl

* call the python module directly

* update to use vllm with 2.6.0 too now and call trl vllm serve from module

* vllm 0.8.1

* use python3

* use sys.executable

* remove context and wait for start

* fixes to make it actually work

* fixes so the grpo tests pass with new vllm paradigm

* explicit host/port and check in start vllm

* make sure that vllm doesn't hang by setting quiet so outouts go to dev null

* also bump bnb to latest release

* add option for wait from cli and nccl debugging for ci

* grpo + vllm test on separate devices for now

* make sure grpo + vllm tests runs single worker since pynccl comms would conflict

* fix cli

* remove wait and add caching for argilla dataset

* refactoring configs

* chore: lint

* add vllm config

* fixup vllm grpo args

* fix one more incorrect schema/config path

* fix another vlllm reference and increase timeout

* make the tests run a bit faster

* change mbsz back so it is correct for grpo

* another change mbsz back so it is correct for grpo

* fixing cli args

* nits

* adding docs

* docs

* include tensor parallel size for vllm in pydantic schema

* moving start_vllm, more docs

* limit output len for grpo vllm

* vllm enable_prefix_caching isn't a bool cli arg

* fix env ordering in tests and also use pid check when looking for vllm

---------

Co-authored-by: Salman Mohammadi <salman.mohammadi@outlook.com>
This commit is contained in:
Wing Lian
2025-03-31 15:47:11 -04:00
committed by GitHub
parent b35992262e
commit b6fc46ada8
24 changed files with 703 additions and 349 deletions

View File

@@ -238,10 +238,10 @@ simpo_gamma: 0.5 # Target reward margin for the SimPO loss
# grpo
trl:
use_vllm: # Optional[bool]. Whether to use VLLM for RL training.
vllm_device: # Optional[str]. Device to use for VLLM.
vllm_gpu_memory_utilization: # Optional[float]. GPU memory utilization for VLLM.
vllm_max_model_len: # Optional[int]. Maximum length of the model for VLLM.
vllm_dtype: # Optional[str]. Data type for VLLM.
vllm_server_host: # Optional[str]. Host of the vLLM server to connect to.
vllm_server_port: # Optional[int]. Port of the vLLM server to connect to.
vllm_server_timeout: # Optional[int]. Total timeout (in seconds) to wait for the vLLM server to respond.
vllm_guided_decoding_regex: # Optional[str]. Regex for vLLM guided decoding.
beta: # Optional[float]. Beta parameter for the RL training. Same as `rl_beta`. Use
max_completion_length: # Optional[int]. Maximum length of the completion for RL training.

View File

@@ -502,9 +502,48 @@ The input format is a simple JSON input with customizable fields based on the ab
Check out our [GRPO cookbook](https://github.com/axolotl-ai-cloud/axolotl-cookbook/tree/main/grpo#training-an-r1-style-large-language-model-using-grpo).
:::
If you have multiple GPUs available, we reccomend using `vLLM` with the `GRPOTrainer` to significantly speedup trajectory generation during training.
First, launch a `vLLM` server using `trl vllm-serve` - you may use a config file or CLI overrides to configure your vLLM server. In this example, we're
using 4 GPUs - 2 for training, and 2 for vLLM:
::: {.callout-important}
Make sure you've installed the correct version of vLLM by including it as an extra when installing axolotl, e.g. `pip install axolotl[vllm]`.
:::
```yaml
base_model: Qwen/Qwen2.5-1.5B-Instruct
vllm:
host: 0.0.0.0
port: 8000
tensor_parallel_size: 2
gpu_memory_utilization: 0.85
dtype: auto
# max_model_len: # you may find it useful to set the vLLM model context length if you know this beforehand
rl: grpo
trl:
use_vllm: true
vllm_server_host: 0.0.0.0
vllm_server_port: 8000
vllm_server_timeout: 300
```
```bash
CUDA_VISIBLE_DEVICES=2,3 axolotl vllm_serve grpo.yaml
```
Your `vLLM` instance will now attempt to spin up, and it's time to kick off training utilizing our remaining two GPUs. In another terminal, execute:
```bash
CUDA_VISIBLE_DEVICES=0,1 axolotl train grpo.yaml --num-processes 2
```
#### Reward functions
GRPO uses custom reward functions and transformations. Please have them ready locally.
For ex, to load OpenAI's GSM8K and use a random reward for completions:
For example, to load OpenAI's GSM8K and use a random reward for completions:
```python
# rewards.py
@@ -530,8 +569,6 @@ trl:
beta: 0.001
max_completion_length: 256
use_vllm: True
vllm_device: auto
vllm_gpu_memory_utilization: 0.15
num_generations: 4
reward_funcs: ["rewards.rand_reward_func"] # format: '{file_name}.{fn_name}'
reward_weights: [1.0]