From 3659d812f74c9b31c13eb26415dbc7e76973f9be Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 5 Feb 2025 13:20:17 -0500 Subject: [PATCH] use cfg.max_completion_length, not sequence_len --- src/axolotl/core/trainers/grpo/__init__.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/src/axolotl/core/trainers/grpo/__init__.py b/src/axolotl/core/trainers/grpo/__init__.py index 8631e6ff5..15590ad1a 100644 --- a/src/axolotl/core/trainers/grpo/__init__.py +++ b/src/axolotl/core/trainers/grpo/__init__.py @@ -38,14 +38,12 @@ class GRPOStrategy: else: grpo_args_kwargs["vllm_device"] = "auto" if cfg.grpo_vllm_gpu_memory_utilization: - grpo_args_kwargs["vllm_gpu_memory_utilization"] = ( - cfg.grpo_vllm_gpu_memory_utilization - ) + grpo_args_kwargs[ + "vllm_gpu_memory_utilization" + ] = cfg.grpo_vllm_gpu_memory_utilization if cfg.grpo_num_generations: grpo_args_kwargs["num_generations"] = cfg.grpo_num_generations - grpo_args_kwargs["max_completion_length"] = ( - cfg.max_completion_length or cfg.sequence_len - ) + grpo_args_kwargs["max_completion_length"] = cfg.max_completion_length return grpo_args_kwargs @classmethod @@ -57,9 +55,9 @@ class GRPOStrategy: reward_funcs.append(cls.get_reward_func(reward_func_fqn)) trainer_kwargs["reward_funcs"] = reward_funcs if cfg.grpo_reward_processing_classes: - trainer_kwargs["reward_processing_classes"] = ( - cfg.grpo_reward_processing_classes - ) + trainer_kwargs[ + "reward_processing_classes" + ] = cfg.grpo_reward_processing_classes return trainer_kwargs @classmethod