use cfg.max_completion_length, not sequence_len
This commit is contained in:
@@ -38,14 +38,12 @@ class GRPOStrategy:
|
||||
else:
|
||||
grpo_args_kwargs["vllm_device"] = "auto"
|
||||
if cfg.grpo_vllm_gpu_memory_utilization:
|
||||
grpo_args_kwargs["vllm_gpu_memory_utilization"] = (
|
||||
cfg.grpo_vllm_gpu_memory_utilization
|
||||
)
|
||||
grpo_args_kwargs[
|
||||
"vllm_gpu_memory_utilization"
|
||||
] = cfg.grpo_vllm_gpu_memory_utilization
|
||||
if cfg.grpo_num_generations:
|
||||
grpo_args_kwargs["num_generations"] = cfg.grpo_num_generations
|
||||
grpo_args_kwargs["max_completion_length"] = (
|
||||
cfg.max_completion_length or cfg.sequence_len
|
||||
)
|
||||
grpo_args_kwargs["max_completion_length"] = cfg.max_completion_length
|
||||
return grpo_args_kwargs
|
||||
|
||||
@classmethod
|
||||
@@ -57,9 +55,9 @@ class GRPOStrategy:
|
||||
reward_funcs.append(cls.get_reward_func(reward_func_fqn))
|
||||
trainer_kwargs["reward_funcs"] = reward_funcs
|
||||
if cfg.grpo_reward_processing_classes:
|
||||
trainer_kwargs["reward_processing_classes"] = (
|
||||
cfg.grpo_reward_processing_classes
|
||||
)
|
||||
trainer_kwargs[
|
||||
"reward_processing_classes"
|
||||
] = cfg.grpo_reward_processing_classes
|
||||
return trainer_kwargs
|
||||
|
||||
@classmethod
|
||||
|
||||
Reference in New Issue
Block a user