use cfg.max_completion_length, not sequence_len

This commit is contained in:
Wing Lian
2025-02-05 13:20:17 -05:00
parent bdb0f97082
commit 3659d812f7

View File

@@ -38,14 +38,12 @@ class GRPOStrategy:
else:
grpo_args_kwargs["vllm_device"] = "auto"
if cfg.grpo_vllm_gpu_memory_utilization:
grpo_args_kwargs["vllm_gpu_memory_utilization"] = (
cfg.grpo_vllm_gpu_memory_utilization
)
grpo_args_kwargs[
"vllm_gpu_memory_utilization"
] = cfg.grpo_vllm_gpu_memory_utilization
if cfg.grpo_num_generations:
grpo_args_kwargs["num_generations"] = cfg.grpo_num_generations
grpo_args_kwargs["max_completion_length"] = (
cfg.max_completion_length or cfg.sequence_len
)
grpo_args_kwargs["max_completion_length"] = cfg.max_completion_length
return grpo_args_kwargs
@classmethod
@@ -57,9 +55,9 @@ class GRPOStrategy:
reward_funcs.append(cls.get_reward_func(reward_func_fqn))
trainer_kwargs["reward_funcs"] = reward_funcs
if cfg.grpo_reward_processing_classes:
trainer_kwargs["reward_processing_classes"] = (
cfg.grpo_reward_processing_classes
)
trainer_kwargs[
"reward_processing_classes"
] = cfg.grpo_reward_processing_classes
return trainer_kwargs
@classmethod