feat: Add GDPO Support (#3353)

* gdpo support - test left

* lint

* fixxes for vllm serv

* test advantages

* docss

* lint

* lint =

* gdpo simple + lint

* lint nit

* example

* lint

* trl 0.27.0

* blocklist

* test assert rmv

* add validation check for GDPO + sum_then_normalize

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
VED
2026-01-22 03:52:45 +05:30
committed by GitHub
parent 8623dd8a72
commit d0d26d5064
11 changed files with 742 additions and 6 deletions

View File

@@ -17,6 +17,7 @@ feedback. Various methods include, but not limited to:
- [Kahneman-Tversky Optimization (KTO)](#kto)
- [Odds Ratio Preference Optimization (ORPO)](#orpo)
- [Group Relative Policy Optimization (GRPO)](#grpo)
- [Group Reward-Decoupled Policy Optimization (GDPO)](#gdpo)
## RLHF using Axolotl
@@ -720,6 +721,102 @@ trl:
For more information, see [GRPO docs](https://huggingface.co/docs/trl/v0.17.0/en/grpo_trainer#loss-types).
### GDPO
GDPO (Group Reward-Decoupled Policy Optimization) extends GRPO for multi-reward training. It addresses the **reward advantage collapse** problem by normalizing each reward function independently before combining them.
::: {.callout-tip}
Use GDPO when training with multiple reward functions. For single reward, GRPO and GDPO produce equivalent results.
:::
Paper: [https://arxiv.org/pdf/2501.05242](https://arxiv.org/pdf/2501.05242)
GDPO uses TRL's native `multi_objective_aggregation` parameter under the hood. When you set `rl: gdpo`, axolotl automatically configures TRL to use `normalize_then_sum` aggregation.
```yaml
base_model: Qwen/Qwen2.5-1.5B-Instruct
vllm:
host: 0.0.0.0
port: 8000
tensor_parallel_size: 2
gpu_memory_utilization: 0.85
rl: gdpo
trl:
beta: 0.001
max_completion_length: 256
use_vllm: true
num_generations: 4
reward_funcs:
- rewards.format_reward
- rewards.correctness_reward
reward_weights: [1.0, 2.0]
datasets:
- path: openai/gsm8k
name: main
type: rewards.oai_gsm8k_transform
```
You can also use GRPO with explicit aggregation control:
```yaml
rl: grpo
trl:
multi_objective_aggregation: normalize_then_sum # GDPO behavior
# or: sum_then_normalize # Default GRPO behavior
```
#### GDPO vs GRPO
| Aspect | GRPO | GDPO |
|--------|------|------|
| **Aggregation** | `sum_then_normalize` | `normalize_then_sum` |
| **Multi-reward** | May collapse advantages | Preserves reward signals |
| **Single reward** | Standard behavior | Equivalent to GRPO |
#### Why GDPO?
When using multiple rewards with GRPO, different reward combinations can produce identical advantages:
```
# Example: format + correctness rewards
[format=0, correct=3] → sum=3
[format=1, correct=2] → sum=3 ← GRPO sees these as equal!
[format=2, correct=1] → sum=3
[format=3, correct=0] → sum=3
```
GDPO normalizes each reward independently, preserving their relative differences.
#### Reward Functions
GDPO uses the same reward function format as GRPO:
```python
# rewards.py
def format_reward(completions, **kwargs) -> list[float]:
return [1.0 if len(c) > 10 else 0.0 for c in completions]
def correctness_reward(completions, answers, **kwargs) -> list[float]:
rewards = []
for completion, answer in zip(completions, answers):
# Your scoring logic here
rewards.append(score)
return rewards
```
#### Sequence Parallelism
GDPO supports sequence parallelism for long-context training:
```yaml
rl: gdpo
context_parallel_size: 2
```
### SimPO
SimPO uses [CPOTrainer](https://huggingface.co/docs/trl/main/en/cpo_trainer) but with alternative loss function.

View File

@@ -0,0 +1,68 @@
base_model: meta-llama/Llama-3.2-1B-Instruct
chat_template: llama3
rl: gdpo
trl:
beta: 0.001
max_completion_length: 128
num_generations: 2
temperature: 0.7
top_p: 0.95
use_vllm: false
multi_objective_aggregation: normalize_then_sum
reward_funcs:
- rwd.format_reward
- rwd.correctness_reward
reward_weights: [1.0, 2.0]
log_completions: true
num_completions_to_print: 3
scale_rewards: true
datasets:
- path: openai/gsm8k
name: main
split: train[:1000]
type: rwd.gsm8k_transform
val_set_size: 0.0
output_dir: ./outputs/llama3-gdpo-out
sequence_len: 512
sample_packing: false
pad_to_sequence_len: false
gradient_accumulation_steps: 8
micro_batch_size: 1
num_epochs: 1
max_steps: 100
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 5e-5
weight_decay: 0.01
warmup_steps: 10
bf16: auto
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
flash_attention: true
logging_steps: 1
save_steps: 50
save_safetensors: true
special_tokens:
pad_token: "<|end_of_text|>"
seed: 42

View File

@@ -17,7 +17,7 @@ transformers==4.57.6
accelerate==1.12.0
datasets==4.5.0
deepspeed>=0.18.3
trl==0.25.1
trl==0.27.0
hf_xet==1.2.0
kernels==0.11.5
trackio>=0.13.0

View File

@@ -52,12 +52,11 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
trainer_cls = None
trainer_cls_args = [self.model]
if self.cfg.rl is RLType.GRPO:
if self.cfg.rl in {RLType.GRPO, RLType.GDPO}:
trainer_cls = GRPOStrategy.get_trainer_class(
sequence_parallel=self.cfg.context_parallel_size > 1
)
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
elif self.cfg.rl in [RLType.DPO, RLType.IPO]:
@@ -147,6 +146,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
elif self.cfg.rl is RLType.KTO:
training_args_cls = AxolotlKTOConfig
# KTOConfig in TRL >= 0.27.0 no longer accepts max_prompt_length
blocklist_args_kwargs = ["max_prompt_length"]
training_args_kwargs["desirable_weight"] = (
self.cfg.kto_desirable_weight or 1.0
@@ -155,10 +156,14 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
self.cfg.kto_undesirable_weight or 1.0
)
elif self.cfg.rl is RLType.GRPO:
elif self.cfg.rl in {RLType.GRPO, RLType.GDPO}:
training_args_cls = GRPOStrategy.get_training_args_class()
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()
if self.cfg.rl is RLType.GDPO:
training_args_kwargs.setdefault(
"multi_objective_aggregation", "normalize_then_sum"
)
elif self.cfg.rl in [RLType.DPO, RLType.IPO]:
training_args_cls = AxolotlDPOConfig

View File

@@ -129,6 +129,11 @@ class GRPOStrategy:
if trl.rollout_func:
grpo_args_kwargs["rollout_func"] = cls.get_rollout_func(trl.rollout_func)
if trl.multi_objective_aggregation is not None:
grpo_args_kwargs["multi_objective_aggregation"] = (
trl.multi_objective_aggregation
)
return grpo_args_kwargs
@classmethod

View File

@@ -173,7 +173,7 @@ def _drop_long_sequences(
return (len_prompt + len_completion) <= sequence_len
if rl is RLType.GRPO:
if rl in {RLType.GRPO, RLType.GDPO}:
return True
raise ValueError("Unknown RL type")

View File

@@ -26,6 +26,7 @@ class RLType(str, Enum):
"""RL trainer type configuration subset"""
DPO = "dpo"
GDPO = "gdpo"
GRPO = "grpo"
IPO = "ipo"
ORPO = "orpo"

View File

@@ -179,3 +179,13 @@ class TRLConfig(BaseModel):
"description": "Path to custom rollout function. Must be importable from current dir."
},
)
multi_objective_aggregation: (
Literal["sum_then_normalize", "normalize_then_sum"] | None
) = Field(
default=None,
json_schema_extra={
"description": "Multi-objective reward aggregation strategy. "
"'sum_then_normalize' (GRPO default): weights and sums rewards first, then normalizes. "
"'normalize_then_sum' (GDPO): normalizes each reward independently, then sums."
},
)

View File

@@ -746,6 +746,19 @@ class RLValidationMixin:
)
return data
@model_validator(mode="before")
@classmethod
def check_gdpo(cls, data):
if (
data.get("rl") == "gdpo"
and data.get("trl", {}).get("multi_objective_aggregation")
== "sum_then_normalize"
):
raise ValueError(
"`multi_objective_aggregation` value set as `sum_then_normalize` => GRPO, but GDPO was selected"
)
return data
class OptimizationValidationMixin:
"""Validation methods related to optimization and performance."""

View File

@@ -311,7 +311,6 @@ class TestHFRLTrainerBuilder:
# KTO specific
assert training_arguments.desirable_weight == 1.0
assert training_arguments.undesirable_weight == 1.0
assert training_arguments.max_prompt_length == 512
def _write_rewards_file(self, rewards_dir: Path):
"""

View File

@@ -0,0 +1,538 @@
"""
GDPO test suite
GDPO uses TRL's multi_objective_aggregation="normalize_then_sum" for
per-reward normalization in multi-reward RL training.
"""
import os
import random
from pathlib import Path
import pytest
import yaml
from accelerate.test_utils import execute_subprocess_async
from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from tests.e2e.multigpu.solo.test_grpo import recursive_kill, start_vllm
from tests.e2e.utils import require_vllm
@pytest.mark.skip(reason="flaky vllm tests in modal")
class TestGDPO:
"""Test case for GDPO training using TRL's native multi-objective aggregation."""
def _utils_write_yaml_and_rewards(self, cfg, temp_dir, suffix=""):
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
with open(f"rewards_gdpo_{suffix}.py", "w", encoding="utf-8") as fout:
fout.write(
"""import random
def format_reward(prompts, completions, **kwargs) -> list[float]:
return [1.0 if len(c) > 10 else 0.0 for c in completions]
def correctness_reward(prompts, completions, **kwargs) -> list[float]:
return [random.uniform(-1, 3) for _ in completions]
def safety_reward(prompts, completions, **kwargs) -> list[float]:
return [1.0 if 'error' not in c.lower() else 0.0 for c in completions]
def single_reward(prompts, completions, **kwargs) -> list[float]:
return [random.uniform(0, 1) for _ in completions]
def oai_gsm8k_transform(cfg, *args, **kwargs):
def transform_fn(example, tokenizer=None):
label = example["answer"].split("####")[-1].strip().replace(",", "")
return {
"prompt": [{"role": "user", "content": example["question"]}],
"answer": label,
}
return transform_fn, {"remove_columns": ["question"]}
"""
)
@pytest.mark.parametrize("num_gpus", [1, 2])
@require_vllm
def test_gdpo_multi_reward_lora(self, temp_dir, num_gpus):
"""Test GDPO with multiple reward functions using LoRA."""
rnd_suffix = str(random.randint(1000, 9999))
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"chat_template": "llama3",
"rl": "gdpo",
"trl": {
"beta": 0.001,
"max_completion_length": 256,
"use_vllm": True,
"num_generations": 4,
"reward_funcs": [
f"rewards_gdpo_{rnd_suffix}.format_reward",
f"rewards_gdpo_{rnd_suffix}.correctness_reward",
],
"reward_weights": [1.0, 2.0],
"scale_rewards": True,
},
"vllm": {
"max_model_len": 800,
"enable_prefix_caching": True,
},
"datasets": [
{
"path": "openai/gsm8k",
"name": "main",
"type": f"rewards_gdpo_{rnd_suffix}.oai_gsm8k_transform",
},
],
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"flash_attention": True,
"sequence_len": 1024,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"max_steps": 3,
"num_epochs": 1,
"micro_batch_size": 4,
"gradient_accumulation_steps": 2,
"warmup_steps": 10,
"val_set_size": 0.0,
"output_dir": temp_dir,
"learning_rate": 0.0001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
"save_first_step": False,
}
)
self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_suffix)
current_env = os.environ.copy()
env = {
"NCCL_P2P_LEVEL": "LOC",
**current_env,
"CUDA_VISIBLE_DEVICES": "1",
}
vllm_process = start_vllm(
cfg.base_model,
env=env,
quiet=True,
wait=300,
gpu_memory_utilization=0.15,
max_model_len=cfg.vllm.max_model_len,
enable_prefix_caching=cfg.vllm.enable_prefix_caching,
host="0.0.0.0",
port=8000,
)
try:
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
str(num_gpus),
"--main-process-port",
f"{get_torch_dist_unique_port()}",
],
env={
"NCCL_P2P_LEVEL": "LOC",
"NCCL_DEBUG": "INFO",
**current_env,
},
)
finally:
recursive_kill(vllm_process)
@require_vllm
def test_gdpo_three_rewards(self, temp_dir):
"""Test GDPO with three reward functions (format, correctness, safety)."""
rnd_suffix = str(random.randint(1000, 9999))
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"chat_template": "llama3",
"rl": "gdpo",
"trl": {
"beta": 0.001,
"max_completion_length": 256,
"use_vllm": True,
"num_generations": 4,
"reward_funcs": [
f"rewards_gdpo_{rnd_suffix}.format_reward",
f"rewards_gdpo_{rnd_suffix}.correctness_reward",
f"rewards_gdpo_{rnd_suffix}.safety_reward",
],
"reward_weights": [1.0, 2.0, 1.5],
},
"vllm": {
"max_model_len": 800,
"enable_prefix_caching": True,
},
"datasets": [
{
"path": "openai/gsm8k",
"name": "main",
"type": f"rewards_gdpo_{rnd_suffix}.oai_gsm8k_transform",
},
],
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"flash_attention": True,
"sequence_len": 1024,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"max_steps": 3,
"num_epochs": 1,
"micro_batch_size": 4,
"gradient_accumulation_steps": 2,
"warmup_steps": 10,
"val_set_size": 0.0,
"output_dir": temp_dir,
"learning_rate": 0.0001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
}
)
self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_suffix)
current_env = os.environ.copy()
env = {
"NCCL_P2P_LEVEL": "LOC",
**current_env,
"CUDA_VISIBLE_DEVICES": "1",
}
vllm_process = start_vllm(
cfg.base_model,
env=env,
quiet=True,
wait=300,
gpu_memory_utilization=0.15,
max_model_len=cfg.vllm.max_model_len,
enable_prefix_caching=cfg.vllm.enable_prefix_caching,
host="0.0.0.0",
port=8000,
)
try:
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"1",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
],
env={
"NCCL_P2P_LEVEL": "LOC",
"NCCL_DEBUG": "INFO",
**current_env,
},
)
finally:
recursive_kill(vllm_process)
@require_vllm
def test_gdpo_single_reward_fallback(self, temp_dir):
"""Test GDPO with single reward."""
rnd_suffix = str(random.randint(1000, 9999))
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"chat_template": "llama3",
"rl": "gdpo",
"trl": {
"beta": 0.001,
"max_completion_length": 256,
"use_vllm": True,
"num_generations": 4,
"reward_funcs": [
f"rewards_gdpo_{rnd_suffix}.single_reward",
],
"reward_weights": [1.0],
},
"vllm": {
"max_model_len": 800,
"enable_prefix_caching": True,
},
"datasets": [
{
"path": "openai/gsm8k",
"name": "main",
"type": f"rewards_gdpo_{rnd_suffix}.oai_gsm8k_transform",
},
],
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"flash_attention": True,
"sequence_len": 1024,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"max_steps": 3,
"num_epochs": 1,
"micro_batch_size": 4,
"gradient_accumulation_steps": 2,
"warmup_steps": 10,
"val_set_size": 0.0,
"output_dir": temp_dir,
"learning_rate": 0.0001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
}
)
self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_suffix)
current_env = os.environ.copy()
env = {
"NCCL_P2P_LEVEL": "LOC",
**current_env,
"CUDA_VISIBLE_DEVICES": "1",
}
vllm_process = start_vllm(
cfg.base_model,
env=env,
quiet=True,
wait=300,
gpu_memory_utilization=0.15,
max_model_len=cfg.vllm.max_model_len,
enable_prefix_caching=cfg.vllm.enable_prefix_caching,
host="0.0.0.0",
port=8000,
)
try:
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"1",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
],
env={
"NCCL_P2P_LEVEL": "LOC",
"NCCL_DEBUG": "INFO",
**current_env,
},
)
finally:
recursive_kill(vllm_process)
@require_vllm
def test_gdpo_fft(self, temp_dir):
"""Test GDPO with full fine-tuning (no adapter)."""
rnd_suffix = str(random.randint(1000, 9999))
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"chat_template": "llama3",
"rl": "gdpo",
"trl": {
"beta": 0.001,
"max_completion_length": 256,
"use_vllm": True,
"num_generations": 4,
"reward_funcs": [
f"rewards_gdpo_{rnd_suffix}.format_reward",
f"rewards_gdpo_{rnd_suffix}.correctness_reward",
],
"reward_weights": [1.0, 2.0],
},
"vllm": {
"max_model_len": 800,
"enable_prefix_caching": True,
},
"datasets": [
{
"path": "openai/gsm8k",
"name": "main",
"type": f"rewards_gdpo_{rnd_suffix}.oai_gsm8k_transform",
},
],
# No adapter - full fine-tuning
"flash_attention": True,
"sequence_len": 1024,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"max_steps": 3,
"num_epochs": 1,
"micro_batch_size": 4,
"gradient_accumulation_steps": 2,
"warmup_steps": 10,
"val_set_size": 0.0,
"output_dir": temp_dir,
"learning_rate": 0.0001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
}
)
self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_suffix)
current_env = os.environ.copy()
env = {
"NCCL_P2P_LEVEL": "LOC",
**current_env,
"CUDA_VISIBLE_DEVICES": "1",
}
vllm_process = start_vllm(
cfg.base_model,
env=env,
quiet=True,
wait=300,
gpu_memory_utilization=0.15,
max_model_len=cfg.vllm.max_model_len,
enable_prefix_caching=cfg.vllm.enable_prefix_caching,
host="0.0.0.0",
port=8000,
)
try:
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"1",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
],
env={
"NCCL_P2P_LEVEL": "LOC",
"NCCL_DEBUG": "INFO",
**current_env,
},
)
finally:
recursive_kill(vllm_process)
@require_vllm
def test_gdpo_sequence_parallel(self, temp_dir):
"""Test GDPO with sequence parallelism."""
rnd_suffix = str(random.randint(1000, 9999))
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"chat_template": "llama3",
"rl": "gdpo",
"context_parallel_size": 2,
"trl": {
"beta": 0.001,
"max_completion_length": 256,
"use_vllm": True,
"num_generations": 4,
"reward_funcs": [
f"rewards_gdpo_{rnd_suffix}.format_reward",
f"rewards_gdpo_{rnd_suffix}.correctness_reward",
],
"reward_weights": [1.0, 2.0],
},
"vllm": {
"max_model_len": 800,
"enable_prefix_caching": True,
},
"datasets": [
{
"path": "openai/gsm8k",
"name": "main",
"type": f"rewards_gdpo_{rnd_suffix}.oai_gsm8k_transform",
},
],
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"flash_attention": True,
"sequence_len": 1024,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"max_steps": 3,
"num_epochs": 1,
"micro_batch_size": 4,
"gradient_accumulation_steps": 2,
"warmup_steps": 10,
"val_set_size": 0.0,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.0001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
}
)
self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_suffix)
current_env = os.environ.copy()
env = {
"NCCL_P2P_LEVEL": "LOC",
**current_env,
"CUDA_VISIBLE_DEVICES": "1",
}
vllm_process = start_vllm(
cfg.base_model,
env=env,
quiet=True,
wait=300,
gpu_memory_utilization=0.15,
max_model_len=cfg.vllm.max_model_len,
enable_prefix_caching=cfg.vllm.enable_prefix_caching,
host="0.0.0.0",
port=8000,
)
try:
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"2",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
],
env={
"NCCL_P2P_LEVEL": "LOC",
"NCCL_DEBUG": "INFO",
**current_env,
},
)
finally:
recursive_kill(vllm_process)