refactor: trl grpo configs to have descriptions (#2386)

* refactor: trl grpo configs to have descriptions

* chore: caps
This commit is contained in:
NanoCode012
2025-03-07 20:58:53 +07:00
committed by GitHub
parent fa7c79b3b9
commit 16dc6ee68d
2 changed files with 74 additions and 15 deletions

View File

@@ -3,6 +3,7 @@ title: "RLHF (Beta)"
description: "Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human feedback." description: "Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human feedback."
back-to-top-navigation: true back-to-top-navigation: true
toc: true toc: true
toc-expand: 2
toc-depth: 4 toc-depth: 4
--- ---
@@ -528,6 +529,7 @@ trl:
vllm_gpu_memory_utilization: 0.15 vllm_gpu_memory_utilization: 0.15
num_generations: 4 num_generations: 4
reward_funcs: ["rewards.rand_reward_func"] # format: '{file_name}.{fn_name}' reward_funcs: ["rewards.rand_reward_func"] # format: '{file_name}.{fn_name}'
reward_weights: [1.0]
datasets: datasets:
- path: openai/gsm8k - path: openai/gsm8k
name: main name: main
@@ -536,6 +538,8 @@ datasets:
To see other examples of custom reward functions, please see [TRL GRPO Docs](https://github.com/huggingface/trl/blob/main/docs/source/grpo_trainer.md#using-a-custom-reward-function). To see other examples of custom reward functions, please see [TRL GRPO Docs](https://github.com/huggingface/trl/blob/main/docs/source/grpo_trainer.md#using-a-custom-reward-function).
To see description of the configs, please see [TRLConfig](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/utils/config/models/input/v0_4_1/trl.py).
### Using local dataset files ### Using local dataset files
```yaml ```yaml

View File

@@ -1,7 +1,8 @@
""" """
GRPO specific configuration args GRPO specific configuration args
""" """
from typing import List, Optional
from typing import Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@@ -11,7 +12,10 @@ class TRLConfig(BaseModel):
Input args for TRL. Input args for TRL.
""" """
beta: Optional[float] = None beta: Optional[float] = Field(
default=None,
json_schema_extra={"description": "Beta for RL training"},
)
max_completion_length: Optional[int] = Field( max_completion_length: Optional[int] = Field(
default=None, default=None,
json_schema_extra={ json_schema_extra={
@@ -20,17 +24,68 @@ class TRLConfig(BaseModel):
) )
# GRPO specific args # GRPO specific args
use_vllm: Optional[bool] = False # Ref: https://github.com/huggingface/trl/blob/e3244d2d096ff1e2e248c931d06d39e165e20623/trl/trainer/grpo_config.py#L22
vllm_device: Optional[str] = "auto" use_vllm: Optional[bool] = Field(
vllm_gpu_memory_utilization: Optional[float] = 0.9 default=False,
vllm_max_model_len: Optional[int] = None json_schema_extra={"description": "Whether to use VLLM for RL training"},
vllm_dtype: Optional[str] = "auto" )
vllm_device: Optional[str] = Field(
default="auto",
json_schema_extra={"description": "Device to use for VLLM"},
)
vllm_gpu_memory_utilization: Optional[float] = Field(
default=0.9,
json_schema_extra={"description": "GPU memory utilization for VLLM"},
)
vllm_dtype: Optional[str] = Field(
default="auto",
json_schema_extra={"description": "Data type for VLLM"},
)
vllm_max_model_len: Optional[int] = Field(
default=None,
json_schema_extra={
"description": "Maximum length of the model context for VLLM"
},
)
reward_funcs: Optional[List[str]] = None reward_funcs: Optional[list[str]] = Field(
reward_weights: Optional[List[float]] = None default=None,
num_generations: Optional[int] = None json_schema_extra={"description": "List of reward functions to load"},
log_completions: Optional[bool] = False )
reward_weights: Optional[list[float]] = Field(
sync_ref_model: Optional[bool] = False default=None,
ref_model_mixup_alpha: Optional[float] = 0.9 json_schema_extra={
ref_model_sync_steps: Optional[int] = 64 "description": "Weights for each reward function. Must match the number of reward functions."
},
)
num_generations: Optional[int] = Field(
default=None,
json_schema_extra={
"description": "Number of generations to sample. The global batch size (num_processes * per_device_batch_size) must be divisible by this value."
},
)
log_completions: Optional[bool] = Field(
default=False,
json_schema_extra={"description": "Whether to log completions"},
)
sync_ref_model: Optional[bool] = Field(
default=False,
json_schema_extra={
"description": (
"Whether to sync the reference model every `ref_model_sync_steps` "
"steps, using the `ref_model_mixup_alpha` parameter."
)
},
)
ref_model_mixup_alpha: Optional[float] = Field(
default=0.9,
json_schema_extra={
"description": "Mixup alpha for the reference model. Requires `sync_ref_model=True`."
},
)
ref_model_sync_steps: Optional[int] = Field(
default=64,
json_schema_extra={
"description": "Sync steps for the reference model. Requires `sync_ref_model=True`."
},
)