additional args for grpo config/trainer (#2598)

This commit is contained in:
Wing Lian
2025-04-30 13:11:12 -04:00
committed by GitHub
parent 5e949eaa07
commit 24ff5f53f8
3 changed files with 46 additions and 0 deletions

View File

@@ -70,6 +70,13 @@ class GRPOStrategy:
if trl.scale_rewards is not None:
grpo_args_kwargs["scale_rewards"] = trl.scale_rewards
if trl.loss_type is not None:
grpo_args_kwargs["loss_type"] = trl.loss_type
if trl.mask_truncated_completions is not None:
grpo_args_kwargs["mask_truncated_completions"] = (
trl.mask_truncated_completions
)
if trl.temperature is not None:
grpo_args_kwargs["temperature"] = trl.temperature
if trl.top_p is not None:
@@ -85,6 +92,11 @@ class GRPOStrategy:
grpo_args_kwargs["num_iterations"] = trl.num_iterations
if trl.epsilon is not None:
grpo_args_kwargs["epsilon"] = trl.epsilon
if trl.epsilon_high is not None:
grpo_args_kwargs["epsilon_high"] = trl.epsilon_high
if trl.use_liger_loss is not None:
grpo_args_kwargs["use_liger_loss"] = trl.use_liger_loss
return grpo_args_kwargs

View File

@@ -1150,6 +1150,18 @@ class AxolotlInputConfig(
return data
@model_validator(mode="before")
@classmethod
def check_grpo_peft_liger(cls, data):
if (
data.get("rl") == "grpo"
and data.get("trl", {})
and data.get("trl").get("use_liger_loss")
and data.get("adapter")
):
raise ValueError("PEFT + GRPO + Liger is not yet supported")
return data
@model_validator(mode="after")
def check_sequence_parallel_degree(self):
if not self.sequence_parallel_degree:

View File

@@ -133,3 +133,25 @@ class TRLConfig(BaseModel):
"description": "Epsilon value for clipping in the GRPO algorithm."
},
)
epsilon_high: float | None = Field(
default=None,
json_schema_extra={
"description": "Upper-bound epsilon value for clipping in the GRPO algorithm."
},
)
use_liger_loss: bool | None = Field(
default=None,
json_schema_extra={"description": "Whether to use Liger loss for GRPO."},
)
loss_type: str | None = Field(
default=None,
json_schema_extra={
"description": "Specifies the loss formulation to use. Supported values are `grpo`, `bnpo`, and `dr_grpo`."
},
)
mask_truncated_completions: bool = Field(
default=False,
json_schema_extra={
"description": "When enabled, truncated completions are excluded from the loss calculation."
},
)