refactor cfg.grpo_* to use cfg.trl.*
This commit is contained in:
@@ -969,8 +969,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.dataset_processes:
|
if self.cfg.dataset_processes:
|
||||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||||
|
|
||||||
if self.cfg.rl_beta:
|
if self.cfg.trl.beta or self.cfg.rl_beta:
|
||||||
training_args_kwargs["beta"] = self.cfg.rl_beta
|
training_args_kwargs["beta"] = self.cfg.trl.beta or self.cfg.rl_beta
|
||||||
if self.cfg.orpo_alpha:
|
if self.cfg.orpo_alpha:
|
||||||
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
|
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
|
||||||
training_args_kwargs["beta"] = self.cfg.orpo_alpha
|
training_args_kwargs["beta"] = self.cfg.orpo_alpha
|
||||||
|
|||||||
@@ -31,33 +31,43 @@ class GRPOStrategy:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def set_training_args_kwargs(cls, cfg):
|
def set_training_args_kwargs(cls, cfg):
|
||||||
grpo_args_kwargs = {}
|
grpo_args_kwargs = {}
|
||||||
if cfg.grpo_use_vllm:
|
if cfg.trl and cfg.trl.use_vllm:
|
||||||
grpo_args_kwargs["use_vllm"] = cfg.grpo_use_vllm
|
grpo_args_kwargs["use_vllm"] = cfg.trl.use_vllm
|
||||||
if cfg.grpo_vllm_device:
|
if cfg.trl.vllm_device:
|
||||||
grpo_args_kwargs["vllm_device"] = cfg.grpo_vllm_device
|
grpo_args_kwargs["vllm_device"] = cfg.trl.vllm_device
|
||||||
else:
|
else:
|
||||||
grpo_args_kwargs["vllm_device"] = "auto"
|
grpo_args_kwargs["vllm_device"] = "auto"
|
||||||
if cfg.grpo_vllm_gpu_memory_utilization:
|
if cfg.trl.vllm_gpu_memory_utilization:
|
||||||
grpo_args_kwargs[
|
grpo_args_kwargs[
|
||||||
"vllm_gpu_memory_utilization"
|
"vllm_gpu_memory_utilization"
|
||||||
] = cfg.grpo_vllm_gpu_memory_utilization
|
] = cfg.trl.vllm_gpu_memory_utilization
|
||||||
if cfg.grpo_num_generations:
|
if cfg.trl.vllm_max_model_len:
|
||||||
grpo_args_kwargs["num_generations"] = cfg.grpo_num_generations
|
grpo_args_kwargs["vllm_max_model_len"] = cfg.trl.vllm_max_model_len
|
||||||
grpo_args_kwargs["max_completion_length"] = cfg.max_completion_length
|
if cfg.trl and cfg.trl.num_generations:
|
||||||
|
grpo_args_kwargs["num_generations"] = cfg.trl.num_generations
|
||||||
|
if cfg.trl and cfg.trl.sync_ref_model:
|
||||||
|
grpo_args_kwargs["sync_ref_model"] = cfg.trl.sync_ref_model
|
||||||
|
if cfg.trl.ref_model_mixup_alpha:
|
||||||
|
grpo_args_kwargs[
|
||||||
|
"ref_model_mixup_alpha"
|
||||||
|
] = cfg.trl.ref_model_mixup_alpha
|
||||||
|
if cfg.trl.ref_model_sync_steps:
|
||||||
|
grpo_args_kwargs["ref_model_sync_steps"] = cfg.trl.ref_model_sync_steps
|
||||||
|
grpo_args_kwargs["max_completion_length"] = cfg.trl.max_completion_length
|
||||||
return grpo_args_kwargs
|
return grpo_args_kwargs
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def set_trainer_kwargs(cls, cfg):
|
def set_trainer_kwargs(cls, cfg):
|
||||||
trainer_kwargs = {}
|
trainer_kwargs = {}
|
||||||
if cfg.grpo_reward_funcs:
|
if cfg.trl and cfg.trl.reward_funcs:
|
||||||
reward_funcs = []
|
reward_funcs = []
|
||||||
for reward_func_fqn in cfg.grpo_reward_funcs:
|
for reward_func_fqn in cfg.trl.reward_funcs:
|
||||||
reward_funcs.append(cls.get_reward_func(reward_func_fqn))
|
reward_funcs.append(cls.get_reward_func(reward_func_fqn))
|
||||||
trainer_kwargs["reward_funcs"] = reward_funcs
|
trainer_kwargs["reward_funcs"] = reward_funcs
|
||||||
if cfg.grpo_reward_processing_classes:
|
if cfg.trl and cfg.trl.reward_processing_classes:
|
||||||
trainer_kwargs[
|
trainer_kwargs[
|
||||||
"reward_processing_classes"
|
"reward_processing_classes"
|
||||||
] = cfg.grpo_reward_processing_classes
|
] = cfg.trl.reward_processing_classes
|
||||||
return trainer_kwargs
|
return trainer_kwargs
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ from transformers.utils.import_utils import is_torch_npu_available
|
|||||||
|
|
||||||
from axolotl.utils.config.models.internals import EnvCapabilities, GPUCapabilities
|
from axolotl.utils.config.models.internals import EnvCapabilities, GPUCapabilities
|
||||||
|
|
||||||
from .grpo import GRPOConfig
|
from .trl import TrlConfig
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.utils.config.models.input")
|
LOG = logging.getLogger("axolotl.utils.config.models.input")
|
||||||
|
|
||||||
@@ -648,7 +648,6 @@ class AxolotlInputConfig(
|
|||||||
MLFlowConfig,
|
MLFlowConfig,
|
||||||
CometConfig,
|
CometConfig,
|
||||||
LISAConfig,
|
LISAConfig,
|
||||||
GRPOConfig,
|
|
||||||
GradioConfig,
|
GradioConfig,
|
||||||
RayConfig,
|
RayConfig,
|
||||||
RemappedParameters,
|
RemappedParameters,
|
||||||
@@ -671,6 +670,9 @@ class AxolotlInputConfig(
|
|||||||
shrink_embeddings: Optional[bool] = None
|
shrink_embeddings: Optional[bool] = None
|
||||||
|
|
||||||
rl: Optional[RLType] = None
|
rl: Optional[RLType] = None
|
||||||
|
trl: Optional[TrlConfig] = Field(
|
||||||
|
default_factory=lambda: TrlConfig() # pylint: disable=unnecessary-lambda
|
||||||
|
)
|
||||||
reward_model: Optional[bool] = None
|
reward_model: Optional[bool] = None
|
||||||
process_reward_model: Optional[bool] = None
|
process_reward_model: Optional[bool] = None
|
||||||
num_labels: Optional[int] = None
|
num_labels: Optional[int] = None
|
||||||
@@ -757,12 +759,6 @@ class AxolotlInputConfig(
|
|||||||
default=512,
|
default=512,
|
||||||
json_schema_extra={"description": "maximum prompt length for RL training"},
|
json_schema_extra={"description": "maximum prompt length for RL training"},
|
||||||
)
|
)
|
||||||
max_completion_length: Optional[int] = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": "Maximum length of the completion for RL training"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
sample_packing: Optional[bool] = None
|
sample_packing: Optional[bool] = None
|
||||||
sample_packing_group_size: Optional[int] = 100_000
|
sample_packing_group_size: Optional[int] = 100_000
|
||||||
sample_packing_bin_size: Optional[int] = 200
|
sample_packing_bin_size: Optional[int] = 200
|
||||||
|
|||||||
@@ -1,18 +0,0 @@
|
|||||||
"""
|
|
||||||
GRPO specific configuration args
|
|
||||||
"""
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
|
|
||||||
class GRPOConfig(BaseModel):
|
|
||||||
"""
|
|
||||||
Input args for GRPO.
|
|
||||||
"""
|
|
||||||
|
|
||||||
grpo_use_vllm: Optional[bool] = False
|
|
||||||
grpo_vllm_device: Optional[str] = "auto"
|
|
||||||
grpo_vllm_gpu_memory_utilization: Optional[float] = 0.9
|
|
||||||
grpo_reward_funcs: Optional[List[str]] = None
|
|
||||||
grpo_num_generations: Optional[int] = None
|
|
||||||
32
src/axolotl/utils/config/models/input/v0_4_1/trl.py
Normal file
32
src/axolotl/utils/config/models/input/v0_4_1/trl.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
"""
|
||||||
|
GRPO specific configuration args
|
||||||
|
"""
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class TrlConfig(BaseModel):
|
||||||
|
"""
|
||||||
|
Input args for TRL.
|
||||||
|
"""
|
||||||
|
|
||||||
|
beta: Optional[float] = None
|
||||||
|
max_completion_length: Optional[int] = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Maximum length of the completion for RL training"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# GRPO specific args
|
||||||
|
use_vllm: Optional[bool] = False
|
||||||
|
vllm_device: Optional[str] = "auto"
|
||||||
|
vllm_gpu_memory_utilization: Optional[float] = 0.9
|
||||||
|
vllm_max_model_len: Optional[int] = None
|
||||||
|
vllm_dtype: Optional[str] = "auto"
|
||||||
|
reward_funcs: Optional[List[str]] = None
|
||||||
|
num_generations: Optional[int] = None
|
||||||
|
sync_ref_model: Optional[bool] = False
|
||||||
|
ref_model_mixup_alpha: Optional[float] = 0.9
|
||||||
|
ref_model_sync_steps: Optional[int] = 64
|
||||||
Reference in New Issue
Block a user