GRPO (#2307)
This commit is contained in:
2
.github/workflows/main.yml
vendored
2
.github/workflows/main.yml
vendored
@@ -24,7 +24,7 @@ jobs:
|
|||||||
cuda_version: 12.4.1
|
cuda_version: 12.4.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.5.1
|
pytorch: 2.5.1
|
||||||
axolotl_extras:
|
axolotl_extras: vllm
|
||||||
is_latest: true
|
is_latest: true
|
||||||
- cuda: 124
|
- cuda: 124
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.4.1
|
||||||
|
|||||||
5
.github/workflows/multi-gpu-e2e.yml
vendored
5
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -24,20 +24,21 @@ jobs:
|
|||||||
cuda_version: 12.4.1
|
cuda_version: 12.4.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.4.1
|
pytorch: 2.4.1
|
||||||
axolotl_extras:
|
axolotl_extras: # no vllm support for 2.4.1
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
nightly_build: "true"
|
nightly_build: "true"
|
||||||
- cuda: 124
|
- cuda: 124
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.4.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.5.1
|
pytorch: 2.5.1
|
||||||
axolotl_extras:
|
axolotl_extras: vllm
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
nightly_build: "true"
|
nightly_build: "true"
|
||||||
- cuda: 124
|
- cuda: 124
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.4.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
|
# awaiting vllm#12721
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
nightly_build: "true"
|
nightly_build: "true"
|
||||||
|
|||||||
2
.github/workflows/tests.yml
vendored
2
.github/workflows/tests.yml
vendored
@@ -204,7 +204,7 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.5.1
|
pytorch: 2.5.1
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
axolotl_extras:
|
axolotl_extras: vllm
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ tokenizers>=0.21.0
|
|||||||
accelerate==1.3.0
|
accelerate==1.3.0
|
||||||
datasets==3.2.0
|
datasets==3.2.0
|
||||||
deepspeed==0.16.1
|
deepspeed==0.16.1
|
||||||
trl==0.13.0
|
trl==0.15.0
|
||||||
|
|
||||||
optimum==1.16.2
|
optimum==1.16.2
|
||||||
hf_transfer
|
hf_transfer
|
||||||
@@ -26,7 +26,7 @@ sentencepiece
|
|||||||
gradio==3.50.2
|
gradio==3.50.2
|
||||||
|
|
||||||
modal==0.70.5
|
modal==0.70.5
|
||||||
pydantic==2.6.3
|
pydantic==2.10.6
|
||||||
addict
|
addict
|
||||||
fire
|
fire
|
||||||
PyYAML>=6.0
|
PyYAML>=6.0
|
||||||
|
|||||||
7
setup.py
7
setup.py
@@ -79,7 +79,7 @@ def parse_requirements():
|
|||||||
if patch == 0:
|
if patch == 0:
|
||||||
_install_requires.append("xformers==0.0.28.post2")
|
_install_requires.append("xformers==0.0.28.post2")
|
||||||
else:
|
else:
|
||||||
_install_requires.append("xformers==0.0.29")
|
_install_requires.append("xformers>=0.0.28.post3")
|
||||||
_install_requires.pop(_install_requires.index(autoawq_version))
|
_install_requires.pop(_install_requires.index(autoawq_version))
|
||||||
elif (major, minor) >= (2, 4):
|
elif (major, minor) >= (2, 4):
|
||||||
if patch == 0:
|
if patch == 0:
|
||||||
@@ -125,7 +125,7 @@ setup(
|
|||||||
},
|
},
|
||||||
extras_require={
|
extras_require={
|
||||||
"flash-attn": [
|
"flash-attn": [
|
||||||
"flash-attn==2.7.0.post2",
|
"flash-attn==2.7.4.post1",
|
||||||
],
|
],
|
||||||
"deepspeed": [
|
"deepspeed": [
|
||||||
"deepspeed==0.16.1",
|
"deepspeed==0.16.1",
|
||||||
@@ -156,5 +156,8 @@ setup(
|
|||||||
"ray": [
|
"ray": [
|
||||||
"ray[train]",
|
"ray[train]",
|
||||||
],
|
],
|
||||||
|
"vllm": [
|
||||||
|
"vllm==0.7.2",
|
||||||
|
],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -35,13 +35,18 @@ def do_cli_train(
|
|||||||
cloud_config: Union[Path, str],
|
cloud_config: Union[Path, str],
|
||||||
config: Union[Path, str],
|
config: Union[Path, str],
|
||||||
accelerate: bool = True,
|
accelerate: bool = True,
|
||||||
|
cwd=None,
|
||||||
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
print_axolotl_text_art()
|
print_axolotl_text_art()
|
||||||
cloud_cfg = load_cloud_cfg(cloud_config)
|
cloud_cfg = load_cloud_cfg(cloud_config)
|
||||||
cloud = ModalCloud(cloud_cfg)
|
cloud = ModalCloud(cloud_cfg)
|
||||||
with open(config, "r", encoding="utf-8") as file:
|
with open(config, "r", encoding="utf-8") as file:
|
||||||
config_yaml = file.read()
|
config_yaml = file.read()
|
||||||
cloud.train(config_yaml, accelerate=accelerate)
|
local_dirs = {}
|
||||||
|
if cwd and not Path(cwd).joinpath("src", "axolotl").exists():
|
||||||
|
local_dirs = {"/workspace/mounts": cwd}
|
||||||
|
cloud.train(config_yaml, accelerate=accelerate, local_dirs=local_dirs, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def do_cli_lm_eval(
|
def do_cli_lm_eval(
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import os
|
|||||||
import subprocess # nosec B404
|
import subprocess # nosec B404
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from random import randint
|
from random import randint
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import modal
|
import modal
|
||||||
|
|
||||||
@@ -22,8 +23,18 @@ def run_cmd(cmd: str, run_folder: str, volumes=None):
|
|||||||
|
|
||||||
# modal workaround so it doesn't use the automounted axolotl
|
# modal workaround so it doesn't use the automounted axolotl
|
||||||
new_env = copy.deepcopy(os.environ)
|
new_env = copy.deepcopy(os.environ)
|
||||||
|
|
||||||
if "PYTHONPATH" in new_env:
|
if "PYTHONPATH" in new_env:
|
||||||
del new_env["PYTHONPATH"]
|
paths = ["/workspace/mounts"]
|
||||||
|
for sub_python_path_str in new_env["PYTHONPATH"].split(":"):
|
||||||
|
sub_python_path = Path(sub_python_path_str)
|
||||||
|
if not sub_python_path.joinpath("src", "axolotl").exists():
|
||||||
|
# we don't want to use the automounted axolotl or unexpected behavior happens
|
||||||
|
paths.append(str(sub_python_path))
|
||||||
|
if paths:
|
||||||
|
new_env["PYTHONPATH"] = ":".join(paths)
|
||||||
|
else:
|
||||||
|
del new_env["PYTHONPATH"]
|
||||||
|
|
||||||
# Propagate errors from subprocess.
|
# Propagate errors from subprocess.
|
||||||
if exit_code := subprocess.call( # nosec B603
|
if exit_code := subprocess.call( # nosec B603
|
||||||
@@ -203,9 +214,12 @@ class ModalCloud(Cloud):
|
|||||||
memory = int(self.config.memory)
|
memory = int(self.config.memory)
|
||||||
return 1024 * memory
|
return 1024 * memory
|
||||||
|
|
||||||
def get_train_env(self):
|
def get_train_env(self, local_dirs=None):
|
||||||
|
image = self.get_image()
|
||||||
|
for mount, local_dir in (local_dirs or {}).items():
|
||||||
|
image = image.add_local_dir(local_dir, mount)
|
||||||
return self.app.function(
|
return self.app.function(
|
||||||
image=self.get_image(),
|
image=image,
|
||||||
volumes={k: v[0] for k, v in self.volumes.items()},
|
volumes={k: v[0] for k, v in self.volumes.items()},
|
||||||
cpu=16.0,
|
cpu=16.0,
|
||||||
gpu=self.get_train_gpu(),
|
gpu=self.get_train_gpu(),
|
||||||
@@ -214,14 +228,21 @@ class ModalCloud(Cloud):
|
|||||||
secrets=self.get_secrets(),
|
secrets=self.get_secrets(),
|
||||||
)
|
)
|
||||||
|
|
||||||
def train(self, config_yaml: str, accelerate: bool = True):
|
def train(
|
||||||
modal_fn = self.get_train_env()(_train)
|
self,
|
||||||
|
config_yaml: str,
|
||||||
|
accelerate: bool = True,
|
||||||
|
local_dirs: Optional[dict[str, str]] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
modal_fn = self.get_train_env(local_dirs)(_train)
|
||||||
with modal.enable_output():
|
with modal.enable_output():
|
||||||
with self.app.run(detach=True):
|
with self.app.run(detach=True):
|
||||||
modal_fn.remote(
|
modal_fn.remote(
|
||||||
config_yaml,
|
config_yaml,
|
||||||
accelerate=accelerate,
|
accelerate=accelerate,
|
||||||
volumes={k: v[0] for k, v in self.volumes.items()},
|
volumes={k: v[0] for k, v in self.volumes.items()},
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def lm_eval(self, config_yaml: str):
|
def lm_eval(self, config_yaml: str):
|
||||||
@@ -252,7 +273,7 @@ def _preprocess(config_yaml: str, volumes=None):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _train(config_yaml: str, accelerate: bool = True, volumes=None):
|
def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs):
|
||||||
with open(
|
with open(
|
||||||
"/workspace/artifacts/axolotl/config.yaml", "w", encoding="utf-8"
|
"/workspace/artifacts/axolotl/config.yaml", "w", encoding="utf-8"
|
||||||
) as f_out:
|
) as f_out:
|
||||||
@@ -262,8 +283,11 @@ def _train(config_yaml: str, accelerate: bool = True, volumes=None):
|
|||||||
accelerate_args = "--accelerate"
|
accelerate_args = "--accelerate"
|
||||||
else:
|
else:
|
||||||
accelerate_args = "--no-accelerate"
|
accelerate_args = "--no-accelerate"
|
||||||
|
num_processes_args = ""
|
||||||
|
if num_processes := kwargs.pop("num_processes", None):
|
||||||
|
num_processes_args = f"--num-processes {num_processes}"
|
||||||
run_cmd(
|
run_cmd(
|
||||||
f"axolotl train {accelerate_args} /workspace/artifacts/axolotl/config.yaml",
|
f"axolotl train {accelerate_args} {num_processes_args} /workspace/artifacts/axolotl/config.yaml",
|
||||||
run_folder,
|
run_folder,
|
||||||
volumes,
|
volumes,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
# pylint: disable=redefined-outer-name
|
# pylint: disable=redefined-outer-name
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import random
|
import random
|
||||||
import subprocess # nosec B404
|
import subprocess # nosec B404
|
||||||
import tempfile
|
import tempfile
|
||||||
@@ -12,6 +13,7 @@ from typing import Optional
|
|||||||
|
|
||||||
import click
|
import click
|
||||||
import yaml
|
import yaml
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
import axolotl
|
import axolotl
|
||||||
from axolotl.cli.args import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
|
from axolotl.cli.args import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
|
||||||
@@ -199,7 +201,14 @@ def train(
|
|||||||
try:
|
try:
|
||||||
if accelerate:
|
if accelerate:
|
||||||
if cloud:
|
if cloud:
|
||||||
do_cli_train(cloud_config=cloud, config=config, accelerate=True)
|
cwd = os.getcwd()
|
||||||
|
do_cli_train(
|
||||||
|
cloud_config=cloud,
|
||||||
|
config=config,
|
||||||
|
accelerate=True,
|
||||||
|
cwd=cwd,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
accelerate_args = []
|
accelerate_args = []
|
||||||
if "main_process_port" in kwargs:
|
if "main_process_port" in kwargs:
|
||||||
@@ -208,7 +217,7 @@ def train(
|
|||||||
accelerate_args.append(str(main_process_port))
|
accelerate_args.append(str(main_process_port))
|
||||||
if "num_processes" in kwargs:
|
if "num_processes" in kwargs:
|
||||||
num_processes = kwargs.pop("num_processes", None)
|
num_processes = kwargs.pop("num_processes", None)
|
||||||
accelerate_args.append("--num-processes")
|
accelerate_args.append("--num_processes")
|
||||||
accelerate_args.append(str(num_processes))
|
accelerate_args.append(str(num_processes))
|
||||||
|
|
||||||
base_cmd = ["accelerate", "launch"]
|
base_cmd = ["accelerate", "launch"]
|
||||||
@@ -220,7 +229,9 @@ def train(
|
|||||||
subprocess.run(cmd, check=True) # nosec B603
|
subprocess.run(cmd, check=True) # nosec B603
|
||||||
else:
|
else:
|
||||||
if cloud:
|
if cloud:
|
||||||
do_cli_train(cloud_config=cloud, config=config, accelerate=False)
|
do_cli_train(
|
||||||
|
cloud_config=cloud, config=config, accelerate=False, **kwargs
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
from axolotl.cli.train import do_cli
|
from axolotl.cli.train import do_cli
|
||||||
|
|
||||||
@@ -381,4 +392,5 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
load_dotenv()
|
||||||
main()
|
main()
|
||||||
|
|||||||
@@ -122,9 +122,11 @@ def load_preference_datasets(
|
|||||||
`total_num_steps`.
|
`total_num_steps`.
|
||||||
"""
|
"""
|
||||||
train_dataset, eval_dataset = load_prepare_preference_datasets(cfg)
|
train_dataset, eval_dataset = load_prepare_preference_datasets(cfg)
|
||||||
total_num_steps = int(
|
total_num_steps: Optional[int] = int(
|
||||||
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
||||||
)
|
)
|
||||||
|
if cfg.rl == "grpo":
|
||||||
|
total_num_steps = None
|
||||||
|
|
||||||
if cli_args.debug or cfg.debug:
|
if cli_args.debug or cfg.debug:
|
||||||
LOG.info("check_dataset_labels...")
|
LOG.info("check_dataset_labels...")
|
||||||
|
|||||||
@@ -39,7 +39,6 @@ from trl.trainer.utils import RewardDataCollatorWithPadding
|
|||||||
|
|
||||||
from axolotl.core.trainers.base import (
|
from axolotl.core.trainers.base import (
|
||||||
AxolotlCPOTrainer,
|
AxolotlCPOTrainer,
|
||||||
AxolotlDPOTrainer,
|
|
||||||
AxolotlKTOTrainer,
|
AxolotlKTOTrainer,
|
||||||
AxolotlMambaTrainer,
|
AxolotlMambaTrainer,
|
||||||
AxolotlORPOTrainer,
|
AxolotlORPOTrainer,
|
||||||
@@ -48,9 +47,11 @@ from axolotl.core.trainers.base import (
|
|||||||
AxolotlTrainer,
|
AxolotlTrainer,
|
||||||
ReLoRATrainer,
|
ReLoRATrainer,
|
||||||
)
|
)
|
||||||
|
from axolotl.core.trainers.dpo import DPOStrategy
|
||||||
|
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
|
||||||
|
from axolotl.core.trainers.grpo import GRPOStrategy
|
||||||
from axolotl.core.training_args import (
|
from axolotl.core.training_args import (
|
||||||
AxolotlCPOConfig,
|
AxolotlCPOConfig,
|
||||||
AxolotlDPOConfig,
|
|
||||||
AxolotlKTOConfig,
|
AxolotlKTOConfig,
|
||||||
AxolotlORPOConfig,
|
AxolotlORPOConfig,
|
||||||
AxolotlPRMConfig,
|
AxolotlPRMConfig,
|
||||||
@@ -641,9 +642,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
tokenizer=self.tokenizer,
|
tokenizer=self.tokenizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.cfg.rl == "orpo":
|
|
||||||
training_arguments_kwargs["orpo_alpha"] = self.cfg.orpo_alpha
|
|
||||||
|
|
||||||
if self.cfg.neftune_noise_alpha is not None:
|
if self.cfg.neftune_noise_alpha is not None:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs[
|
||||||
"neftune_noise_alpha"
|
"neftune_noise_alpha"
|
||||||
@@ -652,7 +650,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
trainer_kwargs = {}
|
trainer_kwargs = {}
|
||||||
|
|
||||||
if self.cfg.reward_model:
|
if self.cfg.reward_model:
|
||||||
trainer_kwargs["max_length"] = self.cfg.sequence_len
|
training_arguments_kwargs["max_length"] = self.cfg.sequence_len
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
if self.cfg.optimizer in [
|
if self.cfg.optimizer in [
|
||||||
@@ -965,10 +963,11 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
# default to saving each epoch if not defined
|
# default to saving each epoch if not defined
|
||||||
training_args_kwargs["save_strategy"] = "epoch"
|
training_args_kwargs["save_strategy"] = "epoch"
|
||||||
|
|
||||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
if self.cfg.dataset_processes:
|
||||||
|
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||||
|
|
||||||
if self.cfg.rl_beta:
|
if (self.cfg.trl and self.cfg.trl.beta) or self.cfg.rl_beta:
|
||||||
training_args_kwargs["beta"] = self.cfg.rl_beta
|
training_args_kwargs["beta"] = self.cfg.trl.beta or self.cfg.rl_beta
|
||||||
if self.cfg.orpo_alpha:
|
if self.cfg.orpo_alpha:
|
||||||
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
|
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
|
||||||
training_args_kwargs["beta"] = self.cfg.orpo_alpha
|
training_args_kwargs["beta"] = self.cfg.orpo_alpha
|
||||||
@@ -977,6 +976,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
|
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
|
||||||
|
|
||||||
training_args_cls = None
|
training_args_cls = None
|
||||||
|
blocklist_args_kwargs = []
|
||||||
if self.cfg.rl == "simpo":
|
if self.cfg.rl == "simpo":
|
||||||
training_args_cls = AxolotlCPOConfig
|
training_args_cls = AxolotlCPOConfig
|
||||||
training_args_kwargs["loss_type"] = "simpo"
|
training_args_kwargs["loss_type"] = "simpo"
|
||||||
@@ -1001,11 +1001,15 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
self.cfg.kto_undesirable_weight or 1.0
|
self.cfg.kto_undesirable_weight or 1.0
|
||||||
)
|
)
|
||||||
|
|
||||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
|
||||||
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||||
if self.cfg.max_prompt_len:
|
if self.cfg.max_prompt_len:
|
||||||
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
||||||
|
|
||||||
|
elif self.cfg.rl == "grpo":
|
||||||
|
training_args_cls = GRPOStrategy.get_training_args_class()
|
||||||
|
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
|
||||||
|
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()
|
||||||
|
|
||||||
else:
|
else:
|
||||||
training_args_cls = AxolotlDPOConfig
|
training_args_cls = AxolotlDPOConfig
|
||||||
if self.cfg.rl == "ipo":
|
if self.cfg.rl == "ipo":
|
||||||
@@ -1016,11 +1020,21 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_args_kwargs["generate_during_eval"] = self.cfg.use_wandb
|
training_args_kwargs["generate_during_eval"] = self.cfg.use_wandb
|
||||||
if self.cfg.dpo_use_weighting is not None:
|
if self.cfg.dpo_use_weighting is not None:
|
||||||
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
|
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
|
||||||
|
if self.cfg.dpo_use_logits_to_keep is not None:
|
||||||
|
training_args_kwargs[
|
||||||
|
"use_logits_to_keep"
|
||||||
|
] = self.cfg.dpo_use_logits_to_keep
|
||||||
|
|
||||||
|
for blocklist_key in blocklist_args_kwargs:
|
||||||
|
if blocklist_key in training_args_kwargs:
|
||||||
|
del training_args_kwargs[blocklist_key]
|
||||||
|
|
||||||
|
max_steps = self.cfg.max_steps or total_num_steps or -1
|
||||||
|
training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs
|
||||||
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
|
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
|
||||||
output_dir=self.cfg.output_dir,
|
self.cfg.output_dir,
|
||||||
per_device_train_batch_size=self.cfg.micro_batch_size,
|
per_device_train_batch_size=self.cfg.micro_batch_size,
|
||||||
max_steps=self.cfg.max_steps or total_num_steps,
|
max_steps=max_steps,
|
||||||
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
|
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
|
||||||
learning_rate=self.cfg.learning_rate,
|
learning_rate=self.cfg.learning_rate,
|
||||||
warmup_steps=self.cfg.warmup_steps,
|
warmup_steps=self.cfg.warmup_steps,
|
||||||
@@ -1047,8 +1061,13 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
dpo_trainer_kwargs[
|
dpo_trainer_kwargs[
|
||||||
"precompute_ref_log_probs"
|
"precompute_ref_log_probs"
|
||||||
] = self.cfg.precompute_ref_log_probs
|
] = self.cfg.precompute_ref_log_probs
|
||||||
if self.cfg.rl in ["dpo", "ipo"]:
|
if self.cfg.rl == "grpo":
|
||||||
trainer_cls = AxolotlDPOTrainer
|
trainer_cls = GRPOStrategy.get_trainer_class()
|
||||||
|
trainer_cls_args = [self.model]
|
||||||
|
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
|
||||||
|
dpo_trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
|
||||||
|
elif self.cfg.rl in ["dpo", "ipo"]:
|
||||||
|
trainer_cls = DPOStrategy.get_trainer_class()
|
||||||
trainer_cls_args = [self.model, self.model_ref]
|
trainer_cls_args = [self.model, self.model_ref]
|
||||||
elif self.cfg.rl == "orpo":
|
elif self.cfg.rl == "orpo":
|
||||||
trainer_cls = AxolotlORPOTrainer
|
trainer_cls = AxolotlORPOTrainer
|
||||||
@@ -1063,12 +1082,14 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
||||||
|
|
||||||
sig = inspect.signature(trainer_cls)
|
sig = inspect.signature(trainer_cls)
|
||||||
if "processing_class" in sig.parameters.keys():
|
if "tokenizer" in sig.parameters.keys():
|
||||||
dpo_trainer_kwargs["processing_class"] = self.tokenizer
|
|
||||||
else:
|
|
||||||
dpo_trainer_kwargs["tokenizer"] = self.tokenizer
|
dpo_trainer_kwargs["tokenizer"] = self.tokenizer
|
||||||
|
else:
|
||||||
|
dpo_trainer_kwargs["processing_class"] = self.tokenizer
|
||||||
|
|
||||||
if self.cfg.datasets is not None and (trainer_cls is AxolotlDPOTrainer):
|
if self.cfg.datasets is not None and (
|
||||||
|
trainer_cls is DPOStrategy.get_trainer_class()
|
||||||
|
):
|
||||||
dpo_trainer_kwargs["dataset_tags"] = [
|
dpo_trainer_kwargs["dataset_tags"] = [
|
||||||
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
|
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -5,30 +5,21 @@ module for customized trainers
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
# pylint: disable=too-many-lines
|
# pylint: disable=too-many-lines
|
||||||
import gc
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Any, Dict, Literal, Optional, Union
|
from typing import Dict, Literal, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from peft.optimizers import create_loraplus_optimizer
|
from peft.optimizers import create_loraplus_optimizer
|
||||||
from torch import nn
|
|
||||||
from torch.optim.lr_scheduler import OneCycleLR
|
from torch.optim.lr_scheduler import OneCycleLR
|
||||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
|
||||||
from transformers.utils import is_sagemaker_mp_enabled
|
from transformers.utils import is_sagemaker_mp_enabled
|
||||||
from trl import (
|
from trl import CPOTrainer, KTOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer
|
||||||
CPOTrainer,
|
|
||||||
DPOTrainer,
|
|
||||||
KTOTrainer,
|
|
||||||
ORPOTrainer,
|
|
||||||
PRMTrainer,
|
|
||||||
RewardTrainer,
|
|
||||||
)
|
|
||||||
from trl.trainer.utils import pad_to_length
|
from trl.trainer.utils import pad_to_length
|
||||||
|
|
||||||
from axolotl.monkeypatch.relora import ReLoRAScheduler
|
from axolotl.monkeypatch.relora import ReLoRAScheduler
|
||||||
@@ -847,107 +838,6 @@ class ReLoRATrainer(AxolotlTrainer):
|
|||||||
return self.lr_scheduler
|
return self.lr_scheduler
|
||||||
|
|
||||||
|
|
||||||
class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
|
||||||
"""
|
|
||||||
Extend the base DPOTrainer for axolotl helpers
|
|
||||||
"""
|
|
||||||
|
|
||||||
tag_names = ["axolotl", "dpo"]
|
|
||||||
|
|
||||||
def __init__(self, *args, dataset_tags=None, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.dataset_tags = dataset_tags
|
|
||||||
self.optimizer = None
|
|
||||||
self.model_accepts_loss_kwargs = False
|
|
||||||
|
|
||||||
def create_optimizer(self):
|
|
||||||
if self.args.loraplus_lr_ratio is None:
|
|
||||||
return super().create_optimizer()
|
|
||||||
|
|
||||||
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
|
||||||
if self.optimizer is None: # pylint: disable=access-member-before-definition
|
|
||||||
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
|
|
||||||
self.args,
|
|
||||||
opt_model,
|
|
||||||
)
|
|
||||||
|
|
||||||
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
|
||||||
if loraplus_lr_ratio:
|
|
||||||
print("Using lora+")
|
|
||||||
loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None)
|
|
||||||
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
|
||||||
opt_model,
|
|
||||||
optimizer_cls,
|
|
||||||
loraplus_lr_ratio=loraplus_lr_ratio,
|
|
||||||
loraplus_lr_embedding=loraplus_lr_embedding,
|
|
||||||
**optimizer_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_sagemaker_mp_enabled():
|
|
||||||
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
|
||||||
self.optimizer
|
|
||||||
)
|
|
||||||
|
|
||||||
return self.optimizer
|
|
||||||
|
|
||||||
@wraps(DPOTrainer.push_to_hub)
|
|
||||||
def push_to_hub(self, *args, **kwargs) -> str:
|
|
||||||
"""
|
|
||||||
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
|
||||||
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
|
||||||
"""
|
|
||||||
kwargs = _sanitize_kwargs_for_ds_tagging(
|
|
||||||
dataset_tags=self.dataset_tags, kwargs=kwargs
|
|
||||||
)
|
|
||||||
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
|
||||||
|
|
||||||
return super().push_to_hub(*args, **kwargs)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def tokenize_row(
|
|
||||||
features,
|
|
||||||
processing_class,
|
|
||||||
max_prompt_length,
|
|
||||||
max_completion_length,
|
|
||||||
add_special_tokens,
|
|
||||||
) -> Dict:
|
|
||||||
res = DPOTrainer.tokenize_row(
|
|
||||||
features,
|
|
||||||
processing_class,
|
|
||||||
max_prompt_length,
|
|
||||||
max_completion_length,
|
|
||||||
add_special_tokens,
|
|
||||||
)
|
|
||||||
# fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen
|
|
||||||
if processing_class.bos_token is None and res["prompt_input_ids"][0] is None:
|
|
||||||
for key in res.keys():
|
|
||||||
res[key] = res[key][1:]
|
|
||||||
|
|
||||||
if processing_class.bos_token and processing_class.bos_token_id is not None:
|
|
||||||
# dpo trainer may incorrectly prepend the bos_token_id to the dpo outputs
|
|
||||||
if res["chosen_input_ids"][0] == processing_class.bos_token_id:
|
|
||||||
res["chosen_input_ids"] = res["chosen_input_ids"][1:]
|
|
||||||
res["chosen_labels"] = res["chosen_labels"][1:]
|
|
||||||
res["chosen_attention_mask"] = res["chosen_attention_mask"][1:]
|
|
||||||
if res["rejected_input_ids"][0] == processing_class.bos_token_id:
|
|
||||||
res["rejected_input_ids"] = res["rejected_input_ids"][1:]
|
|
||||||
res["rejected_labels"] = res["rejected_labels"][1:]
|
|
||||||
res["rejected_attention_mask"] = res["rejected_attention_mask"][1:]
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
def training_step(
|
|
||||||
self,
|
|
||||||
model: nn.Module,
|
|
||||||
inputs: Dict[str, Union[torch.Tensor, Any]],
|
|
||||||
num_items_in_batch=None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
loss: torch.Tensor = super().training_step(model, inputs, num_items_in_batch)
|
|
||||||
gc.collect()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
return loss
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
||||||
"""
|
"""
|
||||||
Extend the base ORPOTrainer for axolotl helpers
|
Extend the base ORPOTrainer for axolotl helpers
|
||||||
|
|||||||
33
src/axolotl/core/trainers/dpo/__init__.py
Normal file
33
src/axolotl/core/trainers/dpo/__init__.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
"""
|
||||||
|
DPO Specific Strategy for training
|
||||||
|
"""
|
||||||
|
from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer
|
||||||
|
|
||||||
|
|
||||||
|
class DPOStrategy:
|
||||||
|
"""
|
||||||
|
Strategy for DPO training
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_trainer_class(cls):
|
||||||
|
return AxolotlDPOTrainer
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_training_args_class(cls):
|
||||||
|
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
|
||||||
|
|
||||||
|
return AxolotlDPOConfig
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def set_training_args_kwargs(cls, cfg):
|
||||||
|
training_args_kwargs = {}
|
||||||
|
if cfg.rl == "ipo":
|
||||||
|
training_args_kwargs["loss_type"] = "ipo"
|
||||||
|
training_args_kwargs["max_length"] = cfg.sequence_len
|
||||||
|
training_args_kwargs["max_completion_length"] = None
|
||||||
|
training_args_kwargs["max_prompt_length"] = cfg.sequence_len
|
||||||
|
training_args_kwargs["generate_during_eval"] = cfg.use_wandb
|
||||||
|
if cfg.dpo_use_weighting is not None:
|
||||||
|
training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting
|
||||||
|
return training_args_kwargs
|
||||||
15
src/axolotl/core/trainers/dpo/args.py
Normal file
15
src/axolotl/core/trainers/dpo/args.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
"""
|
||||||
|
Axolotl specific DPO args
|
||||||
|
"""
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from trl import DPOConfig
|
||||||
|
|
||||||
|
from axolotl.core.training_args import AxolotlTrainingMixins
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
|
||||||
|
"""
|
||||||
|
DPO config for DPO training
|
||||||
|
"""
|
||||||
125
src/axolotl/core/trainers/dpo/trainer.py
Normal file
125
src/axolotl/core/trainers/dpo/trainer.py
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
"""
|
||||||
|
DPO trainer for axolotl
|
||||||
|
"""
|
||||||
|
import gc
|
||||||
|
from functools import wraps
|
||||||
|
from typing import Any, Dict, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from peft.optimizers import create_loraplus_optimizer
|
||||||
|
from torch import nn
|
||||||
|
from transformers import Trainer
|
||||||
|
from transformers.utils import is_sagemaker_mp_enabled
|
||||||
|
from trl import DPOTrainer
|
||||||
|
|
||||||
|
from axolotl.core.trainers.base import (
|
||||||
|
SchedulerMixin,
|
||||||
|
_sanitize_kwargs_for_ds_tagging,
|
||||||
|
_sanitize_kwargs_for_tagging,
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_sagemaker_mp_enabled():
|
||||||
|
import smdistributed.modelparallel.torch as smp
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
||||||
|
"""
|
||||||
|
Extend the base DPOTrainer for axolotl helpers
|
||||||
|
"""
|
||||||
|
|
||||||
|
tag_names = ["axolotl", "dpo"]
|
||||||
|
|
||||||
|
def __init__(self, *args, dataset_tags=None, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.dataset_tags = dataset_tags
|
||||||
|
self.optimizer = None
|
||||||
|
self.model_accepts_loss_kwargs = False
|
||||||
|
|
||||||
|
def create_optimizer(self):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
if self.args.loraplus_lr_ratio is None:
|
||||||
|
return super().create_optimizer()
|
||||||
|
|
||||||
|
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
||||||
|
if self.optimizer is None: # pylint: disable=access-member-before-definition
|
||||||
|
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
|
||||||
|
self.args,
|
||||||
|
opt_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
||||||
|
if loraplus_lr_ratio:
|
||||||
|
print("Using lora+")
|
||||||
|
loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None)
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
||||||
|
opt_model,
|
||||||
|
optimizer_cls,
|
||||||
|
loraplus_lr_ratio=loraplus_lr_ratio,
|
||||||
|
loraplus_lr_embedding=loraplus_lr_embedding,
|
||||||
|
**optimizer_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_sagemaker_mp_enabled():
|
||||||
|
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
||||||
|
self.optimizer
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.optimizer
|
||||||
|
|
||||||
|
@wraps(DPOTrainer.push_to_hub)
|
||||||
|
def push_to_hub(self, *args, **kwargs) -> str:
|
||||||
|
"""
|
||||||
|
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
||||||
|
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
||||||
|
"""
|
||||||
|
kwargs = _sanitize_kwargs_for_ds_tagging(
|
||||||
|
dataset_tags=self.dataset_tags, kwargs=kwargs
|
||||||
|
)
|
||||||
|
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
||||||
|
|
||||||
|
return super().push_to_hub(*args, **kwargs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def tokenize_row(
|
||||||
|
features,
|
||||||
|
processing_class,
|
||||||
|
max_prompt_length,
|
||||||
|
max_completion_length,
|
||||||
|
add_special_tokens,
|
||||||
|
) -> Dict:
|
||||||
|
res = DPOTrainer.tokenize_row(
|
||||||
|
features,
|
||||||
|
processing_class,
|
||||||
|
max_prompt_length,
|
||||||
|
max_completion_length,
|
||||||
|
add_special_tokens,
|
||||||
|
)
|
||||||
|
# fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen
|
||||||
|
if processing_class.bos_token is None and res["prompt_input_ids"][0] is None:
|
||||||
|
for key in res.keys():
|
||||||
|
res[key] = res[key][1:]
|
||||||
|
|
||||||
|
if processing_class.bos_token and processing_class.bos_token_id is not None:
|
||||||
|
# dpo trainer may incorrectly prepend the bos_token_id to the dpo outputs
|
||||||
|
if res["chosen_input_ids"][0] == processing_class.bos_token_id:
|
||||||
|
res["chosen_input_ids"] = res["chosen_input_ids"][1:]
|
||||||
|
res["chosen_labels"] = res["chosen_labels"][1:]
|
||||||
|
res["chosen_attention_mask"] = res["chosen_attention_mask"][1:]
|
||||||
|
if res["rejected_input_ids"][0] == processing_class.bos_token_id:
|
||||||
|
res["rejected_input_ids"] = res["rejected_input_ids"][1:]
|
||||||
|
res["rejected_labels"] = res["rejected_labels"][1:]
|
||||||
|
res["rejected_attention_mask"] = res["rejected_attention_mask"][1:]
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
def training_step(
|
||||||
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
inputs: Dict[str, Union[torch.Tensor, Any]],
|
||||||
|
num_items_in_batch=None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
loss: torch.Tensor = super().training_step(model, inputs, num_items_in_batch)
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
return loss
|
||||||
119
src/axolotl/core/trainers/grpo/__init__.py
Normal file
119
src/axolotl/core/trainers/grpo/__init__.py
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
"""
|
||||||
|
GRPO Specific Strategy for training
|
||||||
|
"""
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import inspect
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from trl.trainer.grpo_trainer import RewardFunc
|
||||||
|
|
||||||
|
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|
||||||
|
class GRPOStrategy:
|
||||||
|
"""
|
||||||
|
Strategy for GRPO training
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_trainer_class(cls):
|
||||||
|
return AxolotlGRPOTrainer
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_training_args_class(cls):
|
||||||
|
from axolotl.core.trainers.grpo.args import AxolotlGRPOConfig
|
||||||
|
|
||||||
|
return AxolotlGRPOConfig
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def set_training_args_kwargs(cls, cfg):
|
||||||
|
grpo_args_kwargs = {}
|
||||||
|
if cfg.trl and cfg.trl.use_vllm:
|
||||||
|
grpo_args_kwargs["use_vllm"] = cfg.trl.use_vllm
|
||||||
|
if cfg.trl and cfg.trl.vllm_device:
|
||||||
|
grpo_args_kwargs["vllm_device"] = cfg.trl.vllm_device
|
||||||
|
else:
|
||||||
|
grpo_args_kwargs["vllm_device"] = "auto"
|
||||||
|
if cfg.trl and cfg.trl.vllm_gpu_memory_utilization:
|
||||||
|
grpo_args_kwargs[
|
||||||
|
"vllm_gpu_memory_utilization"
|
||||||
|
] = cfg.trl.vllm_gpu_memory_utilization
|
||||||
|
if cfg.trl and cfg.trl.vllm_max_model_len:
|
||||||
|
grpo_args_kwargs["vllm_max_model_len"] = cfg.trl.vllm_max_model_len
|
||||||
|
if cfg.trl and cfg.trl.num_generations:
|
||||||
|
grpo_args_kwargs["num_generations"] = cfg.trl.num_generations
|
||||||
|
if cfg.trl and cfg.trl.sync_ref_model:
|
||||||
|
grpo_args_kwargs["sync_ref_model"] = cfg.trl.sync_ref_model
|
||||||
|
if cfg.trl and cfg.trl.ref_model_mixup_alpha:
|
||||||
|
grpo_args_kwargs[
|
||||||
|
"ref_model_mixup_alpha"
|
||||||
|
] = cfg.trl.ref_model_mixup_alpha
|
||||||
|
if cfg.trl and cfg.trl.ref_model_sync_steps:
|
||||||
|
grpo_args_kwargs["ref_model_sync_steps"] = cfg.trl.ref_model_sync_steps
|
||||||
|
grpo_args_kwargs["max_completion_length"] = cfg.trl.max_completion_length
|
||||||
|
grpo_args_kwargs["log_completions"] = cfg.trl.log_completions
|
||||||
|
return grpo_args_kwargs
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def set_trainer_args(cls, cfg):
|
||||||
|
trainer_args = []
|
||||||
|
if cfg.trl and cfg.trl.reward_funcs:
|
||||||
|
reward_funcs = []
|
||||||
|
for reward_func_fqn in cfg.trl.reward_funcs:
|
||||||
|
reward_funcs.append(cls.get_reward_func(reward_func_fqn))
|
||||||
|
trainer_args.append(reward_funcs)
|
||||||
|
return trainer_args
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def set_trainer_kwargs(cls, cfg):
|
||||||
|
trainer_kwargs = {}
|
||||||
|
if cfg.trl and cfg.trl.reward_processing_classes:
|
||||||
|
trainer_kwargs[
|
||||||
|
"reward_processing_classes"
|
||||||
|
] = cfg.trl.reward_processing_classes
|
||||||
|
return trainer_kwargs
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_collator(cls, *args, **kwargs): # pylint: disable=unused-argument
|
||||||
|
# No data collation is needed in GRPO, handled by trl's trainer __init__
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_blocklist_args_kwargs(cls):
|
||||||
|
return ["dataset_num_proc"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_reward_func(cls, reward_func_fqn: str) -> RewardFunc:
|
||||||
|
"""
|
||||||
|
Returns the reward function from the given fully qualified name, or the path to the reward function model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
reward_func_fqn (str): Fully qualified name of the reward function (e.g. r1_grpo.gsm8k_transform),
|
||||||
|
or a HF hub path to the reward model.
|
||||||
|
Raises:
|
||||||
|
ValueError: If the reward function does not accept at least two arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RewardFunc: A callable that accepts prompts and completions and returns rewards,
|
||||||
|
or a path to a reward model.
|
||||||
|
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# use importlib to dynamically load the reward function from the module
|
||||||
|
reward_func_module_name = reward_func_fqn.split(".")[-1]
|
||||||
|
reward_func_module = importlib.import_module(reward_func_fqn.split(".")[-2])
|
||||||
|
reward_func = getattr(reward_func_module, reward_func_module_name)
|
||||||
|
if not len(inspect.signature(reward_func).parameters) >= 2:
|
||||||
|
raise ValueError(
|
||||||
|
"Reward function must accept at least two arguments: prompts: list and completions: list"
|
||||||
|
)
|
||||||
|
return reward_func
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
# the user has passed a string (ideally indicating the path of a reward model)
|
||||||
|
LOG.info(
|
||||||
|
f"Reward function {reward_func_fqn} is a pre-trained model path - if this is unexpected, please check the reward function path."
|
||||||
|
)
|
||||||
|
return reward_func
|
||||||
15
src/axolotl/core/trainers/grpo/args.py
Normal file
15
src/axolotl/core/trainers/grpo/args.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
"""
|
||||||
|
Axolotl Specific Training Args
|
||||||
|
"""
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from trl import GRPOConfig
|
||||||
|
|
||||||
|
from axolotl.core.training_args import AxolotlTrainingMixins
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
|
||||||
|
"""
|
||||||
|
Axolotl GRPO Config for GRPO training
|
||||||
|
"""
|
||||||
107
src/axolotl/core/trainers/grpo/trainer.py
Normal file
107
src/axolotl/core/trainers/grpo/trainer.py
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
"""
|
||||||
|
Axolotl GRPO trainer
|
||||||
|
"""
|
||||||
|
from accelerate.utils import is_peft_model
|
||||||
|
from accelerate.utils.other import is_compiled_module
|
||||||
|
from transformers import PreTrainedModel
|
||||||
|
from trl import GRPOConfig, GRPOTrainer
|
||||||
|
from trl.models import unwrap_model_for_generation
|
||||||
|
|
||||||
|
from axolotl.core.trainers.base import SchedulerMixin
|
||||||
|
|
||||||
|
|
||||||
|
# mypy: ignore-errors
|
||||||
|
class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
|
||||||
|
"""
|
||||||
|
Extend the base GRPOTrainer for axolotl helpers
|
||||||
|
"""
|
||||||
|
|
||||||
|
_tag_names = ["trl", "grpo", "axolotl"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
# pylint: disable=access-member-before-definition
|
||||||
|
# Enable gradient checkpointing if requested
|
||||||
|
if kwargs["args"].gradient_checkpointing:
|
||||||
|
# Ensure use_cache is disabled
|
||||||
|
if hasattr(self.model, "config"):
|
||||||
|
self.model.config.use_cache = False
|
||||||
|
|
||||||
|
# Enable gradient checkpointing on the base model for PEFT
|
||||||
|
if is_peft_model(self.model) and hasattr(
|
||||||
|
self.model.base_model, "gradient_checkpointing_enable"
|
||||||
|
):
|
||||||
|
self.model.base_model.gradient_checkpointing_enable()
|
||||||
|
# Enable gradient checkpointing for non-PEFT models
|
||||||
|
elif hasattr(self.model, "gradient_checkpointing_enable"):
|
||||||
|
self.model.gradient_checkpointing_enable()
|
||||||
|
self.model = self._enable_gradient_checkpointing(self.model, kwargs["args"])
|
||||||
|
# pylint: enable=access-member-before-definition
|
||||||
|
|
||||||
|
def _enable_gradient_checkpointing(
|
||||||
|
self, model: PreTrainedModel, args: GRPOConfig
|
||||||
|
) -> PreTrainedModel:
|
||||||
|
"""Enables gradient checkpointing for the model."""
|
||||||
|
# pylint: disable=unused-argument,redefined-builtin
|
||||||
|
gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
|
||||||
|
use_reentrant = (
|
||||||
|
"use_reentrant" not in gradient_checkpointing_kwargs
|
||||||
|
or gradient_checkpointing_kwargs["use_reentrant"]
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_reentrant:
|
||||||
|
if hasattr(model, "enable_input_require_grads"):
|
||||||
|
model.enable_input_require_grads()
|
||||||
|
else:
|
||||||
|
|
||||||
|
def make_inputs_require_grad(module, input, output):
|
||||||
|
output.requires_grad_(True)
|
||||||
|
|
||||||
|
model.get_input_embeddings().register_forward_hook(
|
||||||
|
make_inputs_require_grad
|
||||||
|
)
|
||||||
|
|
||||||
|
return model
|
||||||
|
# pylint: enable=unused-argument,redefined-builtin
|
||||||
|
|
||||||
|
def _move_model_to_vllm(self):
|
||||||
|
with unwrap_model_for_generation(
|
||||||
|
self.model,
|
||||||
|
self.accelerator,
|
||||||
|
gather_deepspeed3_params=self.args.ds3_gather_for_generation,
|
||||||
|
) as unwrapped_model:
|
||||||
|
if is_compiled_module(unwrapped_model):
|
||||||
|
unwrapped_model = (
|
||||||
|
unwrapped_model._orig_mod # pylint: disable=protected-access
|
||||||
|
)
|
||||||
|
if is_peft_model(unwrapped_model):
|
||||||
|
unwrapped_model.merge_adapter()
|
||||||
|
state_dict = unwrapped_model.state_dict()
|
||||||
|
unwrapped_model.unmerge_adapter()
|
||||||
|
# Remove base_model and base_layer prefixes
|
||||||
|
state_dict = {
|
||||||
|
k.removeprefix("base_model.model.")
|
||||||
|
.removeprefix("base_model.model.")
|
||||||
|
.replace(".base_layer", ""): v
|
||||||
|
for k, v in state_dict.items()
|
||||||
|
}
|
||||||
|
# Remove values with adapter prefix (example: "_lora")
|
||||||
|
state_dict = {
|
||||||
|
k: v
|
||||||
|
for k, v in state_dict.items()
|
||||||
|
if unwrapped_model.prefix not in k
|
||||||
|
}
|
||||||
|
# When module to save, remove its prefix and discard the original module
|
||||||
|
state_dict = {
|
||||||
|
k.replace("modules_to_save.default.", ""): v
|
||||||
|
for k, v in state_dict.items()
|
||||||
|
if "original_module" not in k
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
state_dict = unwrapped_model.state_dict()
|
||||||
|
if self.accelerator.is_main_process:
|
||||||
|
llm_model = (
|
||||||
|
self.llm.llm_engine.model_executor.driver_worker.model_runner.model
|
||||||
|
)
|
||||||
|
llm_model.load_weights(state_dict.items())
|
||||||
@@ -5,7 +5,7 @@ from dataclasses import dataclass, field
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from transformers import TrainingArguments
|
from transformers import TrainingArguments
|
||||||
from trl import CPOConfig, DPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
|
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -217,13 +217,6 @@ class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
|
|
||||||
"""
|
|
||||||
DPO config for DPO training
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AxolotlORPOConfig(AxolotlTrainingMixins, ORPOConfig):
|
class AxolotlORPOConfig(AxolotlTrainingMixins, ORPOConfig):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -13,8 +13,19 @@ def load(strategy, cfg, module_base=None, **kwargs):
|
|||||||
if len(strategy.split(".")) == 1:
|
if len(strategy.split(".")) == 1:
|
||||||
strategy = strategy + ".default"
|
strategy = strategy + ".default"
|
||||||
load_fn = strategy.split(".")[-1]
|
load_fn = strategy.split(".")[-1]
|
||||||
strategy = ".".join(strategy.split(".")[:-1])
|
if len(strategy.split(".")) > 1:
|
||||||
mod = importlib.import_module(f".{strategy}", module_base)
|
try:
|
||||||
|
importlib.import_module(
|
||||||
|
strategy.split(".")[-2],
|
||||||
|
".".join(strategy.split(".")[:-2]),
|
||||||
|
)
|
||||||
|
module_base = ".".join(strategy.split(".")[:-2])
|
||||||
|
strategy = strategy.split(".")[-2]
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
strategy = "." + ".".join(strategy.split(".")[:-1])
|
||||||
|
else:
|
||||||
|
strategy = "." + ".".join(strategy.split(".")[:-1])
|
||||||
|
mod = importlib.import_module(strategy, module_base)
|
||||||
func = getattr(mod, load_fn)
|
func = getattr(mod, load_fn)
|
||||||
return func(cfg, **kwargs)
|
return func(cfg, **kwargs)
|
||||||
except Exception: # pylint: disable=broad-exception-caught
|
except Exception: # pylint: disable=broad-exception-caught
|
||||||
|
|||||||
14
src/axolotl/prompt_strategies/dpo/passthrough.py
Normal file
14
src/axolotl/prompt_strategies/dpo/passthrough.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
"""
|
||||||
|
DPO prompt strategies passthrough/zero-processing strategy
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def default(
|
||||||
|
cfg, dataset_idx=0, **kwargs
|
||||||
|
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||||
|
def transform_fn(
|
||||||
|
sample, tokenizer=None
|
||||||
|
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||||
|
return sample
|
||||||
|
|
||||||
|
return transform_fn
|
||||||
@@ -24,6 +24,8 @@ from transformers.utils.import_utils import is_torch_npu_available
|
|||||||
|
|
||||||
from axolotl.utils.config.models.internals import EnvCapabilities, GPUCapabilities
|
from axolotl.utils.config.models.internals import EnvCapabilities, GPUCapabilities
|
||||||
|
|
||||||
|
from .trl import TRLConfig
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.utils.config.models.input")
|
LOG = logging.getLogger("axolotl.utils.config.models.input")
|
||||||
|
|
||||||
SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"}
|
SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"}
|
||||||
@@ -33,6 +35,7 @@ class RLType(str, Enum):
|
|||||||
"""RL trainer type configuration subset"""
|
"""RL trainer type configuration subset"""
|
||||||
|
|
||||||
dpo = "dpo" # pylint: disable=invalid-name
|
dpo = "dpo" # pylint: disable=invalid-name
|
||||||
|
grpo = "grpo" # pylint: disable=invalid-name
|
||||||
ipo = "ipo" # pylint: disable=invalid-name
|
ipo = "ipo" # pylint: disable=invalid-name
|
||||||
orpo = "orpo" # pylint: disable=invalid-name
|
orpo = "orpo" # pylint: disable=invalid-name
|
||||||
kto = "kto" # pylint: disable=invalid-name
|
kto = "kto" # pylint: disable=invalid-name
|
||||||
@@ -664,14 +667,20 @@ class AxolotlInputConfig(
|
|||||||
auto_resume_from_checkpoints: Optional[bool] = None
|
auto_resume_from_checkpoints: Optional[bool] = None
|
||||||
resize_token_embeddings_to_32x: Optional[bool] = None
|
resize_token_embeddings_to_32x: Optional[bool] = None
|
||||||
mean_resizing_embeddings: Optional[bool] = False
|
mean_resizing_embeddings: Optional[bool] = False
|
||||||
|
# optionally shrink the embeddings when the tokenizer vocab size is smaller
|
||||||
|
shrink_embeddings: Optional[bool] = None
|
||||||
|
|
||||||
rl: Optional[RLType] = None
|
rl: Optional[RLType] = None
|
||||||
|
trl: Optional[TRLConfig] = Field(
|
||||||
|
default_factory=lambda: TRLConfig(), # pylint: disable=unnecessary-lambda
|
||||||
|
)
|
||||||
reward_model: Optional[bool] = None
|
reward_model: Optional[bool] = None
|
||||||
process_reward_model: Optional[bool] = None
|
process_reward_model: Optional[bool] = None
|
||||||
num_labels: Optional[int] = None
|
num_labels: Optional[int] = None
|
||||||
dpo_use_weighting: Optional[
|
dpo_use_weighting: Optional[
|
||||||
bool
|
bool
|
||||||
] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer.
|
] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer.
|
||||||
|
dpo_use_logits_to_keep: Optional[bool] = None
|
||||||
|
|
||||||
datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedDataset], min_length=1)] = None # type: ignore
|
datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedDataset], min_length=1)] = None # type: ignore
|
||||||
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedDataset], min_length=1)] = None # type: ignore
|
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedDataset], min_length=1)] = None # type: ignore
|
||||||
|
|||||||
35
src/axolotl/utils/config/models/input/v0_4_1/trl.py
Normal file
35
src/axolotl/utils/config/models/input/v0_4_1/trl.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
"""
|
||||||
|
GRPO specific configuration args
|
||||||
|
"""
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class TRLConfig(BaseModel):
|
||||||
|
"""
|
||||||
|
Input args for TRL.
|
||||||
|
"""
|
||||||
|
|
||||||
|
beta: Optional[float] = None
|
||||||
|
max_completion_length: Optional[int] = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Maximum length of the completion for RL training"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# GRPO specific args
|
||||||
|
use_vllm: Optional[bool] = False
|
||||||
|
vllm_device: Optional[str] = "auto"
|
||||||
|
vllm_gpu_memory_utilization: Optional[float] = 0.9
|
||||||
|
vllm_max_model_len: Optional[int] = None
|
||||||
|
vllm_dtype: Optional[str] = "auto"
|
||||||
|
|
||||||
|
reward_funcs: Optional[List[str]] = None
|
||||||
|
num_generations: Optional[int] = None
|
||||||
|
log_completions: Optional[bool] = False
|
||||||
|
|
||||||
|
sync_ref_model: Optional[bool] = False
|
||||||
|
ref_model_mixup_alpha: Optional[float] = 0.9
|
||||||
|
ref_model_sync_steps: Optional[int] = 64
|
||||||
@@ -58,7 +58,7 @@ def _save_preprocessed_ds(cfg, sub_cfg, dataset):
|
|||||||
dataset.save_to_disk(str(prepared_ds_path))
|
dataset.save_to_disk(str(prepared_ds_path))
|
||||||
|
|
||||||
|
|
||||||
def map_dataset(cfg, data_set, ds_transform_fn, tokenizer):
|
def map_dataset(cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs):
|
||||||
sig = inspect.signature(ds_transform_fn)
|
sig = inspect.signature(ds_transform_fn)
|
||||||
if "tokenizer" in sig.parameters:
|
if "tokenizer" in sig.parameters:
|
||||||
if not tokenizer:
|
if not tokenizer:
|
||||||
@@ -71,6 +71,7 @@ def map_dataset(cfg, data_set, ds_transform_fn, tokenizer):
|
|||||||
data_set = data_set.map(
|
data_set = data_set.map(
|
||||||
ds_transform_fn,
|
ds_transform_fn,
|
||||||
desc="Mapping RL Dataset",
|
desc="Mapping RL Dataset",
|
||||||
|
**map_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
return data_set
|
return data_set
|
||||||
@@ -113,6 +114,9 @@ def drop_long_rl_seq(
|
|||||||
|
|
||||||
return (len_prompt + len_completion) <= sequence_len
|
return (len_prompt + len_completion) <= sequence_len
|
||||||
|
|
||||||
|
if rl == "grpo":
|
||||||
|
return True
|
||||||
|
|
||||||
raise ValueError("Unknown RL type")
|
raise ValueError("Unknown RL type")
|
||||||
|
|
||||||
|
|
||||||
@@ -140,36 +144,45 @@ def load_prepare_preference_datasets(cfg):
|
|||||||
else:
|
else:
|
||||||
ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
|
ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
|
||||||
|
|
||||||
|
map_kwargs = {}
|
||||||
|
if isinstance(ds_transform_fn, tuple):
|
||||||
|
ds_transform_fn, map_kwargs = ds_transform_fn
|
||||||
split_datasets[i] = map_dataset(
|
split_datasets[i] = map_dataset(
|
||||||
cfg, data_set, ds_transform_fn, tokenizer
|
cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs
|
||||||
)
|
)
|
||||||
elif _cfg.rl == "kto":
|
elif _cfg.rl == "kto":
|
||||||
ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i)
|
ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i)
|
||||||
|
map_kwargs = {}
|
||||||
|
if isinstance(ds_transform_fn, tuple):
|
||||||
|
ds_transform_fn, map_kwargs = ds_transform_fn
|
||||||
split_datasets[i] = map_dataset(
|
split_datasets[i] = map_dataset(
|
||||||
cfg, data_set, ds_transform_fn, tokenizer
|
cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# If no `type` is provided, assume the dataset is already in the expected format with
|
# If no `type` is provided, assume the dataset is already in the expected format with
|
||||||
# "prompt", "chosen" and "rejected" already preprocessed
|
# "prompt", "chosen" and "rejected" already preprocessed
|
||||||
split_datasets[i] = data_set
|
split_datasets[i] = data_set
|
||||||
|
|
||||||
drop_long = partial(
|
if not cfg.skip_prepare_dataset:
|
||||||
drop_long_rl_seq,
|
drop_long = partial(
|
||||||
rl=_cfg.rl,
|
drop_long_rl_seq,
|
||||||
tokenizer=tokenizer,
|
rl=_cfg.rl,
|
||||||
sequence_len=cfg.sequence_len,
|
tokenizer=tokenizer,
|
||||||
)
|
sequence_len=cfg.sequence_len,
|
||||||
|
)
|
||||||
|
|
||||||
prior_len = len(split_datasets[i])
|
prior_len = len(split_datasets[i])
|
||||||
split_datasets[i] = split_datasets[i].filter(
|
split_datasets[i] = split_datasets[i].filter(
|
||||||
drop_long,
|
drop_long,
|
||||||
num_proc=cfg.dataset_processes,
|
num_proc=cfg.dataset_processes,
|
||||||
load_from_cache_file=not cfg.is_preprocess,
|
load_from_cache_file=not cfg.is_preprocess,
|
||||||
desc="Dropping Long Sequences",
|
desc="Dropping Long Sequences",
|
||||||
)
|
)
|
||||||
dropped = prior_len - len(split_datasets[i])
|
dropped = prior_len - len(split_datasets[i])
|
||||||
if dropped:
|
if dropped:
|
||||||
LOG.warning(f"Dropped {dropped} long samples from dataset index {i}")
|
LOG.warning(
|
||||||
|
f"Dropped {dropped} long samples from dataset index {i}"
|
||||||
|
)
|
||||||
|
|
||||||
combined_datasets = concatenate_datasets(split_datasets)
|
combined_datasets = concatenate_datasets(split_datasets)
|
||||||
combined_datasets = combined_datasets.shuffle(seed=cfg.seed)
|
combined_datasets = combined_datasets.shuffle(seed=cfg.seed)
|
||||||
|
|||||||
75
src/axolotl/utils/lora.py
Normal file
75
src/axolotl/utils/lora.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
# Copyright 2025 Axolotl AI. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
module to get the state dict of a merged lora model
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
from peft.tuners.tuners_utils import onload_layer
|
||||||
|
from peft.utils import ModulesToSaveWrapper, _get_submodules
|
||||||
|
|
||||||
|
|
||||||
|
def get_lora_merged_state_dict(
|
||||||
|
model: torch.nn.Module,
|
||||||
|
) -> dict:
|
||||||
|
r"""
|
||||||
|
Create and return a state_dict that has the LoRA deltas
|
||||||
|
merged into the base model’s weights, without modifying `model` in place.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
model (torch.nn.Module): A model that has LoRA/PEFT adapters attached.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A state_dict of the merged parameters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
base_model_prefix = "base_model.model."
|
||||||
|
state_dict = {}
|
||||||
|
key_list = [key for key, _ in model.named_modules() if model.prefix not in key]
|
||||||
|
for key in key_list:
|
||||||
|
try:
|
||||||
|
_, target, _ = _get_submodules(model, key)
|
||||||
|
except AttributeError:
|
||||||
|
continue
|
||||||
|
with onload_layer(target):
|
||||||
|
weight_key = key.replace(base_model_prefix, "") + ".weight"
|
||||||
|
bias_key = key.replace(base_model_prefix, "") + ".bias"
|
||||||
|
if hasattr(target, "base_layer"):
|
||||||
|
target.merge(safe_merge=True, adapter_names=None)
|
||||||
|
# get the state_dict of target.base_layer
|
||||||
|
layer_state_dict = target.base_layer.state_dict()
|
||||||
|
state_dict[weight_key] = layer_state_dict["weight"]
|
||||||
|
elif isinstance(target, ModulesToSaveWrapper):
|
||||||
|
# save any additional trainable modules part of `modules_to_save`
|
||||||
|
new_module = target.modules_to_save[target.active_adapter]
|
||||||
|
if hasattr(new_module, "base_layer"):
|
||||||
|
# check if the module is itself a tuner layer
|
||||||
|
new_module.merge(safe_merge=True, adapter_names=None)
|
||||||
|
layer_state_dict = new_module.state_dict()
|
||||||
|
state_dict[weight_key] = layer_state_dict["weight"]
|
||||||
|
elif hasattr(target, "weight"):
|
||||||
|
if any(
|
||||||
|
skip in key
|
||||||
|
for skip in [
|
||||||
|
".original_module",
|
||||||
|
".modules_to_save",
|
||||||
|
".base_layer",
|
||||||
|
]
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
layer_state_dict = target.state_dict()
|
||||||
|
state_dict[weight_key] = layer_state_dict["weight"]
|
||||||
|
if hasattr(target, "bias") and "bias" in layer_state_dict.keys():
|
||||||
|
state_dict[bias_key] = layer_state_dict["bias"]
|
||||||
|
return state_dict
|
||||||
@@ -1053,9 +1053,12 @@ class ModelLoader:
|
|||||||
if self.cfg.resize_token_embeddings_to_32x
|
if self.cfg.resize_token_embeddings_to_32x
|
||||||
else len(self.tokenizer)
|
else len(self.tokenizer)
|
||||||
)
|
)
|
||||||
if (
|
if hasattr(self.model, "get_input_embeddings") and (
|
||||||
hasattr(self.model, "get_input_embeddings")
|
self.model.get_input_embeddings().num_embeddings < embeddings_len
|
||||||
and self.model.get_input_embeddings().num_embeddings != embeddings_len
|
or (
|
||||||
|
self.model.get_input_embeddings().num_embeddings > embeddings_len
|
||||||
|
and self.cfg.shrink_embeddings
|
||||||
|
)
|
||||||
):
|
):
|
||||||
resize_kwargs = {}
|
resize_kwargs = {}
|
||||||
if self.cfg.mean_resizing_embeddings is not None:
|
if self.cfg.mean_resizing_embeddings is not None:
|
||||||
@@ -1309,6 +1312,7 @@ def load_lora(model, cfg, inference=False, config_only=False):
|
|||||||
lora_config_kwargs["init_lora_weights"] = "loftq"
|
lora_config_kwargs["init_lora_weights"] = "loftq"
|
||||||
if cfg.peft_use_dora:
|
if cfg.peft_use_dora:
|
||||||
lora_config_kwargs["use_dora"] = cfg.peft_use_dora
|
lora_config_kwargs["use_dora"] = cfg.peft_use_dora
|
||||||
|
LOG.info("Initializing LoRA weights using dora. This might take longer.")
|
||||||
if cfg.peft_use_rslora:
|
if cfg.peft_use_rslora:
|
||||||
lora_config_kwargs["use_rslora"] = cfg.peft_use_rslora
|
lora_config_kwargs["use_rslora"] = cfg.peft_use_rslora
|
||||||
if cfg.peft_layer_replication:
|
if cfg.peft_layer_replication:
|
||||||
|
|||||||
@@ -576,7 +576,7 @@ def prepare_opinionated_env(cfg):
|
|||||||
def setup_trainer(
|
def setup_trainer(
|
||||||
cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps
|
cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps
|
||||||
):
|
):
|
||||||
if cfg.rl in ("dpo", "ipo", "orpo", "kto", "simpo"):
|
if cfg.rl:
|
||||||
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer, processor)
|
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer, processor)
|
||||||
trainer_builder.model_ref = model[1]
|
trainer_builder.model_ref = model[1]
|
||||||
trainer_builder.peft_config = model[2]
|
trainer_builder.peft_config = model[2]
|
||||||
|
|||||||
173
tests/e2e/multigpu/test_grpo.py
Normal file
173
tests/e2e/multigpu/test_grpo.py
Normal file
@@ -0,0 +1,173 @@
|
|||||||
|
"""
|
||||||
|
GRPO test suite
|
||||||
|
"""
|
||||||
|
import random
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import yaml
|
||||||
|
from accelerate.test_utils import execute_subprocess_async
|
||||||
|
from e2e.utils import require_vllm
|
||||||
|
from transformers.testing_utils import get_torch_dist_unique_port
|
||||||
|
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
|
||||||
|
class TestGRPO:
|
||||||
|
"""
|
||||||
|
Test case for GRPO training using multilpe GPUs
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _utils_write_yaml_and_rewards(self, cfg, temp_dir, suffix=""):
|
||||||
|
# write cfg to yaml file
|
||||||
|
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||||
|
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||||
|
with open(f"rewards_{suffix}.py", "w", encoding="utf-8") as fout:
|
||||||
|
fout.write(
|
||||||
|
"""import random
|
||||||
|
def rand_reward_func(completions, **kwargs) -> list[float]:
|
||||||
|
return [random.uniform(0, 1) for _ in completions]
|
||||||
|
|
||||||
|
def oai_gsm8k_transform(cfg, *args, **kwargs):
|
||||||
|
def transform_fn(example, tokenizer=None):
|
||||||
|
label = example["answer"].split("####")[-1].strip().replace(",", "")
|
||||||
|
return {
|
||||||
|
"prompt": [{"role": "user", "content": example["question"]},],
|
||||||
|
"answer": label,
|
||||||
|
}
|
||||||
|
return transform_fn, {"remove_columns": ["question"]}
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"num_gpus",
|
||||||
|
[1, 2],
|
||||||
|
)
|
||||||
|
@require_vllm
|
||||||
|
def test_llama_dora(self, temp_dir, num_gpus):
|
||||||
|
rnd_reward_suffix = str(random.randint(1000, 9999))
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
|
"chat_template": "llama3",
|
||||||
|
"rl": "grpo",
|
||||||
|
"trl": {
|
||||||
|
"beta": 0.001,
|
||||||
|
"max_completion_length": 256,
|
||||||
|
"use_vllm": True,
|
||||||
|
"vllm_device": "auto" if num_gpus == 1 else "cuda:1",
|
||||||
|
"vllm_gpu_memory_utilization": 0.15,
|
||||||
|
"num_generations": 4,
|
||||||
|
"reward_funcs": [f"rewards_{rnd_reward_suffix}.rand_reward_func"],
|
||||||
|
},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "openai/gsm8k",
|
||||||
|
"name": "main",
|
||||||
|
"type": f"rewards_{rnd_reward_suffix}.oai_gsm8k_transform",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"adapter": "lora",
|
||||||
|
"lora_r": 8,
|
||||||
|
"lora_alpha": 16,
|
||||||
|
"lora_dropout": 0.05,
|
||||||
|
"lora_target_linear": True,
|
||||||
|
"peft_use_dora": True,
|
||||||
|
"flash_attention": True,
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"special_tokens": {
|
||||||
|
"pad_token": "<|endoftext|>",
|
||||||
|
},
|
||||||
|
"max_steps": 5,
|
||||||
|
"num_epochs": 1,
|
||||||
|
"micro_batch_size": 4,
|
||||||
|
"gradient_accumulation_steps": 2,
|
||||||
|
"warmup_steps": 10,
|
||||||
|
"val_set_size": 0.0,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.0001,
|
||||||
|
"optimizer": "adamw_torch_fused",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"save_safetensors": True,
|
||||||
|
"bf16": "auto",
|
||||||
|
"use_tensorboard": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_reward_suffix)
|
||||||
|
|
||||||
|
execute_subprocess_async(
|
||||||
|
[
|
||||||
|
"axolotl",
|
||||||
|
"train",
|
||||||
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
|
"--num-processes",
|
||||||
|
str(num_gpus),
|
||||||
|
"--main-process-port",
|
||||||
|
f"{get_torch_dist_unique_port()}",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"num_gpus",
|
||||||
|
[1, 2],
|
||||||
|
)
|
||||||
|
@require_vllm
|
||||||
|
def test_llama_fft(self, temp_dir, num_gpus):
|
||||||
|
rnd_reward_suffix = str(random.randint(1000, 9999))
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
|
"chat_template": "llama3",
|
||||||
|
"rl": "grpo",
|
||||||
|
"trl": {
|
||||||
|
"beta": 0.001,
|
||||||
|
"max_completion_length": 256,
|
||||||
|
"use_vllm": True,
|
||||||
|
"vllm_device": "auto" if num_gpus == 1 else "cuda:1",
|
||||||
|
"vllm_gpu_memory_utilization": 0.15,
|
||||||
|
"num_generations": 4,
|
||||||
|
"reward_funcs": [f"rewards_{rnd_reward_suffix}.rand_reward_func"],
|
||||||
|
},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "openai/gsm8k",
|
||||||
|
"name": "main",
|
||||||
|
"type": f"rewards_{rnd_reward_suffix}.oai_gsm8k_transform",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"flash_attention": True,
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"special_tokens": {
|
||||||
|
"pad_token": "<|endoftext|>",
|
||||||
|
},
|
||||||
|
"max_steps": 5,
|
||||||
|
"num_epochs": 1,
|
||||||
|
"micro_batch_size": 4,
|
||||||
|
"gradient_accumulation_steps": 2,
|
||||||
|
"warmup_steps": 10,
|
||||||
|
"val_set_size": 0.0,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.0001,
|
||||||
|
"optimizer": "adamw_torch_fused",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"save_safetensors": True,
|
||||||
|
"bf16": "auto",
|
||||||
|
"use_tensorboard": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_reward_suffix)
|
||||||
|
|
||||||
|
execute_subprocess_async(
|
||||||
|
[
|
||||||
|
"axolotl",
|
||||||
|
"train",
|
||||||
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
|
"--num-processes",
|
||||||
|
str(num_gpus),
|
||||||
|
"--main-process-port",
|
||||||
|
f"{get_torch_dist_unique_port()}",
|
||||||
|
]
|
||||||
|
)
|
||||||
@@ -78,6 +78,24 @@ def require_torch_lt_2_6_0(test_case):
|
|||||||
return unittest.skipUnless(is_max_2_6_0(), "test requires torch<2.6.0")(test_case)
|
return unittest.skipUnless(is_max_2_6_0(), "test requires torch<2.6.0")(test_case)
|
||||||
|
|
||||||
|
|
||||||
|
def require_vllm(test_case):
|
||||||
|
"""
|
||||||
|
Decorator marking a test that requires a vllm to be installed
|
||||||
|
"""
|
||||||
|
|
||||||
|
def is_vllm_installed():
|
||||||
|
try:
|
||||||
|
import vllm # pylint: disable=unused-import # noqa: F401
|
||||||
|
|
||||||
|
return True
|
||||||
|
except ImportError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return unittest.skipUnless(
|
||||||
|
is_vllm_installed(), "test requires a vllm to be installed"
|
||||||
|
)(test_case)
|
||||||
|
|
||||||
|
|
||||||
def is_hopper():
|
def is_hopper():
|
||||||
compute_capability = torch.cuda.get_device_capability()
|
compute_capability = torch.cuda.get_device_capability()
|
||||||
return compute_capability == (9, 0)
|
return compute_capability == (9, 0)
|
||||||
|
|||||||
Reference in New Issue
Block a user