refactor a bit for better grpo support
This commit is contained in:
@@ -13,7 +13,7 @@ liger-kernel==0.5.2
|
||||
packaging==23.2
|
||||
|
||||
peft==0.14.0
|
||||
transformers==4.48.1
|
||||
transformers==4.48.2
|
||||
tokenizers>=0.21.0
|
||||
accelerate==1.3.0
|
||||
datasets==3.2.0
|
||||
|
||||
@@ -35,6 +35,7 @@ from transformers import (
|
||||
EarlyStoppingCallback,
|
||||
TrainerCallback,
|
||||
)
|
||||
from trl import DPOConfig
|
||||
from trl.trainer.utils import RewardDataCollatorWithPadding
|
||||
|
||||
from axolotl.core.trainers.base import (
|
||||
@@ -47,10 +48,10 @@ from axolotl.core.trainers.base import (
|
||||
AxolotlTrainer,
|
||||
ReLoRATrainer,
|
||||
)
|
||||
from axolotl.core.trainers.dpo_trainer import AxolotlDPOTrainer
|
||||
from axolotl.core.trainers.dpo import DPOStrategy
|
||||
from axolotl.core.trainers.grpo import GRPOStrategy
|
||||
from axolotl.core.training_args import (
|
||||
AxolotlCPOConfig,
|
||||
AxolotlDPOConfig,
|
||||
AxolotlKTOConfig,
|
||||
AxolotlORPOConfig,
|
||||
AxolotlPRMConfig,
|
||||
@@ -1006,16 +1007,13 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.max_prompt_len:
|
||||
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
||||
|
||||
elif self.cfg.rl == "grpo":
|
||||
training_args_cls = GRPOStrategy.get_training_args_class()
|
||||
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
|
||||
|
||||
else:
|
||||
training_args_cls = AxolotlDPOConfig
|
||||
if self.cfg.rl == "ipo":
|
||||
training_args_kwargs["loss_type"] = "ipo"
|
||||
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||
training_args_kwargs["max_completion_length"] = None
|
||||
training_args_kwargs["max_prompt_length"] = self.cfg.sequence_len
|
||||
training_args_kwargs["generate_during_eval"] = self.cfg.use_wandb
|
||||
if self.cfg.dpo_use_weighting is not None:
|
||||
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
|
||||
training_args_cls = DPOConfig.get_training_args_class()
|
||||
training_args_kwargs.update(DPOConfig.set_training_args_kwargs(self.cfg))
|
||||
|
||||
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
|
||||
output_dir=self.cfg.output_dir,
|
||||
@@ -1048,7 +1046,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
"precompute_ref_log_probs"
|
||||
] = self.cfg.precompute_ref_log_probs
|
||||
if self.cfg.rl in ["dpo", "ipo"]:
|
||||
trainer_cls = AxolotlDPOTrainer
|
||||
trainer_cls = DPOStrategy.get_trainer_class()
|
||||
trainer_cls_args = [self.model, self.model_ref]
|
||||
elif self.cfg.rl == "orpo":
|
||||
trainer_cls = AxolotlORPOTrainer
|
||||
@@ -1068,7 +1066,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
else:
|
||||
dpo_trainer_kwargs["tokenizer"] = self.tokenizer
|
||||
|
||||
if self.cfg.datasets is not None and (trainer_cls is AxolotlDPOTrainer):
|
||||
if self.cfg.datasets is not None and (
|
||||
trainer_cls is DPOStrategy.get_trainer_class()
|
||||
):
|
||||
dpo_trainer_kwargs["dataset_tags"] = [
|
||||
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
|
||||
]
|
||||
|
||||
33
src/axolotl/core/trainers/dpo/__init__.py
Normal file
33
src/axolotl/core/trainers/dpo/__init__.py
Normal file
@@ -0,0 +1,33 @@
|
||||
"""
|
||||
DPO Specific Strategy for training
|
||||
"""
|
||||
from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer
|
||||
|
||||
|
||||
class DPOStrategy:
|
||||
"""
|
||||
Strategy for DPO training
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_trainer_class(cls):
|
||||
return AxolotlDPOTrainer
|
||||
|
||||
@classmethod
|
||||
def get_training_args_class(cls):
|
||||
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
|
||||
|
||||
return AxolotlDPOConfig
|
||||
|
||||
@classmethod
|
||||
def set_training_args_kwargs(cls, cfg):
|
||||
training_args_kwargs = {}
|
||||
if cfg.rl == "ipo":
|
||||
training_args_kwargs["loss_type"] = "ipo"
|
||||
training_args_kwargs["max_length"] = cfg.sequence_len
|
||||
training_args_kwargs["max_completion_length"] = None
|
||||
training_args_kwargs["max_prompt_length"] = cfg.sequence_len
|
||||
training_args_kwargs["generate_during_eval"] = cfg.use_wandb
|
||||
if cfg.dpo_use_weighting is not None:
|
||||
training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting
|
||||
return training_args_kwargs
|
||||
15
src/axolotl/core/trainers/dpo/args.py
Normal file
15
src/axolotl/core/trainers/dpo/args.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""
|
||||
Axolotl specific DPO args
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
|
||||
from trl import DPOConfig
|
||||
|
||||
from axolotl.core.training_args import AxolotlTrainingMixins
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
|
||||
"""
|
||||
DPO config for DPO training
|
||||
"""
|
||||
52
src/axolotl/core/trainers/grpo/__init__.py
Normal file
52
src/axolotl/core/trainers/grpo/__init__.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""
|
||||
GRPO Specific Strategy for training
|
||||
"""
|
||||
import importlib
|
||||
|
||||
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
|
||||
|
||||
|
||||
class GRPOStrategy:
|
||||
"""
|
||||
Strategy for GRPO training
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_trainer_class(cls):
|
||||
return AxolotlGRPOTrainer
|
||||
|
||||
@classmethod
|
||||
def get_training_args_class(cls):
|
||||
from axolotl.core.trainers.grpo.args import AxolotlGRPOConfig
|
||||
|
||||
return AxolotlGRPOConfig
|
||||
|
||||
@classmethod
|
||||
def set_training_args_kwargs(cls, cfg):
|
||||
grpo_args_kwargs = {}
|
||||
if cfg.grpo_use_vllm:
|
||||
grpo_args_kwargs["use_vllm"] = cfg.grpo_use_vllm
|
||||
if cfg.grpo_vllm_device:
|
||||
grpo_args_kwargs["vllm_device"] = cfg.grpo_vllm_device
|
||||
else:
|
||||
grpo_args_kwargs["vllm_device"] = "auto"
|
||||
if cfg.grpo_vllm_gpu_memory_utilization:
|
||||
grpo_args_kwargs[
|
||||
"vllm_gpu_memory_utilization"
|
||||
] = cfg.grpo_vllm_gpu_memory_utilization
|
||||
return grpo_args_kwargs
|
||||
|
||||
@classmethod
|
||||
def set_trainer_kwargs(cls, cfg):
|
||||
trainer_kwargs = {}
|
||||
if cfg.grpo_reward_funcs:
|
||||
reward_funcs = []
|
||||
for reward_func_module in cfg.grpo_reward_funcs:
|
||||
# use importlib to dynamically load the reward function from the module
|
||||
reward_func_module_name = reward_func_module.split(".")[-1]
|
||||
reward_func_module = importlib.import_module(reward_func_module)
|
||||
reward_func = getattr(reward_func_module, reward_func_module_name)
|
||||
reward_funcs.append(reward_func)
|
||||
trainer_kwargs["reward_funcs"] = reward_funcs
|
||||
|
||||
return trainer_kwargs
|
||||
12
src/axolotl/core/trainers/grpo/args.py
Normal file
12
src/axolotl/core/trainers/grpo/args.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""
|
||||
Axolotl Specific Training Args
|
||||
"""
|
||||
from trl import GRPOConfig
|
||||
|
||||
from axolotl.core.training_args import AxolotlTrainingMixins
|
||||
|
||||
|
||||
class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
|
||||
"""
|
||||
Axolotl GRPO Config for GRPO training
|
||||
"""
|
||||
12
src/axolotl/core/trainers/grpo/trainer.py
Normal file
12
src/axolotl/core/trainers/grpo/trainer.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""
|
||||
Axolotl GRPO trainer
|
||||
"""
|
||||
from trl import GRPOTrainer
|
||||
|
||||
from axolotl.core.trainers.base import SchedulerMixin
|
||||
|
||||
|
||||
class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
|
||||
"""
|
||||
Extend the base GRPOTrainer for axolotl helpers
|
||||
"""
|
||||
@@ -5,7 +5,7 @@ from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from transformers import TrainingArguments
|
||||
from trl import CPOConfig, DPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
|
||||
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -217,13 +217,6 @@ class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
|
||||
"""
|
||||
DPO config for DPO training
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlORPOConfig(AxolotlTrainingMixins, ORPOConfig):
|
||||
"""
|
||||
|
||||
@@ -24,6 +24,8 @@ from transformers.utils.import_utils import is_torch_npu_available
|
||||
|
||||
from axolotl.utils.config.models.internals import EnvCapabilities, GPUCapabilities
|
||||
|
||||
from .grpo import GRPOConfig
|
||||
|
||||
LOG = logging.getLogger("axolotl.utils.config.models.input")
|
||||
|
||||
SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"}
|
||||
@@ -33,6 +35,7 @@ class RLType(str, Enum):
|
||||
"""RL trainer type configuration subset"""
|
||||
|
||||
dpo = "dpo" # pylint: disable=invalid-name
|
||||
grpo = "grpo" # pylint: disable=invalid-name
|
||||
ipo = "ipo" # pylint: disable=invalid-name
|
||||
orpo = "orpo" # pylint: disable=invalid-name
|
||||
kto = "kto" # pylint: disable=invalid-name
|
||||
@@ -645,6 +648,7 @@ class AxolotlInputConfig(
|
||||
MLFlowConfig,
|
||||
CometConfig,
|
||||
LISAConfig,
|
||||
GRPOConfig,
|
||||
GradioConfig,
|
||||
RayConfig,
|
||||
RemappedParameters,
|
||||
|
||||
17
src/axolotl/utils/config/models/input/v0_4_1/grpo.py
Normal file
17
src/axolotl/utils/config/models/input/v0_4_1/grpo.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""
|
||||
GRPO specific configuration args
|
||||
"""
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class GRPOConfig(BaseModel):
|
||||
"""
|
||||
Input args for GRPO.
|
||||
"""
|
||||
|
||||
grpo_use_vllm: Optional[bool] = False
|
||||
grpo_vllm_device: Optional[str] = "auto"
|
||||
grpo_vllm_gpu_memory_utilization: Optional[float] = 0.9
|
||||
grpo_reward_funcs: Optional[List[str]] = None
|
||||
Reference in New Issue
Block a user