adding 'reward_processing_classes'
This commit is contained in:
@@ -38,11 +38,14 @@ class GRPOStrategy:
|
|||||||
else:
|
else:
|
||||||
grpo_args_kwargs["vllm_device"] = "auto"
|
grpo_args_kwargs["vllm_device"] = "auto"
|
||||||
if cfg.grpo_vllm_gpu_memory_utilization:
|
if cfg.grpo_vllm_gpu_memory_utilization:
|
||||||
grpo_args_kwargs[
|
grpo_args_kwargs["vllm_gpu_memory_utilization"] = (
|
||||||
"vllm_gpu_memory_utilization"
|
cfg.grpo_vllm_gpu_memory_utilization
|
||||||
] = cfg.grpo_vllm_gpu_memory_utilization
|
)
|
||||||
if cfg.grpo_num_generations:
|
if cfg.grpo_num_generations:
|
||||||
grpo_args_kwargs["num_generations"] = cfg.grpo_num_generations
|
grpo_args_kwargs["num_generations"] = cfg.grpo_num_generations
|
||||||
|
grpo_args_kwargs["max_completion_length"] = (
|
||||||
|
cfg.max_completion_length or cfg.sequence_len
|
||||||
|
)
|
||||||
return grpo_args_kwargs
|
return grpo_args_kwargs
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -54,9 +57,9 @@ class GRPOStrategy:
|
|||||||
reward_funcs.append(cls.get_reward_func(reward_func_fqn))
|
reward_funcs.append(cls.get_reward_func(reward_func_fqn))
|
||||||
trainer_kwargs["reward_funcs"] = reward_funcs
|
trainer_kwargs["reward_funcs"] = reward_funcs
|
||||||
if cfg.grpo_reward_processing_classes:
|
if cfg.grpo_reward_processing_classes:
|
||||||
trainer_kwargs[
|
trainer_kwargs["reward_processing_classes"] = (
|
||||||
"reward_processing_classes"
|
cfg.grpo_reward_processing_classes
|
||||||
] = cfg.grpo_reward_processing_classes
|
)
|
||||||
return trainer_kwargs
|
return trainer_kwargs
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -757,6 +757,12 @@ class AxolotlInputConfig(
|
|||||||
default=512,
|
default=512,
|
||||||
json_schema_extra={"description": "maximum prompt length for RL training"},
|
json_schema_extra={"description": "maximum prompt length for RL training"},
|
||||||
)
|
)
|
||||||
|
max_completion_length: Optional[int] = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Maximum length of the completion for RL training"
|
||||||
|
},
|
||||||
|
)
|
||||||
sample_packing: Optional[bool] = None
|
sample_packing: Optional[bool] = None
|
||||||
sample_packing_group_size: Optional[int] = 100_000
|
sample_packing_group_size: Optional[int] = 100_000
|
||||||
sample_packing_bin_size: Optional[int] = 200
|
sample_packing_bin_size: Optional[int] = 200
|
||||||
|
|||||||
Reference in New Issue
Block a user