refactor a bit for better grpo support

This commit is contained in:
Wing Lian
2025-02-02 23:22:36 -05:00
parent 57264b6491
commit 59ad21f2de
11 changed files with 160 additions and 22 deletions

View File

@@ -13,7 +13,7 @@ liger-kernel==0.5.2
packaging==23.2
peft==0.14.0
transformers==4.48.1
transformers==4.48.2
tokenizers>=0.21.0
accelerate==1.3.0
datasets==3.2.0

View File

@@ -35,6 +35,7 @@ from transformers import (
EarlyStoppingCallback,
TrainerCallback,
)
from trl import DPOConfig
from trl.trainer.utils import RewardDataCollatorWithPadding
from axolotl.core.trainers.base import (
@@ -47,10 +48,10 @@ from axolotl.core.trainers.base import (
AxolotlTrainer,
ReLoRATrainer,
)
from axolotl.core.trainers.dpo_trainer import AxolotlDPOTrainer
from axolotl.core.trainers.dpo import DPOStrategy
from axolotl.core.trainers.grpo import GRPOStrategy
from axolotl.core.training_args import (
AxolotlCPOConfig,
AxolotlDPOConfig,
AxolotlKTOConfig,
AxolotlORPOConfig,
AxolotlPRMConfig,
@@ -1006,16 +1007,13 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
elif self.cfg.rl == "grpo":
training_args_cls = GRPOStrategy.get_training_args_class()
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
else:
training_args_cls = AxolotlDPOConfig
if self.cfg.rl == "ipo":
training_args_kwargs["loss_type"] = "ipo"
training_args_kwargs["max_length"] = self.cfg.sequence_len
training_args_kwargs["max_completion_length"] = None
training_args_kwargs["max_prompt_length"] = self.cfg.sequence_len
training_args_kwargs["generate_during_eval"] = self.cfg.use_wandb
if self.cfg.dpo_use_weighting is not None:
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
training_args_cls = DPOConfig.get_training_args_class()
training_args_kwargs.update(DPOConfig.set_training_args_kwargs(self.cfg))
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
output_dir=self.cfg.output_dir,
@@ -1048,7 +1046,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
"precompute_ref_log_probs"
] = self.cfg.precompute_ref_log_probs
if self.cfg.rl in ["dpo", "ipo"]:
trainer_cls = AxolotlDPOTrainer
trainer_cls = DPOStrategy.get_trainer_class()
trainer_cls_args = [self.model, self.model_ref]
elif self.cfg.rl == "orpo":
trainer_cls = AxolotlORPOTrainer
@@ -1068,7 +1066,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
else:
dpo_trainer_kwargs["tokenizer"] = self.tokenizer
if self.cfg.datasets is not None and (trainer_cls is AxolotlDPOTrainer):
if self.cfg.datasets is not None and (
trainer_cls is DPOStrategy.get_trainer_class()
):
dpo_trainer_kwargs["dataset_tags"] = [
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
]

View File

@@ -0,0 +1,33 @@
"""
DPO Specific Strategy for training
"""
from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer
class DPOStrategy:
"""
Strategy for DPO training
"""
@classmethod
def get_trainer_class(cls):
return AxolotlDPOTrainer
@classmethod
def get_training_args_class(cls):
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
return AxolotlDPOConfig
@classmethod
def set_training_args_kwargs(cls, cfg):
training_args_kwargs = {}
if cfg.rl == "ipo":
training_args_kwargs["loss_type"] = "ipo"
training_args_kwargs["max_length"] = cfg.sequence_len
training_args_kwargs["max_completion_length"] = None
training_args_kwargs["max_prompt_length"] = cfg.sequence_len
training_args_kwargs["generate_during_eval"] = cfg.use_wandb
if cfg.dpo_use_weighting is not None:
training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting
return training_args_kwargs

View File

@@ -0,0 +1,15 @@
"""
Axolotl specific DPO args
"""
from dataclasses import dataclass
from trl import DPOConfig
from axolotl.core.training_args import AxolotlTrainingMixins
@dataclass
class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
"""
DPO config for DPO training
"""

View File

@@ -0,0 +1,52 @@
"""
GRPO Specific Strategy for training
"""
import importlib
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
class GRPOStrategy:
"""
Strategy for GRPO training
"""
@classmethod
def get_trainer_class(cls):
return AxolotlGRPOTrainer
@classmethod
def get_training_args_class(cls):
from axolotl.core.trainers.grpo.args import AxolotlGRPOConfig
return AxolotlGRPOConfig
@classmethod
def set_training_args_kwargs(cls, cfg):
grpo_args_kwargs = {}
if cfg.grpo_use_vllm:
grpo_args_kwargs["use_vllm"] = cfg.grpo_use_vllm
if cfg.grpo_vllm_device:
grpo_args_kwargs["vllm_device"] = cfg.grpo_vllm_device
else:
grpo_args_kwargs["vllm_device"] = "auto"
if cfg.grpo_vllm_gpu_memory_utilization:
grpo_args_kwargs[
"vllm_gpu_memory_utilization"
] = cfg.grpo_vllm_gpu_memory_utilization
return grpo_args_kwargs
@classmethod
def set_trainer_kwargs(cls, cfg):
trainer_kwargs = {}
if cfg.grpo_reward_funcs:
reward_funcs = []
for reward_func_module in cfg.grpo_reward_funcs:
# use importlib to dynamically load the reward function from the module
reward_func_module_name = reward_func_module.split(".")[-1]
reward_func_module = importlib.import_module(reward_func_module)
reward_func = getattr(reward_func_module, reward_func_module_name)
reward_funcs.append(reward_func)
trainer_kwargs["reward_funcs"] = reward_funcs
return trainer_kwargs

View File

@@ -0,0 +1,12 @@
"""
Axolotl Specific Training Args
"""
from trl import GRPOConfig
from axolotl.core.training_args import AxolotlTrainingMixins
class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
"""
Axolotl GRPO Config for GRPO training
"""

View File

@@ -0,0 +1,12 @@
"""
Axolotl GRPO trainer
"""
from trl import GRPOTrainer
from axolotl.core.trainers.base import SchedulerMixin
class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
"""
Extend the base GRPOTrainer for axolotl helpers
"""

View File

@@ -5,7 +5,7 @@ from dataclasses import dataclass, field
from typing import Optional
from transformers import TrainingArguments
from trl import CPOConfig, DPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
@dataclass
@@ -217,13 +217,6 @@ class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
"""
@dataclass
class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
"""
DPO config for DPO training
"""
@dataclass
class AxolotlORPOConfig(AxolotlTrainingMixins, ORPOConfig):
"""

View File

@@ -24,6 +24,8 @@ from transformers.utils.import_utils import is_torch_npu_available
from axolotl.utils.config.models.internals import EnvCapabilities, GPUCapabilities
from .grpo import GRPOConfig
LOG = logging.getLogger("axolotl.utils.config.models.input")
SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"}
@@ -33,6 +35,7 @@ class RLType(str, Enum):
"""RL trainer type configuration subset"""
dpo = "dpo" # pylint: disable=invalid-name
grpo = "grpo" # pylint: disable=invalid-name
ipo = "ipo" # pylint: disable=invalid-name
orpo = "orpo" # pylint: disable=invalid-name
kto = "kto" # pylint: disable=invalid-name
@@ -645,6 +648,7 @@ class AxolotlInputConfig(
MLFlowConfig,
CometConfig,
LISAConfig,
GRPOConfig,
GradioConfig,
RayConfig,
RemappedParameters,

View File

@@ -0,0 +1,17 @@
"""
GRPO specific configuration args
"""
from typing import List, Optional
from pydantic import BaseModel
class GRPOConfig(BaseModel):
"""
Input args for GRPO.
"""
grpo_use_vllm: Optional[bool] = False
grpo_vllm_device: Optional[str] = "auto"
grpo_vllm_gpu_memory_utilization: Optional[float] = 0.9
grpo_reward_funcs: Optional[List[str]] = None