feat: Add GDPO Support (#3353)

* gdpo support - test left

* lint

* fixxes for vllm serv

* test advantages

* docss

* lint

* lint =

* gdpo simple + lint

* lint nit

* example

* lint

* trl 0.27.0

* blocklist

* test assert rmv

* add validation check for GDPO + sum_then_normalize

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
VED
2026-01-22 03:52:45 +05:30
committed by GitHub
parent 8623dd8a72
commit d0d26d5064
11 changed files with 742 additions and 6 deletions

View File

@@ -52,12 +52,11 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
trainer_cls = None
trainer_cls_args = [self.model]
if self.cfg.rl is RLType.GRPO:
if self.cfg.rl in {RLType.GRPO, RLType.GDPO}:
trainer_cls = GRPOStrategy.get_trainer_class(
sequence_parallel=self.cfg.context_parallel_size > 1
)
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
elif self.cfg.rl in [RLType.DPO, RLType.IPO]:
@@ -147,6 +146,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
elif self.cfg.rl is RLType.KTO:
training_args_cls = AxolotlKTOConfig
# KTOConfig in TRL >= 0.27.0 no longer accepts max_prompt_length
blocklist_args_kwargs = ["max_prompt_length"]
training_args_kwargs["desirable_weight"] = (
self.cfg.kto_desirable_weight or 1.0
@@ -155,10 +156,14 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
self.cfg.kto_undesirable_weight or 1.0
)
elif self.cfg.rl is RLType.GRPO:
elif self.cfg.rl in {RLType.GRPO, RLType.GDPO}:
training_args_cls = GRPOStrategy.get_training_args_class()
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()
if self.cfg.rl is RLType.GDPO:
training_args_kwargs.setdefault(
"multi_objective_aggregation", "normalize_then_sum"
)
elif self.cfg.rl in [RLType.DPO, RLType.IPO]:
training_args_cls = AxolotlDPOConfig

View File

@@ -129,6 +129,11 @@ class GRPOStrategy:
if trl.rollout_func:
grpo_args_kwargs["rollout_func"] = cls.get_rollout_func(trl.rollout_func)
if trl.multi_objective_aggregation is not None:
grpo_args_kwargs["multi_objective_aggregation"] = (
trl.multi_objective_aggregation
)
return grpo_args_kwargs
@classmethod

View File

@@ -173,7 +173,7 @@ def _drop_long_sequences(
return (len_prompt + len_completion) <= sequence_len
if rl is RLType.GRPO:
if rl in {RLType.GRPO, RLType.GDPO}:
return True
raise ValueError("Unknown RL type")

View File

@@ -26,6 +26,7 @@ class RLType(str, Enum):
"""RL trainer type configuration subset"""
DPO = "dpo"
GDPO = "gdpo"
GRPO = "grpo"
IPO = "ipo"
ORPO = "orpo"

View File

@@ -179,3 +179,13 @@ class TRLConfig(BaseModel):
"description": "Path to custom rollout function. Must be importable from current dir."
},
)
multi_objective_aggregation: (
Literal["sum_then_normalize", "normalize_then_sum"] | None
) = Field(
default=None,
json_schema_extra={
"description": "Multi-objective reward aggregation strategy. "
"'sum_then_normalize' (GRPO default): weights and sums rewards first, then normalizes. "
"'normalize_then_sum' (GDPO): normalizes each reward independently, then sums."
},
)

View File

@@ -746,6 +746,19 @@ class RLValidationMixin:
)
return data
@model_validator(mode="before")
@classmethod
def check_gdpo(cls, data):
if (
data.get("rl") == "gdpo"
and data.get("trl", {}).get("multi_objective_aggregation")
== "sum_then_normalize"
):
raise ValueError(
"`multi_objective_aggregation` value set as `sum_then_normalize` => GRPO, but GDPO was selected"
)
return data
class OptimizationValidationMixin:
"""Validation methods related to optimization and performance."""