diff --git a/src/axolotl/core/trainers/grpo/__init__.py b/src/axolotl/core/trainers/grpo/__init__.py index b64d087b0..078fdcb22 100644 --- a/src/axolotl/core/trainers/grpo/__init__.py +++ b/src/axolotl/core/trainers/grpo/__init__.py @@ -63,6 +63,7 @@ class GRPOStrategy: grpo_args_kwargs["max_completion_length"] = trl.max_completion_length grpo_args_kwargs["log_completions"] = trl.log_completions + grpo_args_kwargs["num_completions_to_print"] = trl.num_completions_to_print if trl.reward_weights: grpo_args_kwargs["reward_weights"] = trl.reward_weights diff --git a/src/axolotl/utils/schemas/trl.py b/src/axolotl/utils/schemas/trl.py index c581b265e..37b71dba8 100644 --- a/src/axolotl/utils/schemas/trl.py +++ b/src/axolotl/utils/schemas/trl.py @@ -67,6 +67,12 @@ class TRLConfig(BaseModel): default=False, json_schema_extra={"description": "Whether to log completions"}, ) + num_completions_to_print: int | None = Field( + default=None, + json_schema_extra={ + "description": "Number of completions to print. If `log_completions` is `True`, this will be the number of completions logged." + }, + ) sync_ref_model: bool | None = Field( default=False, json_schema_extra={