diff --git a/src/axolotl/core/trainers/grpo/__init__.py b/src/axolotl/core/trainers/grpo/__init__.py index c82fe69f2..b64d087b0 100644 --- a/src/axolotl/core/trainers/grpo/__init__.py +++ b/src/axolotl/core/trainers/grpo/__init__.py @@ -70,6 +70,13 @@ class GRPOStrategy: if trl.scale_rewards is not None: grpo_args_kwargs["scale_rewards"] = trl.scale_rewards + if trl.loss_type is not None: + grpo_args_kwargs["loss_type"] = trl.loss_type + if trl.mask_truncated_completions is not None: + grpo_args_kwargs["mask_truncated_completions"] = ( + trl.mask_truncated_completions + ) + if trl.temperature is not None: grpo_args_kwargs["temperature"] = trl.temperature if trl.top_p is not None: @@ -85,6 +92,11 @@ class GRPOStrategy: grpo_args_kwargs["num_iterations"] = trl.num_iterations if trl.epsilon is not None: grpo_args_kwargs["epsilon"] = trl.epsilon + if trl.epsilon_high is not None: + grpo_args_kwargs["epsilon_high"] = trl.epsilon_high + + if trl.use_liger_loss is not None: + grpo_args_kwargs["use_liger_loss"] = trl.use_liger_loss return grpo_args_kwargs diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 7f35de81d..bc7ee7e72 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -1150,6 +1150,18 @@ class AxolotlInputConfig( return data + @model_validator(mode="before") + @classmethod + def check_grpo_peft_liger(cls, data): + if ( + data.get("rl") == "grpo" + and data.get("trl", {}) + and data.get("trl").get("use_liger_loss") + and data.get("adapter") + ): + raise ValueError("PEFT + GRPO + Liger is not yet supported") + return data + @model_validator(mode="after") def check_sequence_parallel_degree(self): if not self.sequence_parallel_degree: diff --git a/src/axolotl/utils/schemas/trl.py b/src/axolotl/utils/schemas/trl.py index a051fb0ab..c581b265e 100644 --- a/src/axolotl/utils/schemas/trl.py +++ b/src/axolotl/utils/schemas/trl.py @@ -133,3 +133,25 @@ class TRLConfig(BaseModel): "description": "Epsilon value for clipping in the GRPO algorithm." }, ) + epsilon_high: float | None = Field( + default=None, + json_schema_extra={ + "description": "Upper-bound epsilon value for clipping in the GRPO algorithm." + }, + ) + use_liger_loss: bool | None = Field( + default=None, + json_schema_extra={"description": "Whether to use Liger loss for GRPO."}, + ) + loss_type: str | None = Field( + default=None, + json_schema_extra={ + "description": "Specifies the loss formulation to use. Supported values are `grpo`, `bnpo`, and `dr_grpo`." + }, + ) + mask_truncated_completions: bool = Field( + default=False, + json_schema_extra={ + "description": "When enabled, truncated completions are excluded from the loss calculation." + }, + )