diff --git a/src/axolotl/core/trainers/grpo/__init__.py b/src/axolotl/core/trainers/grpo/__init__.py index 0a7a22f2d..8631e6ff5 100644 --- a/src/axolotl/core/trainers/grpo/__init__.py +++ b/src/axolotl/core/trainers/grpo/__init__.py @@ -38,11 +38,14 @@ class GRPOStrategy: else: grpo_args_kwargs["vllm_device"] = "auto" if cfg.grpo_vllm_gpu_memory_utilization: - grpo_args_kwargs[ - "vllm_gpu_memory_utilization" - ] = cfg.grpo_vllm_gpu_memory_utilization + grpo_args_kwargs["vllm_gpu_memory_utilization"] = ( + cfg.grpo_vllm_gpu_memory_utilization + ) if cfg.grpo_num_generations: grpo_args_kwargs["num_generations"] = cfg.grpo_num_generations + grpo_args_kwargs["max_completion_length"] = ( + cfg.max_completion_length or cfg.sequence_len + ) return grpo_args_kwargs @classmethod @@ -54,9 +57,9 @@ class GRPOStrategy: reward_funcs.append(cls.get_reward_func(reward_func_fqn)) trainer_kwargs["reward_funcs"] = reward_funcs if cfg.grpo_reward_processing_classes: - trainer_kwargs[ - "reward_processing_classes" - ] = cfg.grpo_reward_processing_classes + trainer_kwargs["reward_processing_classes"] = ( + cfg.grpo_reward_processing_classes + ) return trainer_kwargs @classmethod diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index b09d89bd5..bb881bfd5 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -757,6 +757,12 @@ class AxolotlInputConfig( default=512, json_schema_extra={"description": "maximum prompt length for RL training"}, ) + max_completion_length: Optional[int] = Field( + default=None, + json_schema_extra={ + "description": "Maximum length of the completion for RL training" + }, + ) sample_packing: Optional[bool] = None sample_packing_group_size: Optional[int] = 100_000 sample_packing_bin_size: Optional[int] = 200