additional args for grpo config/trainer (#2598)
This commit is contained in:
@@ -70,6 +70,13 @@ class GRPOStrategy:
|
||||
if trl.scale_rewards is not None:
|
||||
grpo_args_kwargs["scale_rewards"] = trl.scale_rewards
|
||||
|
||||
if trl.loss_type is not None:
|
||||
grpo_args_kwargs["loss_type"] = trl.loss_type
|
||||
if trl.mask_truncated_completions is not None:
|
||||
grpo_args_kwargs["mask_truncated_completions"] = (
|
||||
trl.mask_truncated_completions
|
||||
)
|
||||
|
||||
if trl.temperature is not None:
|
||||
grpo_args_kwargs["temperature"] = trl.temperature
|
||||
if trl.top_p is not None:
|
||||
@@ -85,6 +92,11 @@ class GRPOStrategy:
|
||||
grpo_args_kwargs["num_iterations"] = trl.num_iterations
|
||||
if trl.epsilon is not None:
|
||||
grpo_args_kwargs["epsilon"] = trl.epsilon
|
||||
if trl.epsilon_high is not None:
|
||||
grpo_args_kwargs["epsilon_high"] = trl.epsilon_high
|
||||
|
||||
if trl.use_liger_loss is not None:
|
||||
grpo_args_kwargs["use_liger_loss"] = trl.use_liger_loss
|
||||
|
||||
return grpo_args_kwargs
|
||||
|
||||
|
||||
@@ -1150,6 +1150,18 @@ class AxolotlInputConfig(
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_grpo_peft_liger(cls, data):
|
||||
if (
|
||||
data.get("rl") == "grpo"
|
||||
and data.get("trl", {})
|
||||
and data.get("trl").get("use_liger_loss")
|
||||
and data.get("adapter")
|
||||
):
|
||||
raise ValueError("PEFT + GRPO + Liger is not yet supported")
|
||||
return data
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_sequence_parallel_degree(self):
|
||||
if not self.sequence_parallel_degree:
|
||||
|
||||
@@ -133,3 +133,25 @@ class TRLConfig(BaseModel):
|
||||
"description": "Epsilon value for clipping in the GRPO algorithm."
|
||||
},
|
||||
)
|
||||
epsilon_high: float | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Upper-bound epsilon value for clipping in the GRPO algorithm."
|
||||
},
|
||||
)
|
||||
use_liger_loss: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Whether to use Liger loss for GRPO."},
|
||||
)
|
||||
loss_type: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Specifies the loss formulation to use. Supported values are `grpo`, `bnpo`, and `dr_grpo`."
|
||||
},
|
||||
)
|
||||
mask_truncated_completions: bool = Field(
|
||||
default=False,
|
||||
json_schema_extra={
|
||||
"description": "When enabled, truncated completions are excluded from the loss calculation."
|
||||
},
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user