adding 'reward_processing_classes'

This commit is contained in:
Salman Mohammadi
2025-02-05 18:18:42 +00:00
2 changed files with 15 additions and 6 deletions

View File

@@ -38,11 +38,14 @@ class GRPOStrategy:
else:
grpo_args_kwargs["vllm_device"] = "auto"
if cfg.grpo_vllm_gpu_memory_utilization:
grpo_args_kwargs[
"vllm_gpu_memory_utilization"
] = cfg.grpo_vllm_gpu_memory_utilization
grpo_args_kwargs["vllm_gpu_memory_utilization"] = (
cfg.grpo_vllm_gpu_memory_utilization
)
if cfg.grpo_num_generations:
grpo_args_kwargs["num_generations"] = cfg.grpo_num_generations
grpo_args_kwargs["max_completion_length"] = (
cfg.max_completion_length or cfg.sequence_len
)
return grpo_args_kwargs
@classmethod
@@ -54,9 +57,9 @@ class GRPOStrategy:
reward_funcs.append(cls.get_reward_func(reward_func_fqn))
trainer_kwargs["reward_funcs"] = reward_funcs
if cfg.grpo_reward_processing_classes:
trainer_kwargs[
"reward_processing_classes"
] = cfg.grpo_reward_processing_classes
trainer_kwargs["reward_processing_classes"] = (
cfg.grpo_reward_processing_classes
)
return trainer_kwargs
@classmethod

View File

@@ -757,6 +757,12 @@ class AxolotlInputConfig(
default=512,
json_schema_extra={"description": "maximum prompt length for RL training"},
)
max_completion_length: Optional[int] = Field(
default=None,
json_schema_extra={
"description": "Maximum length of the completion for RL training"
},
)
sample_packing: Optional[bool] = None
sample_packing_group_size: Optional[int] = 100_000
sample_packing_bin_size: Optional[int] = 200