From 59ad21f2de0e38057200a5a182fe97ecbefcc94d Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 2 Feb 2025 23:22:36 -0500 Subject: [PATCH] refactor a bit for better grpo support --- requirements.txt | 2 +- src/axolotl/core/trainer_builder.py | 26 +++++----- src/axolotl/core/trainers/dpo/__init__.py | 33 ++++++++++++ src/axolotl/core/trainers/dpo/args.py | 15 ++++++ .../{dpo_trainer.py => dpo/trainer.py} | 0 src/axolotl/core/trainers/grpo/__init__.py | 52 +++++++++++++++++++ src/axolotl/core/trainers/grpo/args.py | 12 +++++ src/axolotl/core/trainers/grpo/trainer.py | 12 +++++ src/axolotl/core/training_args.py | 9 +--- .../config/models/input/v0_4_1/__init__.py | 4 ++ .../utils/config/models/input/v0_4_1/grpo.py | 17 ++++++ 11 files changed, 160 insertions(+), 22 deletions(-) create mode 100644 src/axolotl/core/trainers/dpo/__init__.py create mode 100644 src/axolotl/core/trainers/dpo/args.py rename src/axolotl/core/trainers/{dpo_trainer.py => dpo/trainer.py} (100%) create mode 100644 src/axolotl/core/trainers/grpo/__init__.py create mode 100644 src/axolotl/core/trainers/grpo/args.py create mode 100644 src/axolotl/core/trainers/grpo/trainer.py create mode 100644 src/axolotl/utils/config/models/input/v0_4_1/grpo.py diff --git a/requirements.txt b/requirements.txt index 822950847..061749902 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,7 +13,7 @@ liger-kernel==0.5.2 packaging==23.2 peft==0.14.0 -transformers==4.48.1 +transformers==4.48.2 tokenizers>=0.21.0 accelerate==1.3.0 datasets==3.2.0 diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index aeee49f5a..19481d22c 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -35,6 +35,7 @@ from transformers import ( EarlyStoppingCallback, TrainerCallback, ) +from trl import DPOConfig from trl.trainer.utils import RewardDataCollatorWithPadding from axolotl.core.trainers.base import ( @@ -47,10 +48,10 @@ from axolotl.core.trainers.base import ( AxolotlTrainer, ReLoRATrainer, ) -from axolotl.core.trainers.dpo_trainer import AxolotlDPOTrainer +from axolotl.core.trainers.dpo import DPOStrategy +from axolotl.core.trainers.grpo import GRPOStrategy from axolotl.core.training_args import ( AxolotlCPOConfig, - AxolotlDPOConfig, AxolotlKTOConfig, AxolotlORPOConfig, AxolotlPRMConfig, @@ -1006,16 +1007,13 @@ class HFRLTrainerBuilder(TrainerBuilderBase): if self.cfg.max_prompt_len: training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len + elif self.cfg.rl == "grpo": + training_args_cls = GRPOStrategy.get_training_args_class() + training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg)) + else: - training_args_cls = AxolotlDPOConfig - if self.cfg.rl == "ipo": - training_args_kwargs["loss_type"] = "ipo" - training_args_kwargs["max_length"] = self.cfg.sequence_len - training_args_kwargs["max_completion_length"] = None - training_args_kwargs["max_prompt_length"] = self.cfg.sequence_len - training_args_kwargs["generate_during_eval"] = self.cfg.use_wandb - if self.cfg.dpo_use_weighting is not None: - training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting + training_args_cls = DPOConfig.get_training_args_class() + training_args_kwargs.update(DPOConfig.set_training_args_kwargs(self.cfg)) training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg output_dir=self.cfg.output_dir, @@ -1048,7 +1046,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase): "precompute_ref_log_probs" ] = self.cfg.precompute_ref_log_probs if self.cfg.rl in ["dpo", "ipo"]: - trainer_cls = AxolotlDPOTrainer + trainer_cls = DPOStrategy.get_trainer_class() trainer_cls_args = [self.model, self.model_ref] elif self.cfg.rl == "orpo": trainer_cls = AxolotlORPOTrainer @@ -1068,7 +1066,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase): else: dpo_trainer_kwargs["tokenizer"] = self.tokenizer - if self.cfg.datasets is not None and (trainer_cls is AxolotlDPOTrainer): + if self.cfg.datasets is not None and ( + trainer_cls is DPOStrategy.get_trainer_class() + ): dpo_trainer_kwargs["dataset_tags"] = [ d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir() ] diff --git a/src/axolotl/core/trainers/dpo/__init__.py b/src/axolotl/core/trainers/dpo/__init__.py new file mode 100644 index 000000000..8187a7fb5 --- /dev/null +++ b/src/axolotl/core/trainers/dpo/__init__.py @@ -0,0 +1,33 @@ +""" +DPO Specific Strategy for training +""" +from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer + + +class DPOStrategy: + """ + Strategy for DPO training + """ + + @classmethod + def get_trainer_class(cls): + return AxolotlDPOTrainer + + @classmethod + def get_training_args_class(cls): + from axolotl.core.trainers.dpo.args import AxolotlDPOConfig + + return AxolotlDPOConfig + + @classmethod + def set_training_args_kwargs(cls, cfg): + training_args_kwargs = {} + if cfg.rl == "ipo": + training_args_kwargs["loss_type"] = "ipo" + training_args_kwargs["max_length"] = cfg.sequence_len + training_args_kwargs["max_completion_length"] = None + training_args_kwargs["max_prompt_length"] = cfg.sequence_len + training_args_kwargs["generate_during_eval"] = cfg.use_wandb + if cfg.dpo_use_weighting is not None: + training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting + return training_args_kwargs diff --git a/src/axolotl/core/trainers/dpo/args.py b/src/axolotl/core/trainers/dpo/args.py new file mode 100644 index 000000000..4cae67d3e --- /dev/null +++ b/src/axolotl/core/trainers/dpo/args.py @@ -0,0 +1,15 @@ +""" +Axolotl specific DPO args +""" +from dataclasses import dataclass + +from trl import DPOConfig + +from axolotl.core.training_args import AxolotlTrainingMixins + + +@dataclass +class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig): + """ + DPO config for DPO training + """ diff --git a/src/axolotl/core/trainers/dpo_trainer.py b/src/axolotl/core/trainers/dpo/trainer.py similarity index 100% rename from src/axolotl/core/trainers/dpo_trainer.py rename to src/axolotl/core/trainers/dpo/trainer.py diff --git a/src/axolotl/core/trainers/grpo/__init__.py b/src/axolotl/core/trainers/grpo/__init__.py new file mode 100644 index 000000000..bab8605e8 --- /dev/null +++ b/src/axolotl/core/trainers/grpo/__init__.py @@ -0,0 +1,52 @@ +""" +GRPO Specific Strategy for training +""" +import importlib + +from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer + + +class GRPOStrategy: + """ + Strategy for GRPO training + """ + + @classmethod + def get_trainer_class(cls): + return AxolotlGRPOTrainer + + @classmethod + def get_training_args_class(cls): + from axolotl.core.trainers.grpo.args import AxolotlGRPOConfig + + return AxolotlGRPOConfig + + @classmethod + def set_training_args_kwargs(cls, cfg): + grpo_args_kwargs = {} + if cfg.grpo_use_vllm: + grpo_args_kwargs["use_vllm"] = cfg.grpo_use_vllm + if cfg.grpo_vllm_device: + grpo_args_kwargs["vllm_device"] = cfg.grpo_vllm_device + else: + grpo_args_kwargs["vllm_device"] = "auto" + if cfg.grpo_vllm_gpu_memory_utilization: + grpo_args_kwargs[ + "vllm_gpu_memory_utilization" + ] = cfg.grpo_vllm_gpu_memory_utilization + return grpo_args_kwargs + + @classmethod + def set_trainer_kwargs(cls, cfg): + trainer_kwargs = {} + if cfg.grpo_reward_funcs: + reward_funcs = [] + for reward_func_module in cfg.grpo_reward_funcs: + # use importlib to dynamically load the reward function from the module + reward_func_module_name = reward_func_module.split(".")[-1] + reward_func_module = importlib.import_module(reward_func_module) + reward_func = getattr(reward_func_module, reward_func_module_name) + reward_funcs.append(reward_func) + trainer_kwargs["reward_funcs"] = reward_funcs + + return trainer_kwargs diff --git a/src/axolotl/core/trainers/grpo/args.py b/src/axolotl/core/trainers/grpo/args.py new file mode 100644 index 000000000..3be21a4ce --- /dev/null +++ b/src/axolotl/core/trainers/grpo/args.py @@ -0,0 +1,12 @@ +""" +Axolotl Specific Training Args +""" +from trl import GRPOConfig + +from axolotl.core.training_args import AxolotlTrainingMixins + + +class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig): + """ + Axolotl GRPO Config for GRPO training + """ diff --git a/src/axolotl/core/trainers/grpo/trainer.py b/src/axolotl/core/trainers/grpo/trainer.py new file mode 100644 index 000000000..c3cb23f8c --- /dev/null +++ b/src/axolotl/core/trainers/grpo/trainer.py @@ -0,0 +1,12 @@ +""" +Axolotl GRPO trainer +""" +from trl import GRPOTrainer + +from axolotl.core.trainers.base import SchedulerMixin + + +class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer): + """ + Extend the base GRPOTrainer for axolotl helpers + """ diff --git a/src/axolotl/core/training_args.py b/src/axolotl/core/training_args.py index 9eae52162..7cace7643 100644 --- a/src/axolotl/core/training_args.py +++ b/src/axolotl/core/training_args.py @@ -5,7 +5,7 @@ from dataclasses import dataclass, field from typing import Optional from transformers import TrainingArguments -from trl import CPOConfig, DPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig +from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig @dataclass @@ -217,13 +217,6 @@ class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments): """ -@dataclass -class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig): - """ - DPO config for DPO training - """ - - @dataclass class AxolotlORPOConfig(AxolotlTrainingMixins, ORPOConfig): """ diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 028b7ea18..0607f8af1 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -24,6 +24,8 @@ from transformers.utils.import_utils import is_torch_npu_available from axolotl.utils.config.models.internals import EnvCapabilities, GPUCapabilities +from .grpo import GRPOConfig + LOG = logging.getLogger("axolotl.utils.config.models.input") SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"} @@ -33,6 +35,7 @@ class RLType(str, Enum): """RL trainer type configuration subset""" dpo = "dpo" # pylint: disable=invalid-name + grpo = "grpo" # pylint: disable=invalid-name ipo = "ipo" # pylint: disable=invalid-name orpo = "orpo" # pylint: disable=invalid-name kto = "kto" # pylint: disable=invalid-name @@ -645,6 +648,7 @@ class AxolotlInputConfig( MLFlowConfig, CometConfig, LISAConfig, + GRPOConfig, GradioConfig, RayConfig, RemappedParameters, diff --git a/src/axolotl/utils/config/models/input/v0_4_1/grpo.py b/src/axolotl/utils/config/models/input/v0_4_1/grpo.py new file mode 100644 index 000000000..857d93cf9 --- /dev/null +++ b/src/axolotl/utils/config/models/input/v0_4_1/grpo.py @@ -0,0 +1,17 @@ +""" +GRPO specific configuration args +""" +from typing import List, Optional + +from pydantic import BaseModel + + +class GRPOConfig(BaseModel): + """ + Input args for GRPO. + """ + + grpo_use_vllm: Optional[bool] = False + grpo_vllm_device: Optional[str] = "auto" + grpo_vllm_gpu_memory_utilization: Optional[float] = 0.9 + grpo_reward_funcs: Optional[List[str]] = None