refactor a bit for better grpo support
This commit is contained in:
@@ -13,7 +13,7 @@ liger-kernel==0.5.2
|
|||||||
packaging==23.2
|
packaging==23.2
|
||||||
|
|
||||||
peft==0.14.0
|
peft==0.14.0
|
||||||
transformers==4.48.1
|
transformers==4.48.2
|
||||||
tokenizers>=0.21.0
|
tokenizers>=0.21.0
|
||||||
accelerate==1.3.0
|
accelerate==1.3.0
|
||||||
datasets==3.2.0
|
datasets==3.2.0
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ from transformers import (
|
|||||||
EarlyStoppingCallback,
|
EarlyStoppingCallback,
|
||||||
TrainerCallback,
|
TrainerCallback,
|
||||||
)
|
)
|
||||||
|
from trl import DPOConfig
|
||||||
from trl.trainer.utils import RewardDataCollatorWithPadding
|
from trl.trainer.utils import RewardDataCollatorWithPadding
|
||||||
|
|
||||||
from axolotl.core.trainers.base import (
|
from axolotl.core.trainers.base import (
|
||||||
@@ -47,10 +48,10 @@ from axolotl.core.trainers.base import (
|
|||||||
AxolotlTrainer,
|
AxolotlTrainer,
|
||||||
ReLoRATrainer,
|
ReLoRATrainer,
|
||||||
)
|
)
|
||||||
from axolotl.core.trainers.dpo_trainer import AxolotlDPOTrainer
|
from axolotl.core.trainers.dpo import DPOStrategy
|
||||||
|
from axolotl.core.trainers.grpo import GRPOStrategy
|
||||||
from axolotl.core.training_args import (
|
from axolotl.core.training_args import (
|
||||||
AxolotlCPOConfig,
|
AxolotlCPOConfig,
|
||||||
AxolotlDPOConfig,
|
|
||||||
AxolotlKTOConfig,
|
AxolotlKTOConfig,
|
||||||
AxolotlORPOConfig,
|
AxolotlORPOConfig,
|
||||||
AxolotlPRMConfig,
|
AxolotlPRMConfig,
|
||||||
@@ -1006,16 +1007,13 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.max_prompt_len:
|
if self.cfg.max_prompt_len:
|
||||||
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
||||||
|
|
||||||
|
elif self.cfg.rl == "grpo":
|
||||||
|
training_args_cls = GRPOStrategy.get_training_args_class()
|
||||||
|
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
training_args_cls = AxolotlDPOConfig
|
training_args_cls = DPOConfig.get_training_args_class()
|
||||||
if self.cfg.rl == "ipo":
|
training_args_kwargs.update(DPOConfig.set_training_args_kwargs(self.cfg))
|
||||||
training_args_kwargs["loss_type"] = "ipo"
|
|
||||||
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
|
||||||
training_args_kwargs["max_completion_length"] = None
|
|
||||||
training_args_kwargs["max_prompt_length"] = self.cfg.sequence_len
|
|
||||||
training_args_kwargs["generate_during_eval"] = self.cfg.use_wandb
|
|
||||||
if self.cfg.dpo_use_weighting is not None:
|
|
||||||
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
|
|
||||||
|
|
||||||
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
|
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
|
||||||
output_dir=self.cfg.output_dir,
|
output_dir=self.cfg.output_dir,
|
||||||
@@ -1048,7 +1046,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
"precompute_ref_log_probs"
|
"precompute_ref_log_probs"
|
||||||
] = self.cfg.precompute_ref_log_probs
|
] = self.cfg.precompute_ref_log_probs
|
||||||
if self.cfg.rl in ["dpo", "ipo"]:
|
if self.cfg.rl in ["dpo", "ipo"]:
|
||||||
trainer_cls = AxolotlDPOTrainer
|
trainer_cls = DPOStrategy.get_trainer_class()
|
||||||
trainer_cls_args = [self.model, self.model_ref]
|
trainer_cls_args = [self.model, self.model_ref]
|
||||||
elif self.cfg.rl == "orpo":
|
elif self.cfg.rl == "orpo":
|
||||||
trainer_cls = AxolotlORPOTrainer
|
trainer_cls = AxolotlORPOTrainer
|
||||||
@@ -1068,7 +1066,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
else:
|
else:
|
||||||
dpo_trainer_kwargs["tokenizer"] = self.tokenizer
|
dpo_trainer_kwargs["tokenizer"] = self.tokenizer
|
||||||
|
|
||||||
if self.cfg.datasets is not None and (trainer_cls is AxolotlDPOTrainer):
|
if self.cfg.datasets is not None and (
|
||||||
|
trainer_cls is DPOStrategy.get_trainer_class()
|
||||||
|
):
|
||||||
dpo_trainer_kwargs["dataset_tags"] = [
|
dpo_trainer_kwargs["dataset_tags"] = [
|
||||||
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
|
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
|
||||||
]
|
]
|
||||||
|
|||||||
33
src/axolotl/core/trainers/dpo/__init__.py
Normal file
33
src/axolotl/core/trainers/dpo/__init__.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
"""
|
||||||
|
DPO Specific Strategy for training
|
||||||
|
"""
|
||||||
|
from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer
|
||||||
|
|
||||||
|
|
||||||
|
class DPOStrategy:
|
||||||
|
"""
|
||||||
|
Strategy for DPO training
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_trainer_class(cls):
|
||||||
|
return AxolotlDPOTrainer
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_training_args_class(cls):
|
||||||
|
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
|
||||||
|
|
||||||
|
return AxolotlDPOConfig
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def set_training_args_kwargs(cls, cfg):
|
||||||
|
training_args_kwargs = {}
|
||||||
|
if cfg.rl == "ipo":
|
||||||
|
training_args_kwargs["loss_type"] = "ipo"
|
||||||
|
training_args_kwargs["max_length"] = cfg.sequence_len
|
||||||
|
training_args_kwargs["max_completion_length"] = None
|
||||||
|
training_args_kwargs["max_prompt_length"] = cfg.sequence_len
|
||||||
|
training_args_kwargs["generate_during_eval"] = cfg.use_wandb
|
||||||
|
if cfg.dpo_use_weighting is not None:
|
||||||
|
training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting
|
||||||
|
return training_args_kwargs
|
||||||
15
src/axolotl/core/trainers/dpo/args.py
Normal file
15
src/axolotl/core/trainers/dpo/args.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
"""
|
||||||
|
Axolotl specific DPO args
|
||||||
|
"""
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from trl import DPOConfig
|
||||||
|
|
||||||
|
from axolotl.core.training_args import AxolotlTrainingMixins
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
|
||||||
|
"""
|
||||||
|
DPO config for DPO training
|
||||||
|
"""
|
||||||
52
src/axolotl/core/trainers/grpo/__init__.py
Normal file
52
src/axolotl/core/trainers/grpo/__init__.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
"""
|
||||||
|
GRPO Specific Strategy for training
|
||||||
|
"""
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
|
||||||
|
|
||||||
|
|
||||||
|
class GRPOStrategy:
|
||||||
|
"""
|
||||||
|
Strategy for GRPO training
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_trainer_class(cls):
|
||||||
|
return AxolotlGRPOTrainer
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_training_args_class(cls):
|
||||||
|
from axolotl.core.trainers.grpo.args import AxolotlGRPOConfig
|
||||||
|
|
||||||
|
return AxolotlGRPOConfig
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def set_training_args_kwargs(cls, cfg):
|
||||||
|
grpo_args_kwargs = {}
|
||||||
|
if cfg.grpo_use_vllm:
|
||||||
|
grpo_args_kwargs["use_vllm"] = cfg.grpo_use_vllm
|
||||||
|
if cfg.grpo_vllm_device:
|
||||||
|
grpo_args_kwargs["vllm_device"] = cfg.grpo_vllm_device
|
||||||
|
else:
|
||||||
|
grpo_args_kwargs["vllm_device"] = "auto"
|
||||||
|
if cfg.grpo_vllm_gpu_memory_utilization:
|
||||||
|
grpo_args_kwargs[
|
||||||
|
"vllm_gpu_memory_utilization"
|
||||||
|
] = cfg.grpo_vllm_gpu_memory_utilization
|
||||||
|
return grpo_args_kwargs
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def set_trainer_kwargs(cls, cfg):
|
||||||
|
trainer_kwargs = {}
|
||||||
|
if cfg.grpo_reward_funcs:
|
||||||
|
reward_funcs = []
|
||||||
|
for reward_func_module in cfg.grpo_reward_funcs:
|
||||||
|
# use importlib to dynamically load the reward function from the module
|
||||||
|
reward_func_module_name = reward_func_module.split(".")[-1]
|
||||||
|
reward_func_module = importlib.import_module(reward_func_module)
|
||||||
|
reward_func = getattr(reward_func_module, reward_func_module_name)
|
||||||
|
reward_funcs.append(reward_func)
|
||||||
|
trainer_kwargs["reward_funcs"] = reward_funcs
|
||||||
|
|
||||||
|
return trainer_kwargs
|
||||||
12
src/axolotl/core/trainers/grpo/args.py
Normal file
12
src/axolotl/core/trainers/grpo/args.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
"""
|
||||||
|
Axolotl Specific Training Args
|
||||||
|
"""
|
||||||
|
from trl import GRPOConfig
|
||||||
|
|
||||||
|
from axolotl.core.training_args import AxolotlTrainingMixins
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
|
||||||
|
"""
|
||||||
|
Axolotl GRPO Config for GRPO training
|
||||||
|
"""
|
||||||
12
src/axolotl/core/trainers/grpo/trainer.py
Normal file
12
src/axolotl/core/trainers/grpo/trainer.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
"""
|
||||||
|
Axolotl GRPO trainer
|
||||||
|
"""
|
||||||
|
from trl import GRPOTrainer
|
||||||
|
|
||||||
|
from axolotl.core.trainers.base import SchedulerMixin
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
|
||||||
|
"""
|
||||||
|
Extend the base GRPOTrainer for axolotl helpers
|
||||||
|
"""
|
||||||
@@ -5,7 +5,7 @@ from dataclasses import dataclass, field
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from transformers import TrainingArguments
|
from transformers import TrainingArguments
|
||||||
from trl import CPOConfig, DPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
|
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -217,13 +217,6 @@ class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
|
|
||||||
"""
|
|
||||||
DPO config for DPO training
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AxolotlORPOConfig(AxolotlTrainingMixins, ORPOConfig):
|
class AxolotlORPOConfig(AxolotlTrainingMixins, ORPOConfig):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -24,6 +24,8 @@ from transformers.utils.import_utils import is_torch_npu_available
|
|||||||
|
|
||||||
from axolotl.utils.config.models.internals import EnvCapabilities, GPUCapabilities
|
from axolotl.utils.config.models.internals import EnvCapabilities, GPUCapabilities
|
||||||
|
|
||||||
|
from .grpo import GRPOConfig
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.utils.config.models.input")
|
LOG = logging.getLogger("axolotl.utils.config.models.input")
|
||||||
|
|
||||||
SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"}
|
SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"}
|
||||||
@@ -33,6 +35,7 @@ class RLType(str, Enum):
|
|||||||
"""RL trainer type configuration subset"""
|
"""RL trainer type configuration subset"""
|
||||||
|
|
||||||
dpo = "dpo" # pylint: disable=invalid-name
|
dpo = "dpo" # pylint: disable=invalid-name
|
||||||
|
grpo = "grpo" # pylint: disable=invalid-name
|
||||||
ipo = "ipo" # pylint: disable=invalid-name
|
ipo = "ipo" # pylint: disable=invalid-name
|
||||||
orpo = "orpo" # pylint: disable=invalid-name
|
orpo = "orpo" # pylint: disable=invalid-name
|
||||||
kto = "kto" # pylint: disable=invalid-name
|
kto = "kto" # pylint: disable=invalid-name
|
||||||
@@ -645,6 +648,7 @@ class AxolotlInputConfig(
|
|||||||
MLFlowConfig,
|
MLFlowConfig,
|
||||||
CometConfig,
|
CometConfig,
|
||||||
LISAConfig,
|
LISAConfig,
|
||||||
|
GRPOConfig,
|
||||||
GradioConfig,
|
GradioConfig,
|
||||||
RayConfig,
|
RayConfig,
|
||||||
RemappedParameters,
|
RemappedParameters,
|
||||||
|
|||||||
17
src/axolotl/utils/config/models/input/v0_4_1/grpo.py
Normal file
17
src/axolotl/utils/config/models/input/v0_4_1/grpo.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
"""
|
||||||
|
GRPO specific configuration args
|
||||||
|
"""
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class GRPOConfig(BaseModel):
|
||||||
|
"""
|
||||||
|
Input args for GRPO.
|
||||||
|
"""
|
||||||
|
|
||||||
|
grpo_use_vllm: Optional[bool] = False
|
||||||
|
grpo_vllm_device: Optional[str] = "auto"
|
||||||
|
grpo_vllm_gpu_memory_utilization: Optional[float] = 0.9
|
||||||
|
grpo_reward_funcs: Optional[List[str]] = None
|
||||||
Reference in New Issue
Block a user