refactor: trl grpo configs to have descriptions (#2386)
* refactor: trl grpo configs to have descriptions * chore: caps
This commit is contained in:
@@ -3,6 +3,7 @@ title: "RLHF (Beta)"
|
|||||||
description: "Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human feedback."
|
description: "Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human feedback."
|
||||||
back-to-top-navigation: true
|
back-to-top-navigation: true
|
||||||
toc: true
|
toc: true
|
||||||
|
toc-expand: 2
|
||||||
toc-depth: 4
|
toc-depth: 4
|
||||||
---
|
---
|
||||||
|
|
||||||
@@ -528,6 +529,7 @@ trl:
|
|||||||
vllm_gpu_memory_utilization: 0.15
|
vllm_gpu_memory_utilization: 0.15
|
||||||
num_generations: 4
|
num_generations: 4
|
||||||
reward_funcs: ["rewards.rand_reward_func"] # format: '{file_name}.{fn_name}'
|
reward_funcs: ["rewards.rand_reward_func"] # format: '{file_name}.{fn_name}'
|
||||||
|
reward_weights: [1.0]
|
||||||
datasets:
|
datasets:
|
||||||
- path: openai/gsm8k
|
- path: openai/gsm8k
|
||||||
name: main
|
name: main
|
||||||
@@ -536,6 +538,8 @@ datasets:
|
|||||||
|
|
||||||
To see other examples of custom reward functions, please see [TRL GRPO Docs](https://github.com/huggingface/trl/blob/main/docs/source/grpo_trainer.md#using-a-custom-reward-function).
|
To see other examples of custom reward functions, please see [TRL GRPO Docs](https://github.com/huggingface/trl/blob/main/docs/source/grpo_trainer.md#using-a-custom-reward-function).
|
||||||
|
|
||||||
|
To see description of the configs, please see [TRLConfig](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/utils/config/models/input/v0_4_1/trl.py).
|
||||||
|
|
||||||
### Using local dataset files
|
### Using local dataset files
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
"""
|
"""
|
||||||
GRPO specific configuration args
|
GRPO specific configuration args
|
||||||
"""
|
"""
|
||||||
from typing import List, Optional
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
@@ -11,7 +12,10 @@ class TRLConfig(BaseModel):
|
|||||||
Input args for TRL.
|
Input args for TRL.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
beta: Optional[float] = None
|
beta: Optional[float] = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={"description": "Beta for RL training"},
|
||||||
|
)
|
||||||
max_completion_length: Optional[int] = Field(
|
max_completion_length: Optional[int] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
@@ -20,17 +24,68 @@ class TRLConfig(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# GRPO specific args
|
# GRPO specific args
|
||||||
use_vllm: Optional[bool] = False
|
# Ref: https://github.com/huggingface/trl/blob/e3244d2d096ff1e2e248c931d06d39e165e20623/trl/trainer/grpo_config.py#L22
|
||||||
vllm_device: Optional[str] = "auto"
|
use_vllm: Optional[bool] = Field(
|
||||||
vllm_gpu_memory_utilization: Optional[float] = 0.9
|
default=False,
|
||||||
vllm_max_model_len: Optional[int] = None
|
json_schema_extra={"description": "Whether to use VLLM for RL training"},
|
||||||
vllm_dtype: Optional[str] = "auto"
|
)
|
||||||
|
vllm_device: Optional[str] = Field(
|
||||||
|
default="auto",
|
||||||
|
json_schema_extra={"description": "Device to use for VLLM"},
|
||||||
|
)
|
||||||
|
vllm_gpu_memory_utilization: Optional[float] = Field(
|
||||||
|
default=0.9,
|
||||||
|
json_schema_extra={"description": "GPU memory utilization for VLLM"},
|
||||||
|
)
|
||||||
|
vllm_dtype: Optional[str] = Field(
|
||||||
|
default="auto",
|
||||||
|
json_schema_extra={"description": "Data type for VLLM"},
|
||||||
|
)
|
||||||
|
vllm_max_model_len: Optional[int] = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Maximum length of the model context for VLLM"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
reward_funcs: Optional[List[str]] = None
|
reward_funcs: Optional[list[str]] = Field(
|
||||||
reward_weights: Optional[List[float]] = None
|
default=None,
|
||||||
num_generations: Optional[int] = None
|
json_schema_extra={"description": "List of reward functions to load"},
|
||||||
log_completions: Optional[bool] = False
|
)
|
||||||
|
reward_weights: Optional[list[float]] = Field(
|
||||||
sync_ref_model: Optional[bool] = False
|
default=None,
|
||||||
ref_model_mixup_alpha: Optional[float] = 0.9
|
json_schema_extra={
|
||||||
ref_model_sync_steps: Optional[int] = 64
|
"description": "Weights for each reward function. Must match the number of reward functions."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
num_generations: Optional[int] = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Number of generations to sample. The global batch size (num_processes * per_device_batch_size) must be divisible by this value."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
log_completions: Optional[bool] = Field(
|
||||||
|
default=False,
|
||||||
|
json_schema_extra={"description": "Whether to log completions"},
|
||||||
|
)
|
||||||
|
sync_ref_model: Optional[bool] = Field(
|
||||||
|
default=False,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": (
|
||||||
|
"Whether to sync the reference model every `ref_model_sync_steps` "
|
||||||
|
"steps, using the `ref_model_mixup_alpha` parameter."
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
ref_model_mixup_alpha: Optional[float] = Field(
|
||||||
|
default=0.9,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Mixup alpha for the reference model. Requires `sync_ref_model=True`."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
ref_model_sync_steps: Optional[int] = Field(
|
||||||
|
default=64,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Sync steps for the reference model. Requires `sync_ref_model=True`."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user