additional args for grpo config/trainer (#2598)
This commit is contained in:
@@ -70,6 +70,13 @@ class GRPOStrategy:
|
|||||||
if trl.scale_rewards is not None:
|
if trl.scale_rewards is not None:
|
||||||
grpo_args_kwargs["scale_rewards"] = trl.scale_rewards
|
grpo_args_kwargs["scale_rewards"] = trl.scale_rewards
|
||||||
|
|
||||||
|
if trl.loss_type is not None:
|
||||||
|
grpo_args_kwargs["loss_type"] = trl.loss_type
|
||||||
|
if trl.mask_truncated_completions is not None:
|
||||||
|
grpo_args_kwargs["mask_truncated_completions"] = (
|
||||||
|
trl.mask_truncated_completions
|
||||||
|
)
|
||||||
|
|
||||||
if trl.temperature is not None:
|
if trl.temperature is not None:
|
||||||
grpo_args_kwargs["temperature"] = trl.temperature
|
grpo_args_kwargs["temperature"] = trl.temperature
|
||||||
if trl.top_p is not None:
|
if trl.top_p is not None:
|
||||||
@@ -85,6 +92,11 @@ class GRPOStrategy:
|
|||||||
grpo_args_kwargs["num_iterations"] = trl.num_iterations
|
grpo_args_kwargs["num_iterations"] = trl.num_iterations
|
||||||
if trl.epsilon is not None:
|
if trl.epsilon is not None:
|
||||||
grpo_args_kwargs["epsilon"] = trl.epsilon
|
grpo_args_kwargs["epsilon"] = trl.epsilon
|
||||||
|
if trl.epsilon_high is not None:
|
||||||
|
grpo_args_kwargs["epsilon_high"] = trl.epsilon_high
|
||||||
|
|
||||||
|
if trl.use_liger_loss is not None:
|
||||||
|
grpo_args_kwargs["use_liger_loss"] = trl.use_liger_loss
|
||||||
|
|
||||||
return grpo_args_kwargs
|
return grpo_args_kwargs
|
||||||
|
|
||||||
|
|||||||
@@ -1150,6 +1150,18 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_grpo_peft_liger(cls, data):
|
||||||
|
if (
|
||||||
|
data.get("rl") == "grpo"
|
||||||
|
and data.get("trl", {})
|
||||||
|
and data.get("trl").get("use_liger_loss")
|
||||||
|
and data.get("adapter")
|
||||||
|
):
|
||||||
|
raise ValueError("PEFT + GRPO + Liger is not yet supported")
|
||||||
|
return data
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def check_sequence_parallel_degree(self):
|
def check_sequence_parallel_degree(self):
|
||||||
if not self.sequence_parallel_degree:
|
if not self.sequence_parallel_degree:
|
||||||
|
|||||||
@@ -133,3 +133,25 @@ class TRLConfig(BaseModel):
|
|||||||
"description": "Epsilon value for clipping in the GRPO algorithm."
|
"description": "Epsilon value for clipping in the GRPO algorithm."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
epsilon_high: float | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Upper-bound epsilon value for clipping in the GRPO algorithm."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
use_liger_loss: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={"description": "Whether to use Liger loss for GRPO."},
|
||||||
|
)
|
||||||
|
loss_type: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Specifies the loss formulation to use. Supported values are `grpo`, `bnpo`, and `dr_grpo`."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
mask_truncated_completions: bool = Field(
|
||||||
|
default=False,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "When enabled, truncated completions are excluded from the loss calculation."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user