diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 0513595e8..a110c08ed 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -969,8 +969,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase): if self.cfg.dataset_processes: training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes - if self.cfg.rl_beta: - training_args_kwargs["beta"] = self.cfg.rl_beta + if self.cfg.trl.beta or self.cfg.rl_beta: + training_args_kwargs["beta"] = self.cfg.trl.beta or self.cfg.rl_beta if self.cfg.orpo_alpha: # trl does some odd mapping of alpha to beta to reuse the beta parameter ??? training_args_kwargs["beta"] = self.cfg.orpo_alpha diff --git a/src/axolotl/core/trainers/grpo/__init__.py b/src/axolotl/core/trainers/grpo/__init__.py index 15590ad1a..a58cfa19c 100644 --- a/src/axolotl/core/trainers/grpo/__init__.py +++ b/src/axolotl/core/trainers/grpo/__init__.py @@ -31,33 +31,43 @@ class GRPOStrategy: @classmethod def set_training_args_kwargs(cls, cfg): grpo_args_kwargs = {} - if cfg.grpo_use_vllm: - grpo_args_kwargs["use_vllm"] = cfg.grpo_use_vllm - if cfg.grpo_vllm_device: - grpo_args_kwargs["vllm_device"] = cfg.grpo_vllm_device + if cfg.trl and cfg.trl.use_vllm: + grpo_args_kwargs["use_vllm"] = cfg.trl.use_vllm + if cfg.trl.vllm_device: + grpo_args_kwargs["vllm_device"] = cfg.trl.vllm_device else: grpo_args_kwargs["vllm_device"] = "auto" - if cfg.grpo_vllm_gpu_memory_utilization: + if cfg.trl.vllm_gpu_memory_utilization: grpo_args_kwargs[ "vllm_gpu_memory_utilization" - ] = cfg.grpo_vllm_gpu_memory_utilization - if cfg.grpo_num_generations: - grpo_args_kwargs["num_generations"] = cfg.grpo_num_generations - grpo_args_kwargs["max_completion_length"] = cfg.max_completion_length + ] = cfg.trl.vllm_gpu_memory_utilization + if cfg.trl.vllm_max_model_len: + grpo_args_kwargs["vllm_max_model_len"] = cfg.trl.vllm_max_model_len + if cfg.trl and cfg.trl.num_generations: + grpo_args_kwargs["num_generations"] = cfg.trl.num_generations + if cfg.trl and cfg.trl.sync_ref_model: + grpo_args_kwargs["sync_ref_model"] = cfg.trl.sync_ref_model + if cfg.trl.ref_model_mixup_alpha: + grpo_args_kwargs[ + "ref_model_mixup_alpha" + ] = cfg.trl.ref_model_mixup_alpha + if cfg.trl.ref_model_sync_steps: + grpo_args_kwargs["ref_model_sync_steps"] = cfg.trl.ref_model_sync_steps + grpo_args_kwargs["max_completion_length"] = cfg.trl.max_completion_length return grpo_args_kwargs @classmethod def set_trainer_kwargs(cls, cfg): trainer_kwargs = {} - if cfg.grpo_reward_funcs: + if cfg.trl and cfg.trl.reward_funcs: reward_funcs = [] - for reward_func_fqn in cfg.grpo_reward_funcs: + for reward_func_fqn in cfg.trl.reward_funcs: reward_funcs.append(cls.get_reward_func(reward_func_fqn)) trainer_kwargs["reward_funcs"] = reward_funcs - if cfg.grpo_reward_processing_classes: + if cfg.trl and cfg.trl.reward_processing_classes: trainer_kwargs[ "reward_processing_classes" - ] = cfg.grpo_reward_processing_classes + ] = cfg.trl.reward_processing_classes return trainer_kwargs @classmethod diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index bb881bfd5..c2c61f478 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -24,7 +24,7 @@ from transformers.utils.import_utils import is_torch_npu_available from axolotl.utils.config.models.internals import EnvCapabilities, GPUCapabilities -from .grpo import GRPOConfig +from .trl import TrlConfig LOG = logging.getLogger("axolotl.utils.config.models.input") @@ -648,7 +648,6 @@ class AxolotlInputConfig( MLFlowConfig, CometConfig, LISAConfig, - GRPOConfig, GradioConfig, RayConfig, RemappedParameters, @@ -671,6 +670,9 @@ class AxolotlInputConfig( shrink_embeddings: Optional[bool] = None rl: Optional[RLType] = None + trl: Optional[TrlConfig] = Field( + default_factory=lambda: TrlConfig() # pylint: disable=unnecessary-lambda + ) reward_model: Optional[bool] = None process_reward_model: Optional[bool] = None num_labels: Optional[int] = None @@ -757,12 +759,6 @@ class AxolotlInputConfig( default=512, json_schema_extra={"description": "maximum prompt length for RL training"}, ) - max_completion_length: Optional[int] = Field( - default=None, - json_schema_extra={ - "description": "Maximum length of the completion for RL training" - }, - ) sample_packing: Optional[bool] = None sample_packing_group_size: Optional[int] = 100_000 sample_packing_bin_size: Optional[int] = 200 diff --git a/src/axolotl/utils/config/models/input/v0_4_1/grpo.py b/src/axolotl/utils/config/models/input/v0_4_1/grpo.py deleted file mode 100644 index 39ac341c6..000000000 --- a/src/axolotl/utils/config/models/input/v0_4_1/grpo.py +++ /dev/null @@ -1,18 +0,0 @@ -""" -GRPO specific configuration args -""" -from typing import List, Optional - -from pydantic import BaseModel - - -class GRPOConfig(BaseModel): - """ - Input args for GRPO. - """ - - grpo_use_vllm: Optional[bool] = False - grpo_vllm_device: Optional[str] = "auto" - grpo_vllm_gpu_memory_utilization: Optional[float] = 0.9 - grpo_reward_funcs: Optional[List[str]] = None - grpo_num_generations: Optional[int] = None diff --git a/src/axolotl/utils/config/models/input/v0_4_1/trl.py b/src/axolotl/utils/config/models/input/v0_4_1/trl.py new file mode 100644 index 000000000..2408e5420 --- /dev/null +++ b/src/axolotl/utils/config/models/input/v0_4_1/trl.py @@ -0,0 +1,32 @@ +""" +GRPO specific configuration args +""" +from typing import List, Optional + +from pydantic import BaseModel, Field + + +class TrlConfig(BaseModel): + """ + Input args for TRL. + """ + + beta: Optional[float] = None + max_completion_length: Optional[int] = Field( + default=None, + json_schema_extra={ + "description": "Maximum length of the completion for RL training" + }, + ) + + # GRPO specific args + use_vllm: Optional[bool] = False + vllm_device: Optional[str] = "auto" + vllm_gpu_memory_utilization: Optional[float] = 0.9 + vllm_max_model_len: Optional[int] = None + vllm_dtype: Optional[str] = "auto" + reward_funcs: Optional[List[str]] = None + num_generations: Optional[int] = None + sync_ref_model: Optional[bool] = False + ref_model_mixup_alpha: Optional[float] = 0.9 + ref_model_sync_steps: Optional[int] = 64