feat: Add GDPO Support (#3353)
* gdpo support - test left * lint * fixxes for vllm serv * test advantages * docss * lint * lint = * gdpo simple + lint * lint nit * example * lint * trl 0.27.0 * blocklist * test assert rmv * add validation check for GDPO + sum_then_normalize --------- Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
@@ -52,12 +52,11 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
trainer_cls = None
|
||||
trainer_cls_args = [self.model]
|
||||
|
||||
if self.cfg.rl is RLType.GRPO:
|
||||
if self.cfg.rl in {RLType.GRPO, RLType.GDPO}:
|
||||
trainer_cls = GRPOStrategy.get_trainer_class(
|
||||
sequence_parallel=self.cfg.context_parallel_size > 1
|
||||
)
|
||||
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
|
||||
|
||||
trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
|
||||
|
||||
elif self.cfg.rl in [RLType.DPO, RLType.IPO]:
|
||||
@@ -147,6 +146,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
|
||||
elif self.cfg.rl is RLType.KTO:
|
||||
training_args_cls = AxolotlKTOConfig
|
||||
# KTOConfig in TRL >= 0.27.0 no longer accepts max_prompt_length
|
||||
blocklist_args_kwargs = ["max_prompt_length"]
|
||||
|
||||
training_args_kwargs["desirable_weight"] = (
|
||||
self.cfg.kto_desirable_weight or 1.0
|
||||
@@ -155,10 +156,14 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
self.cfg.kto_undesirable_weight or 1.0
|
||||
)
|
||||
|
||||
elif self.cfg.rl is RLType.GRPO:
|
||||
elif self.cfg.rl in {RLType.GRPO, RLType.GDPO}:
|
||||
training_args_cls = GRPOStrategy.get_training_args_class()
|
||||
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
|
||||
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()
|
||||
if self.cfg.rl is RLType.GDPO:
|
||||
training_args_kwargs.setdefault(
|
||||
"multi_objective_aggregation", "normalize_then_sum"
|
||||
)
|
||||
|
||||
elif self.cfg.rl in [RLType.DPO, RLType.IPO]:
|
||||
training_args_cls = AxolotlDPOConfig
|
||||
|
||||
@@ -129,6 +129,11 @@ class GRPOStrategy:
|
||||
if trl.rollout_func:
|
||||
grpo_args_kwargs["rollout_func"] = cls.get_rollout_func(trl.rollout_func)
|
||||
|
||||
if trl.multi_objective_aggregation is not None:
|
||||
grpo_args_kwargs["multi_objective_aggregation"] = (
|
||||
trl.multi_objective_aggregation
|
||||
)
|
||||
|
||||
return grpo_args_kwargs
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -173,7 +173,7 @@ def _drop_long_sequences(
|
||||
|
||||
return (len_prompt + len_completion) <= sequence_len
|
||||
|
||||
if rl is RLType.GRPO:
|
||||
if rl in {RLType.GRPO, RLType.GDPO}:
|
||||
return True
|
||||
|
||||
raise ValueError("Unknown RL type")
|
||||
|
||||
@@ -26,6 +26,7 @@ class RLType(str, Enum):
|
||||
"""RL trainer type configuration subset"""
|
||||
|
||||
DPO = "dpo"
|
||||
GDPO = "gdpo"
|
||||
GRPO = "grpo"
|
||||
IPO = "ipo"
|
||||
ORPO = "orpo"
|
||||
|
||||
@@ -179,3 +179,13 @@ class TRLConfig(BaseModel):
|
||||
"description": "Path to custom rollout function. Must be importable from current dir."
|
||||
},
|
||||
)
|
||||
multi_objective_aggregation: (
|
||||
Literal["sum_then_normalize", "normalize_then_sum"] | None
|
||||
) = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Multi-objective reward aggregation strategy. "
|
||||
"'sum_then_normalize' (GRPO default): weights and sums rewards first, then normalizes. "
|
||||
"'normalize_then_sum' (GDPO): normalizes each reward independently, then sums."
|
||||
},
|
||||
)
|
||||
|
||||
@@ -746,6 +746,19 @@ class RLValidationMixin:
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_gdpo(cls, data):
|
||||
if (
|
||||
data.get("rl") == "gdpo"
|
||||
and data.get("trl", {}).get("multi_objective_aggregation")
|
||||
== "sum_then_normalize"
|
||||
):
|
||||
raise ValueError(
|
||||
"`multi_objective_aggregation` value set as `sum_then_normalize` => GRPO, but GDPO was selected"
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
class OptimizationValidationMixin:
|
||||
"""Validation methods related to optimization and performance."""
|
||||
|
||||
Reference in New Issue
Block a user