refactor cfg.grpo_* to use cfg.trl.*

This commit is contained in:
Wing Lian
2025-02-05 20:41:14 -05:00
parent 3659d812f7
commit aded9c500d
5 changed files with 61 additions and 41 deletions

View File

@@ -969,8 +969,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.dataset_processes:
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
if self.cfg.rl_beta:
training_args_kwargs["beta"] = self.cfg.rl_beta
if self.cfg.trl.beta or self.cfg.rl_beta:
training_args_kwargs["beta"] = self.cfg.trl.beta or self.cfg.rl_beta
if self.cfg.orpo_alpha:
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
training_args_kwargs["beta"] = self.cfg.orpo_alpha

View File

@@ -31,33 +31,43 @@ class GRPOStrategy:
@classmethod
def set_training_args_kwargs(cls, cfg):
grpo_args_kwargs = {}
if cfg.grpo_use_vllm:
grpo_args_kwargs["use_vllm"] = cfg.grpo_use_vllm
if cfg.grpo_vllm_device:
grpo_args_kwargs["vllm_device"] = cfg.grpo_vllm_device
if cfg.trl and cfg.trl.use_vllm:
grpo_args_kwargs["use_vllm"] = cfg.trl.use_vllm
if cfg.trl.vllm_device:
grpo_args_kwargs["vllm_device"] = cfg.trl.vllm_device
else:
grpo_args_kwargs["vllm_device"] = "auto"
if cfg.grpo_vllm_gpu_memory_utilization:
if cfg.trl.vllm_gpu_memory_utilization:
grpo_args_kwargs[
"vllm_gpu_memory_utilization"
] = cfg.grpo_vllm_gpu_memory_utilization
if cfg.grpo_num_generations:
grpo_args_kwargs["num_generations"] = cfg.grpo_num_generations
grpo_args_kwargs["max_completion_length"] = cfg.max_completion_length
] = cfg.trl.vllm_gpu_memory_utilization
if cfg.trl.vllm_max_model_len:
grpo_args_kwargs["vllm_max_model_len"] = cfg.trl.vllm_max_model_len
if cfg.trl and cfg.trl.num_generations:
grpo_args_kwargs["num_generations"] = cfg.trl.num_generations
if cfg.trl and cfg.trl.sync_ref_model:
grpo_args_kwargs["sync_ref_model"] = cfg.trl.sync_ref_model
if cfg.trl.ref_model_mixup_alpha:
grpo_args_kwargs[
"ref_model_mixup_alpha"
] = cfg.trl.ref_model_mixup_alpha
if cfg.trl.ref_model_sync_steps:
grpo_args_kwargs["ref_model_sync_steps"] = cfg.trl.ref_model_sync_steps
grpo_args_kwargs["max_completion_length"] = cfg.trl.max_completion_length
return grpo_args_kwargs
@classmethod
def set_trainer_kwargs(cls, cfg):
trainer_kwargs = {}
if cfg.grpo_reward_funcs:
if cfg.trl and cfg.trl.reward_funcs:
reward_funcs = []
for reward_func_fqn in cfg.grpo_reward_funcs:
for reward_func_fqn in cfg.trl.reward_funcs:
reward_funcs.append(cls.get_reward_func(reward_func_fqn))
trainer_kwargs["reward_funcs"] = reward_funcs
if cfg.grpo_reward_processing_classes:
if cfg.trl and cfg.trl.reward_processing_classes:
trainer_kwargs[
"reward_processing_classes"
] = cfg.grpo_reward_processing_classes
] = cfg.trl.reward_processing_classes
return trainer_kwargs
@classmethod

View File

@@ -24,7 +24,7 @@ from transformers.utils.import_utils import is_torch_npu_available
from axolotl.utils.config.models.internals import EnvCapabilities, GPUCapabilities
from .grpo import GRPOConfig
from .trl import TrlConfig
LOG = logging.getLogger("axolotl.utils.config.models.input")
@@ -648,7 +648,6 @@ class AxolotlInputConfig(
MLFlowConfig,
CometConfig,
LISAConfig,
GRPOConfig,
GradioConfig,
RayConfig,
RemappedParameters,
@@ -671,6 +670,9 @@ class AxolotlInputConfig(
shrink_embeddings: Optional[bool] = None
rl: Optional[RLType] = None
trl: Optional[TrlConfig] = Field(
default_factory=lambda: TrlConfig() # pylint: disable=unnecessary-lambda
)
reward_model: Optional[bool] = None
process_reward_model: Optional[bool] = None
num_labels: Optional[int] = None
@@ -757,12 +759,6 @@ class AxolotlInputConfig(
default=512,
json_schema_extra={"description": "maximum prompt length for RL training"},
)
max_completion_length: Optional[int] = Field(
default=None,
json_schema_extra={
"description": "Maximum length of the completion for RL training"
},
)
sample_packing: Optional[bool] = None
sample_packing_group_size: Optional[int] = 100_000
sample_packing_bin_size: Optional[int] = 200

View File

@@ -1,18 +0,0 @@
"""
GRPO specific configuration args
"""
from typing import List, Optional
from pydantic import BaseModel
class GRPOConfig(BaseModel):
"""
Input args for GRPO.
"""
grpo_use_vllm: Optional[bool] = False
grpo_vllm_device: Optional[str] = "auto"
grpo_vllm_gpu_memory_utilization: Optional[float] = 0.9
grpo_reward_funcs: Optional[List[str]] = None
grpo_num_generations: Optional[int] = None

View File

@@ -0,0 +1,32 @@
"""
GRPO specific configuration args
"""
from typing import List, Optional
from pydantic import BaseModel, Field
class TrlConfig(BaseModel):
"""
Input args for TRL.
"""
beta: Optional[float] = None
max_completion_length: Optional[int] = Field(
default=None,
json_schema_extra={
"description": "Maximum length of the completion for RL training"
},
)
# GRPO specific args
use_vllm: Optional[bool] = False
vllm_device: Optional[str] = "auto"
vllm_gpu_memory_utilization: Optional[float] = 0.9
vllm_max_model_len: Optional[int] = None
vllm_dtype: Optional[str] = "auto"
reward_funcs: Optional[List[str]] = None
num_generations: Optional[int] = None
sync_ref_model: Optional[bool] = False
ref_model_mixup_alpha: Optional[float] = 0.9
ref_model_sync_steps: Optional[int] = 64