From c3f2b1c5c2aad469443ae37cce7af475ec70525f Mon Sep 17 00:00:00 2001 From: Dhruv Mullick Date: Wed, 30 Apr 2025 19:00:30 -0600 Subject: [PATCH] Add num_completions_to_print for trl and grpo (#2604) --- src/axolotl/core/trainers/grpo/__init__.py | 1 + src/axolotl/utils/schemas/trl.py | 6 ++++++ 2 files changed, 7 insertions(+) diff --git a/src/axolotl/core/trainers/grpo/__init__.py b/src/axolotl/core/trainers/grpo/__init__.py index b64d087b0..078fdcb22 100644 --- a/src/axolotl/core/trainers/grpo/__init__.py +++ b/src/axolotl/core/trainers/grpo/__init__.py @@ -63,6 +63,7 @@ class GRPOStrategy: grpo_args_kwargs["max_completion_length"] = trl.max_completion_length grpo_args_kwargs["log_completions"] = trl.log_completions + grpo_args_kwargs["num_completions_to_print"] = trl.num_completions_to_print if trl.reward_weights: grpo_args_kwargs["reward_weights"] = trl.reward_weights diff --git a/src/axolotl/utils/schemas/trl.py b/src/axolotl/utils/schemas/trl.py index c581b265e..37b71dba8 100644 --- a/src/axolotl/utils/schemas/trl.py +++ b/src/axolotl/utils/schemas/trl.py @@ -67,6 +67,12 @@ class TRLConfig(BaseModel): default=False, json_schema_extra={"description": "Whether to log completions"}, ) + num_completions_to_print: int | None = Field( + default=None, + json_schema_extra={ + "description": "Number of completions to print. If `log_completions` is `True`, this will be the number of completions logged." + }, + ) sync_ref_model: bool | None = Field( default=False, json_schema_extra={