diff --git a/docs/rlhf.qmd b/docs/rlhf.qmd index 741976bc6..773b159e8 100644 --- a/docs/rlhf.qmd +++ b/docs/rlhf.qmd @@ -3,6 +3,7 @@ title: "RLHF (Beta)" description: "Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human feedback." back-to-top-navigation: true toc: true +toc-expand: 2 toc-depth: 4 --- @@ -528,6 +529,7 @@ trl: vllm_gpu_memory_utilization: 0.15 num_generations: 4 reward_funcs: ["rewards.rand_reward_func"] # format: '{file_name}.{fn_name}' + reward_weights: [1.0] datasets: - path: openai/gsm8k name: main @@ -536,6 +538,8 @@ datasets: To see other examples of custom reward functions, please see [TRL GRPO Docs](https://github.com/huggingface/trl/blob/main/docs/source/grpo_trainer.md#using-a-custom-reward-function). +To see description of the configs, please see [TRLConfig](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/utils/config/models/input/v0_4_1/trl.py). + ### Using local dataset files ```yaml diff --git a/src/axolotl/utils/config/models/input/v0_4_1/trl.py b/src/axolotl/utils/config/models/input/v0_4_1/trl.py index ae26e634c..f408acdba 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/trl.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/trl.py @@ -1,7 +1,8 @@ """ GRPO specific configuration args """ -from typing import List, Optional + +from typing import Optional from pydantic import BaseModel, Field @@ -11,7 +12,10 @@ class TRLConfig(BaseModel): Input args for TRL. """ - beta: Optional[float] = None + beta: Optional[float] = Field( + default=None, + json_schema_extra={"description": "Beta for RL training"}, + ) max_completion_length: Optional[int] = Field( default=None, json_schema_extra={ @@ -20,17 +24,68 @@ class TRLConfig(BaseModel): ) # GRPO specific args - use_vllm: Optional[bool] = False - vllm_device: Optional[str] = "auto" - vllm_gpu_memory_utilization: Optional[float] = 0.9 - vllm_max_model_len: Optional[int] = None - vllm_dtype: Optional[str] = "auto" + # Ref: https://github.com/huggingface/trl/blob/e3244d2d096ff1e2e248c931d06d39e165e20623/trl/trainer/grpo_config.py#L22 + use_vllm: Optional[bool] = Field( + default=False, + json_schema_extra={"description": "Whether to use VLLM for RL training"}, + ) + vllm_device: Optional[str] = Field( + default="auto", + json_schema_extra={"description": "Device to use for VLLM"}, + ) + vllm_gpu_memory_utilization: Optional[float] = Field( + default=0.9, + json_schema_extra={"description": "GPU memory utilization for VLLM"}, + ) + vllm_dtype: Optional[str] = Field( + default="auto", + json_schema_extra={"description": "Data type for VLLM"}, + ) + vllm_max_model_len: Optional[int] = Field( + default=None, + json_schema_extra={ + "description": "Maximum length of the model context for VLLM" + }, + ) - reward_funcs: Optional[List[str]] = None - reward_weights: Optional[List[float]] = None - num_generations: Optional[int] = None - log_completions: Optional[bool] = False - - sync_ref_model: Optional[bool] = False - ref_model_mixup_alpha: Optional[float] = 0.9 - ref_model_sync_steps: Optional[int] = 64 + reward_funcs: Optional[list[str]] = Field( + default=None, + json_schema_extra={"description": "List of reward functions to load"}, + ) + reward_weights: Optional[list[float]] = Field( + default=None, + json_schema_extra={ + "description": "Weights for each reward function. Must match the number of reward functions." + }, + ) + num_generations: Optional[int] = Field( + default=None, + json_schema_extra={ + "description": "Number of generations to sample. The global batch size (num_processes * per_device_batch_size) must be divisible by this value." + }, + ) + log_completions: Optional[bool] = Field( + default=False, + json_schema_extra={"description": "Whether to log completions"}, + ) + sync_ref_model: Optional[bool] = Field( + default=False, + json_schema_extra={ + "description": ( + "Whether to sync the reference model every `ref_model_sync_steps` " + "steps, using the `ref_model_mixup_alpha` parameter." + ) + }, + ) + ref_model_mixup_alpha: Optional[float] = Field( + default=0.9, + json_schema_extra={ + "description": "Mixup alpha for the reference model. Requires `sync_ref_model=True`." + }, + ) + ref_model_sync_steps: Optional[int] = Field( + default=64, + json_schema_extra={ + "description": "Sync steps for the reference model. Requires `sync_ref_model=True`." + }, + )