refactor a bit for better grpo support

This commit is contained in:
Wing Lian
2025-02-02 23:22:36 -05:00
parent 57264b6491
commit 59ad21f2de
11 changed files with 160 additions and 22 deletions

View File

@@ -13,7 +13,7 @@ liger-kernel==0.5.2
packaging==23.2 packaging==23.2
peft==0.14.0 peft==0.14.0
transformers==4.48.1 transformers==4.48.2
tokenizers>=0.21.0 tokenizers>=0.21.0
accelerate==1.3.0 accelerate==1.3.0
datasets==3.2.0 datasets==3.2.0

View File

@@ -35,6 +35,7 @@ from transformers import (
EarlyStoppingCallback, EarlyStoppingCallback,
TrainerCallback, TrainerCallback,
) )
from trl import DPOConfig
from trl.trainer.utils import RewardDataCollatorWithPadding from trl.trainer.utils import RewardDataCollatorWithPadding
from axolotl.core.trainers.base import ( from axolotl.core.trainers.base import (
@@ -47,10 +48,10 @@ from axolotl.core.trainers.base import (
AxolotlTrainer, AxolotlTrainer,
ReLoRATrainer, ReLoRATrainer,
) )
from axolotl.core.trainers.dpo_trainer import AxolotlDPOTrainer from axolotl.core.trainers.dpo import DPOStrategy
from axolotl.core.trainers.grpo import GRPOStrategy
from axolotl.core.training_args import ( from axolotl.core.training_args import (
AxolotlCPOConfig, AxolotlCPOConfig,
AxolotlDPOConfig,
AxolotlKTOConfig, AxolotlKTOConfig,
AxolotlORPOConfig, AxolotlORPOConfig,
AxolotlPRMConfig, AxolotlPRMConfig,
@@ -1006,16 +1007,13 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.max_prompt_len: if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
elif self.cfg.rl == "grpo":
training_args_cls = GRPOStrategy.get_training_args_class()
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
else: else:
training_args_cls = AxolotlDPOConfig training_args_cls = DPOConfig.get_training_args_class()
if self.cfg.rl == "ipo": training_args_kwargs.update(DPOConfig.set_training_args_kwargs(self.cfg))
training_args_kwargs["loss_type"] = "ipo"
training_args_kwargs["max_length"] = self.cfg.sequence_len
training_args_kwargs["max_completion_length"] = None
training_args_kwargs["max_prompt_length"] = self.cfg.sequence_len
training_args_kwargs["generate_during_eval"] = self.cfg.use_wandb
if self.cfg.dpo_use_weighting is not None:
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
output_dir=self.cfg.output_dir, output_dir=self.cfg.output_dir,
@@ -1048,7 +1046,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
"precompute_ref_log_probs" "precompute_ref_log_probs"
] = self.cfg.precompute_ref_log_probs ] = self.cfg.precompute_ref_log_probs
if self.cfg.rl in ["dpo", "ipo"]: if self.cfg.rl in ["dpo", "ipo"]:
trainer_cls = AxolotlDPOTrainer trainer_cls = DPOStrategy.get_trainer_class()
trainer_cls_args = [self.model, self.model_ref] trainer_cls_args = [self.model, self.model_ref]
elif self.cfg.rl == "orpo": elif self.cfg.rl == "orpo":
trainer_cls = AxolotlORPOTrainer trainer_cls = AxolotlORPOTrainer
@@ -1068,7 +1066,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
else: else:
dpo_trainer_kwargs["tokenizer"] = self.tokenizer dpo_trainer_kwargs["tokenizer"] = self.tokenizer
if self.cfg.datasets is not None and (trainer_cls is AxolotlDPOTrainer): if self.cfg.datasets is not None and (
trainer_cls is DPOStrategy.get_trainer_class()
):
dpo_trainer_kwargs["dataset_tags"] = [ dpo_trainer_kwargs["dataset_tags"] = [
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir() d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
] ]

View File

@@ -0,0 +1,33 @@
"""
DPO Specific Strategy for training
"""
from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer
class DPOStrategy:
"""
Strategy for DPO training
"""
@classmethod
def get_trainer_class(cls):
return AxolotlDPOTrainer
@classmethod
def get_training_args_class(cls):
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
return AxolotlDPOConfig
@classmethod
def set_training_args_kwargs(cls, cfg):
training_args_kwargs = {}
if cfg.rl == "ipo":
training_args_kwargs["loss_type"] = "ipo"
training_args_kwargs["max_length"] = cfg.sequence_len
training_args_kwargs["max_completion_length"] = None
training_args_kwargs["max_prompt_length"] = cfg.sequence_len
training_args_kwargs["generate_during_eval"] = cfg.use_wandb
if cfg.dpo_use_weighting is not None:
training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting
return training_args_kwargs

View File

@@ -0,0 +1,15 @@
"""
Axolotl specific DPO args
"""
from dataclasses import dataclass
from trl import DPOConfig
from axolotl.core.training_args import AxolotlTrainingMixins
@dataclass
class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
"""
DPO config for DPO training
"""

View File

@@ -0,0 +1,52 @@
"""
GRPO Specific Strategy for training
"""
import importlib
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
class GRPOStrategy:
"""
Strategy for GRPO training
"""
@classmethod
def get_trainer_class(cls):
return AxolotlGRPOTrainer
@classmethod
def get_training_args_class(cls):
from axolotl.core.trainers.grpo.args import AxolotlGRPOConfig
return AxolotlGRPOConfig
@classmethod
def set_training_args_kwargs(cls, cfg):
grpo_args_kwargs = {}
if cfg.grpo_use_vllm:
grpo_args_kwargs["use_vllm"] = cfg.grpo_use_vllm
if cfg.grpo_vllm_device:
grpo_args_kwargs["vllm_device"] = cfg.grpo_vllm_device
else:
grpo_args_kwargs["vllm_device"] = "auto"
if cfg.grpo_vllm_gpu_memory_utilization:
grpo_args_kwargs[
"vllm_gpu_memory_utilization"
] = cfg.grpo_vllm_gpu_memory_utilization
return grpo_args_kwargs
@classmethod
def set_trainer_kwargs(cls, cfg):
trainer_kwargs = {}
if cfg.grpo_reward_funcs:
reward_funcs = []
for reward_func_module in cfg.grpo_reward_funcs:
# use importlib to dynamically load the reward function from the module
reward_func_module_name = reward_func_module.split(".")[-1]
reward_func_module = importlib.import_module(reward_func_module)
reward_func = getattr(reward_func_module, reward_func_module_name)
reward_funcs.append(reward_func)
trainer_kwargs["reward_funcs"] = reward_funcs
return trainer_kwargs

View File

@@ -0,0 +1,12 @@
"""
Axolotl Specific Training Args
"""
from trl import GRPOConfig
from axolotl.core.training_args import AxolotlTrainingMixins
class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
"""
Axolotl GRPO Config for GRPO training
"""

View File

@@ -0,0 +1,12 @@
"""
Axolotl GRPO trainer
"""
from trl import GRPOTrainer
from axolotl.core.trainers.base import SchedulerMixin
class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
"""
Extend the base GRPOTrainer for axolotl helpers
"""

View File

@@ -5,7 +5,7 @@ from dataclasses import dataclass, field
from typing import Optional from typing import Optional
from transformers import TrainingArguments from transformers import TrainingArguments
from trl import CPOConfig, DPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
@dataclass @dataclass
@@ -217,13 +217,6 @@ class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
""" """
@dataclass
class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
"""
DPO config for DPO training
"""
@dataclass @dataclass
class AxolotlORPOConfig(AxolotlTrainingMixins, ORPOConfig): class AxolotlORPOConfig(AxolotlTrainingMixins, ORPOConfig):
""" """

View File

@@ -24,6 +24,8 @@ from transformers.utils.import_utils import is_torch_npu_available
from axolotl.utils.config.models.internals import EnvCapabilities, GPUCapabilities from axolotl.utils.config.models.internals import EnvCapabilities, GPUCapabilities
from .grpo import GRPOConfig
LOG = logging.getLogger("axolotl.utils.config.models.input") LOG = logging.getLogger("axolotl.utils.config.models.input")
SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"} SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"}
@@ -33,6 +35,7 @@ class RLType(str, Enum):
"""RL trainer type configuration subset""" """RL trainer type configuration subset"""
dpo = "dpo" # pylint: disable=invalid-name dpo = "dpo" # pylint: disable=invalid-name
grpo = "grpo" # pylint: disable=invalid-name
ipo = "ipo" # pylint: disable=invalid-name ipo = "ipo" # pylint: disable=invalid-name
orpo = "orpo" # pylint: disable=invalid-name orpo = "orpo" # pylint: disable=invalid-name
kto = "kto" # pylint: disable=invalid-name kto = "kto" # pylint: disable=invalid-name
@@ -645,6 +648,7 @@ class AxolotlInputConfig(
MLFlowConfig, MLFlowConfig,
CometConfig, CometConfig,
LISAConfig, LISAConfig,
GRPOConfig,
GradioConfig, GradioConfig,
RayConfig, RayConfig,
RemappedParameters, RemappedParameters,

View File

@@ -0,0 +1,17 @@
"""
GRPO specific configuration args
"""
from typing import List, Optional
from pydantic import BaseModel
class GRPOConfig(BaseModel):
"""
Input args for GRPO.
"""
grpo_use_vllm: Optional[bool] = False
grpo_vllm_device: Optional[str] = "auto"
grpo_vllm_gpu_memory_utilization: Optional[float] = 0.9
grpo_reward_funcs: Optional[List[str]] = None