diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 6c7e248bf..7fb97b9d9 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -24,7 +24,7 @@ jobs: cuda_version: 12.4.1 python_version: "3.11" pytorch: 2.5.1 - axolotl_extras: + axolotl_extras: vllm is_latest: true - cuda: 124 cuda_version: 12.4.1 diff --git a/.github/workflows/multi-gpu-e2e.yml b/.github/workflows/multi-gpu-e2e.yml index 6c9c7bb49..ea00d749b 100644 --- a/.github/workflows/multi-gpu-e2e.yml +++ b/.github/workflows/multi-gpu-e2e.yml @@ -24,20 +24,21 @@ jobs: cuda_version: 12.4.1 python_version: "3.11" pytorch: 2.4.1 - axolotl_extras: + axolotl_extras: # no vllm support for 2.4.1 num_gpus: 2 nightly_build: "true" - cuda: 124 cuda_version: 12.4.1 python_version: "3.11" pytorch: 2.5.1 - axolotl_extras: + axolotl_extras: vllm num_gpus: 2 nightly_build: "true" - cuda: 124 cuda_version: 12.4.1 python_version: "3.11" pytorch: 2.6.0 + # awaiting vllm#12721 axolotl_extras: num_gpus: 2 nightly_build: "true" diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 21218d72f..889339005 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -204,7 +204,7 @@ jobs: python_version: "3.11" pytorch: 2.5.1 num_gpus: 1 - axolotl_extras: + axolotl_extras: vllm steps: - name: Checkout uses: actions/checkout@v4 diff --git a/requirements.txt b/requirements.txt index 37a3f0d64..44f49289a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,7 +18,7 @@ tokenizers>=0.21.0 accelerate==1.3.0 datasets==3.2.0 deepspeed==0.16.1 -trl==0.13.0 +trl==0.15.0 optimum==1.16.2 hf_transfer @@ -26,7 +26,7 @@ sentencepiece gradio==3.50.2 modal==0.70.5 -pydantic==2.6.3 +pydantic==2.10.6 addict fire PyYAML>=6.0 diff --git a/setup.py b/setup.py index 6730591c0..b84c34525 100644 --- a/setup.py +++ b/setup.py @@ -79,7 +79,7 @@ def parse_requirements(): if patch == 0: _install_requires.append("xformers==0.0.28.post2") else: - _install_requires.append("xformers==0.0.29") + _install_requires.append("xformers>=0.0.28.post3") _install_requires.pop(_install_requires.index(autoawq_version)) elif (major, minor) >= (2, 4): if patch == 0: @@ -125,7 +125,7 @@ setup( }, extras_require={ "flash-attn": [ - "flash-attn==2.7.0.post2", + "flash-attn==2.7.4.post1", ], "deepspeed": [ "deepspeed==0.16.1", @@ -156,5 +156,8 @@ setup( "ray": [ "ray[train]", ], + "vllm": [ + "vllm==0.7.2", + ], }, ) diff --git a/src/axolotl/cli/cloud/__init__.py b/src/axolotl/cli/cloud/__init__.py index fde46e397..b879601be 100644 --- a/src/axolotl/cli/cloud/__init__.py +++ b/src/axolotl/cli/cloud/__init__.py @@ -35,13 +35,18 @@ def do_cli_train( cloud_config: Union[Path, str], config: Union[Path, str], accelerate: bool = True, + cwd=None, + **kwargs, ) -> None: print_axolotl_text_art() cloud_cfg = load_cloud_cfg(cloud_config) cloud = ModalCloud(cloud_cfg) with open(config, "r", encoding="utf-8") as file: config_yaml = file.read() - cloud.train(config_yaml, accelerate=accelerate) + local_dirs = {} + if cwd and not Path(cwd).joinpath("src", "axolotl").exists(): + local_dirs = {"/workspace/mounts": cwd} + cloud.train(config_yaml, accelerate=accelerate, local_dirs=local_dirs, **kwargs) def do_cli_lm_eval( diff --git a/src/axolotl/cli/cloud/modal_.py b/src/axolotl/cli/cloud/modal_.py index bcc47ead9..6b724f732 100644 --- a/src/axolotl/cli/cloud/modal_.py +++ b/src/axolotl/cli/cloud/modal_.py @@ -7,6 +7,7 @@ import os import subprocess # nosec B404 from pathlib import Path from random import randint +from typing import Optional import modal @@ -22,8 +23,18 @@ def run_cmd(cmd: str, run_folder: str, volumes=None): # modal workaround so it doesn't use the automounted axolotl new_env = copy.deepcopy(os.environ) + if "PYTHONPATH" in new_env: - del new_env["PYTHONPATH"] + paths = ["/workspace/mounts"] + for sub_python_path_str in new_env["PYTHONPATH"].split(":"): + sub_python_path = Path(sub_python_path_str) + if not sub_python_path.joinpath("src", "axolotl").exists(): + # we don't want to use the automounted axolotl or unexpected behavior happens + paths.append(str(sub_python_path)) + if paths: + new_env["PYTHONPATH"] = ":".join(paths) + else: + del new_env["PYTHONPATH"] # Propagate errors from subprocess. if exit_code := subprocess.call( # nosec B603 @@ -203,9 +214,12 @@ class ModalCloud(Cloud): memory = int(self.config.memory) return 1024 * memory - def get_train_env(self): + def get_train_env(self, local_dirs=None): + image = self.get_image() + for mount, local_dir in (local_dirs or {}).items(): + image = image.add_local_dir(local_dir, mount) return self.app.function( - image=self.get_image(), + image=image, volumes={k: v[0] for k, v in self.volumes.items()}, cpu=16.0, gpu=self.get_train_gpu(), @@ -214,14 +228,21 @@ class ModalCloud(Cloud): secrets=self.get_secrets(), ) - def train(self, config_yaml: str, accelerate: bool = True): - modal_fn = self.get_train_env()(_train) + def train( + self, + config_yaml: str, + accelerate: bool = True, + local_dirs: Optional[dict[str, str]] = None, + **kwargs, + ): + modal_fn = self.get_train_env(local_dirs)(_train) with modal.enable_output(): with self.app.run(detach=True): modal_fn.remote( config_yaml, accelerate=accelerate, volumes={k: v[0] for k, v in self.volumes.items()}, + **kwargs, ) def lm_eval(self, config_yaml: str): @@ -252,7 +273,7 @@ def _preprocess(config_yaml: str, volumes=None): ) -def _train(config_yaml: str, accelerate: bool = True, volumes=None): +def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs): with open( "/workspace/artifacts/axolotl/config.yaml", "w", encoding="utf-8" ) as f_out: @@ -262,8 +283,11 @@ def _train(config_yaml: str, accelerate: bool = True, volumes=None): accelerate_args = "--accelerate" else: accelerate_args = "--no-accelerate" + num_processes_args = "" + if num_processes := kwargs.pop("num_processes", None): + num_processes_args = f"--num-processes {num_processes}" run_cmd( - f"axolotl train {accelerate_args} /workspace/artifacts/axolotl/config.yaml", + f"axolotl train {accelerate_args} {num_processes_args} /workspace/artifacts/axolotl/config.yaml", run_folder, volumes, ) diff --git a/src/axolotl/cli/main.py b/src/axolotl/cli/main.py index d7aa1f6a7..61dc4403a 100644 --- a/src/axolotl/cli/main.py +++ b/src/axolotl/cli/main.py @@ -2,6 +2,7 @@ # pylint: disable=redefined-outer-name import logging +import os import random import subprocess # nosec B404 import tempfile @@ -12,6 +13,7 @@ from typing import Optional import click import yaml +from dotenv import load_dotenv import axolotl from axolotl.cli.args import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs @@ -199,7 +201,14 @@ def train( try: if accelerate: if cloud: - do_cli_train(cloud_config=cloud, config=config, accelerate=True) + cwd = os.getcwd() + do_cli_train( + cloud_config=cloud, + config=config, + accelerate=True, + cwd=cwd, + **kwargs, + ) else: accelerate_args = [] if "main_process_port" in kwargs: @@ -208,7 +217,7 @@ def train( accelerate_args.append(str(main_process_port)) if "num_processes" in kwargs: num_processes = kwargs.pop("num_processes", None) - accelerate_args.append("--num-processes") + accelerate_args.append("--num_processes") accelerate_args.append(str(num_processes)) base_cmd = ["accelerate", "launch"] @@ -220,7 +229,9 @@ def train( subprocess.run(cmd, check=True) # nosec B603 else: if cloud: - do_cli_train(cloud_config=cloud, config=config, accelerate=False) + do_cli_train( + cloud_config=cloud, config=config, accelerate=False, **kwargs + ) else: from axolotl.cli.train import do_cli @@ -381,4 +392,5 @@ def main(): if __name__ == "__main__": + load_dotenv() main() diff --git a/src/axolotl/common/datasets.py b/src/axolotl/common/datasets.py index cbc0d127c..db07eb43b 100644 --- a/src/axolotl/common/datasets.py +++ b/src/axolotl/common/datasets.py @@ -122,9 +122,11 @@ def load_preference_datasets( `total_num_steps`. """ train_dataset, eval_dataset = load_prepare_preference_datasets(cfg) - total_num_steps = int( + total_num_steps: Optional[int] = int( math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) ) + if cfg.rl == "grpo": + total_num_steps = None if cli_args.debug or cfg.debug: LOG.info("check_dataset_labels...") diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 89480d775..ae757cf43 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -39,7 +39,6 @@ from trl.trainer.utils import RewardDataCollatorWithPadding from axolotl.core.trainers.base import ( AxolotlCPOTrainer, - AxolotlDPOTrainer, AxolotlKTOTrainer, AxolotlMambaTrainer, AxolotlORPOTrainer, @@ -48,9 +47,11 @@ from axolotl.core.trainers.base import ( AxolotlTrainer, ReLoRATrainer, ) +from axolotl.core.trainers.dpo import DPOStrategy +from axolotl.core.trainers.dpo.args import AxolotlDPOConfig +from axolotl.core.trainers.grpo import GRPOStrategy from axolotl.core.training_args import ( AxolotlCPOConfig, - AxolotlDPOConfig, AxolotlKTOConfig, AxolotlORPOConfig, AxolotlPRMConfig, @@ -641,9 +642,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): tokenizer=self.tokenizer, ) - if self.cfg.rl == "orpo": - training_arguments_kwargs["orpo_alpha"] = self.cfg.orpo_alpha - if self.cfg.neftune_noise_alpha is not None: training_arguments_kwargs[ "neftune_noise_alpha" @@ -652,7 +650,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): trainer_kwargs = {} if self.cfg.reward_model: - trainer_kwargs["max_length"] = self.cfg.sequence_len + training_arguments_kwargs["max_length"] = self.cfg.sequence_len # pylint: disable=duplicate-code if self.cfg.optimizer in [ @@ -965,10 +963,11 @@ class HFRLTrainerBuilder(TrainerBuilderBase): # default to saving each epoch if not defined training_args_kwargs["save_strategy"] = "epoch" - training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes + if self.cfg.dataset_processes: + training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes - if self.cfg.rl_beta: - training_args_kwargs["beta"] = self.cfg.rl_beta + if (self.cfg.trl and self.cfg.trl.beta) or self.cfg.rl_beta: + training_args_kwargs["beta"] = self.cfg.trl.beta or self.cfg.rl_beta if self.cfg.orpo_alpha: # trl does some odd mapping of alpha to beta to reuse the beta parameter ??? training_args_kwargs["beta"] = self.cfg.orpo_alpha @@ -977,6 +976,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase): training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha training_args_cls = None + blocklist_args_kwargs = [] if self.cfg.rl == "simpo": training_args_cls = AxolotlCPOConfig training_args_kwargs["loss_type"] = "simpo" @@ -1001,11 +1001,15 @@ class HFRLTrainerBuilder(TrainerBuilderBase): self.cfg.kto_undesirable_weight or 1.0 ) - training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes training_args_kwargs["max_length"] = self.cfg.sequence_len if self.cfg.max_prompt_len: training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len + elif self.cfg.rl == "grpo": + training_args_cls = GRPOStrategy.get_training_args_class() + training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg)) + blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs() + else: training_args_cls = AxolotlDPOConfig if self.cfg.rl == "ipo": @@ -1016,11 +1020,21 @@ class HFRLTrainerBuilder(TrainerBuilderBase): training_args_kwargs["generate_during_eval"] = self.cfg.use_wandb if self.cfg.dpo_use_weighting is not None: training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting + if self.cfg.dpo_use_logits_to_keep is not None: + training_args_kwargs[ + "use_logits_to_keep" + ] = self.cfg.dpo_use_logits_to_keep + for blocklist_key in blocklist_args_kwargs: + if blocklist_key in training_args_kwargs: + del training_args_kwargs[blocklist_key] + + max_steps = self.cfg.max_steps or total_num_steps or -1 + training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg - output_dir=self.cfg.output_dir, + self.cfg.output_dir, per_device_train_batch_size=self.cfg.micro_batch_size, - max_steps=self.cfg.max_steps or total_num_steps, + max_steps=max_steps, gradient_accumulation_steps=self.cfg.gradient_accumulation_steps, learning_rate=self.cfg.learning_rate, warmup_steps=self.cfg.warmup_steps, @@ -1047,8 +1061,13 @@ class HFRLTrainerBuilder(TrainerBuilderBase): dpo_trainer_kwargs[ "precompute_ref_log_probs" ] = self.cfg.precompute_ref_log_probs - if self.cfg.rl in ["dpo", "ipo"]: - trainer_cls = AxolotlDPOTrainer + if self.cfg.rl == "grpo": + trainer_cls = GRPOStrategy.get_trainer_class() + trainer_cls_args = [self.model] + trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg)) + dpo_trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg)) + elif self.cfg.rl in ["dpo", "ipo"]: + trainer_cls = DPOStrategy.get_trainer_class() trainer_cls_args = [self.model, self.model_ref] elif self.cfg.rl == "orpo": trainer_cls = AxolotlORPOTrainer @@ -1063,12 +1082,14 @@ class HFRLTrainerBuilder(TrainerBuilderBase): raise ValueError(f"Unsupported RL: {self.cfg.rl}") sig = inspect.signature(trainer_cls) - if "processing_class" in sig.parameters.keys(): - dpo_trainer_kwargs["processing_class"] = self.tokenizer - else: + if "tokenizer" in sig.parameters.keys(): dpo_trainer_kwargs["tokenizer"] = self.tokenizer + else: + dpo_trainer_kwargs["processing_class"] = self.tokenizer - if self.cfg.datasets is not None and (trainer_cls is AxolotlDPOTrainer): + if self.cfg.datasets is not None and ( + trainer_cls is DPOStrategy.get_trainer_class() + ): dpo_trainer_kwargs["dataset_tags"] = [ d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir() ] diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 44a6d54d7..ee2545b21 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -5,30 +5,21 @@ module for customized trainers from __future__ import annotations # pylint: disable=too-many-lines -import gc import logging import os from collections import defaultdict from functools import wraps -from typing import Any, Dict, Literal, Optional, Union +from typing import Dict, Literal, Optional import torch from datasets import Dataset from peft.optimizers import create_loraplus_optimizer -from torch import nn from torch.optim.lr_scheduler import OneCycleLR from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler from transformers import Trainer from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker from transformers.utils import is_sagemaker_mp_enabled -from trl import ( - CPOTrainer, - DPOTrainer, - KTOTrainer, - ORPOTrainer, - PRMTrainer, - RewardTrainer, -) +from trl import CPOTrainer, KTOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer from trl.trainer.utils import pad_to_length from axolotl.monkeypatch.relora import ReLoRAScheduler @@ -847,107 +838,6 @@ class ReLoRATrainer(AxolotlTrainer): return self.lr_scheduler -class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer): - """ - Extend the base DPOTrainer for axolotl helpers - """ - - tag_names = ["axolotl", "dpo"] - - def __init__(self, *args, dataset_tags=None, **kwargs): - super().__init__(*args, **kwargs) - self.dataset_tags = dataset_tags - self.optimizer = None - self.model_accepts_loss_kwargs = False - - def create_optimizer(self): - if self.args.loraplus_lr_ratio is None: - return super().create_optimizer() - - opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model - if self.optimizer is None: # pylint: disable=access-member-before-definition - optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs( - self.args, - opt_model, - ) - - loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None) - if loraplus_lr_ratio: - print("Using lora+") - loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None) - self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init - opt_model, - optimizer_cls, - loraplus_lr_ratio=loraplus_lr_ratio, - loraplus_lr_embedding=loraplus_lr_embedding, - **optimizer_kwargs, - ) - - if is_sagemaker_mp_enabled(): - self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init - self.optimizer - ) - - return self.optimizer - - @wraps(DPOTrainer.push_to_hub) - def push_to_hub(self, *args, **kwargs) -> str: - """ - Overwrite the `push_to_hub` method in order to force-add the tags when pushing the - model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. - """ - kwargs = _sanitize_kwargs_for_ds_tagging( - dataset_tags=self.dataset_tags, kwargs=kwargs - ) - kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs) - - return super().push_to_hub(*args, **kwargs) - - @staticmethod - def tokenize_row( - features, - processing_class, - max_prompt_length, - max_completion_length, - add_special_tokens, - ) -> Dict: - res = DPOTrainer.tokenize_row( - features, - processing_class, - max_prompt_length, - max_completion_length, - add_special_tokens, - ) - # fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen - if processing_class.bos_token is None and res["prompt_input_ids"][0] is None: - for key in res.keys(): - res[key] = res[key][1:] - - if processing_class.bos_token and processing_class.bos_token_id is not None: - # dpo trainer may incorrectly prepend the bos_token_id to the dpo outputs - if res["chosen_input_ids"][0] == processing_class.bos_token_id: - res["chosen_input_ids"] = res["chosen_input_ids"][1:] - res["chosen_labels"] = res["chosen_labels"][1:] - res["chosen_attention_mask"] = res["chosen_attention_mask"][1:] - if res["rejected_input_ids"][0] == processing_class.bos_token_id: - res["rejected_input_ids"] = res["rejected_input_ids"][1:] - res["rejected_labels"] = res["rejected_labels"][1:] - res["rejected_attention_mask"] = res["rejected_attention_mask"][1:] - - return res - - def training_step( - self, - model: nn.Module, - inputs: Dict[str, Union[torch.Tensor, Any]], - num_items_in_batch=None, - ) -> torch.Tensor: - loss: torch.Tensor = super().training_step(model, inputs, num_items_in_batch) - gc.collect() - torch.cuda.empty_cache() - return loss - - class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer): """ Extend the base ORPOTrainer for axolotl helpers diff --git a/src/axolotl/core/trainers/dpo/__init__.py b/src/axolotl/core/trainers/dpo/__init__.py new file mode 100644 index 000000000..8187a7fb5 --- /dev/null +++ b/src/axolotl/core/trainers/dpo/__init__.py @@ -0,0 +1,33 @@ +""" +DPO Specific Strategy for training +""" +from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer + + +class DPOStrategy: + """ + Strategy for DPO training + """ + + @classmethod + def get_trainer_class(cls): + return AxolotlDPOTrainer + + @classmethod + def get_training_args_class(cls): + from axolotl.core.trainers.dpo.args import AxolotlDPOConfig + + return AxolotlDPOConfig + + @classmethod + def set_training_args_kwargs(cls, cfg): + training_args_kwargs = {} + if cfg.rl == "ipo": + training_args_kwargs["loss_type"] = "ipo" + training_args_kwargs["max_length"] = cfg.sequence_len + training_args_kwargs["max_completion_length"] = None + training_args_kwargs["max_prompt_length"] = cfg.sequence_len + training_args_kwargs["generate_during_eval"] = cfg.use_wandb + if cfg.dpo_use_weighting is not None: + training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting + return training_args_kwargs diff --git a/src/axolotl/core/trainers/dpo/args.py b/src/axolotl/core/trainers/dpo/args.py new file mode 100644 index 000000000..4cae67d3e --- /dev/null +++ b/src/axolotl/core/trainers/dpo/args.py @@ -0,0 +1,15 @@ +""" +Axolotl specific DPO args +""" +from dataclasses import dataclass + +from trl import DPOConfig + +from axolotl.core.training_args import AxolotlTrainingMixins + + +@dataclass +class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig): + """ + DPO config for DPO training + """ diff --git a/src/axolotl/core/trainers/dpo/trainer.py b/src/axolotl/core/trainers/dpo/trainer.py new file mode 100644 index 000000000..a1de4cc82 --- /dev/null +++ b/src/axolotl/core/trainers/dpo/trainer.py @@ -0,0 +1,125 @@ +""" +DPO trainer for axolotl +""" +import gc +from functools import wraps +from typing import Any, Dict, Union + +import torch +from peft.optimizers import create_loraplus_optimizer +from torch import nn +from transformers import Trainer +from transformers.utils import is_sagemaker_mp_enabled +from trl import DPOTrainer + +from axolotl.core.trainers.base import ( + SchedulerMixin, + _sanitize_kwargs_for_ds_tagging, + _sanitize_kwargs_for_tagging, +) + +if is_sagemaker_mp_enabled(): + import smdistributed.modelparallel.torch as smp + + +class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer): + """ + Extend the base DPOTrainer for axolotl helpers + """ + + tag_names = ["axolotl", "dpo"] + + def __init__(self, *args, dataset_tags=None, **kwargs): + super().__init__(*args, **kwargs) + self.dataset_tags = dataset_tags + self.optimizer = None + self.model_accepts_loss_kwargs = False + + def create_optimizer(self): + # pylint: disable=duplicate-code + if self.args.loraplus_lr_ratio is None: + return super().create_optimizer() + + opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model + if self.optimizer is None: # pylint: disable=access-member-before-definition + optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs( + self.args, + opt_model, + ) + + loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None) + if loraplus_lr_ratio: + print("Using lora+") + loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None) + # pylint: disable=duplicate-code + self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init + opt_model, + optimizer_cls, + loraplus_lr_ratio=loraplus_lr_ratio, + loraplus_lr_embedding=loraplus_lr_embedding, + **optimizer_kwargs, + ) + + if is_sagemaker_mp_enabled(): + self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init + self.optimizer + ) + + return self.optimizer + + @wraps(DPOTrainer.push_to_hub) + def push_to_hub(self, *args, **kwargs) -> str: + """ + Overwrite the `push_to_hub` method in order to force-add the tags when pushing the + model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. + """ + kwargs = _sanitize_kwargs_for_ds_tagging( + dataset_tags=self.dataset_tags, kwargs=kwargs + ) + kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs) + + return super().push_to_hub(*args, **kwargs) + + @staticmethod + def tokenize_row( + features, + processing_class, + max_prompt_length, + max_completion_length, + add_special_tokens, + ) -> Dict: + res = DPOTrainer.tokenize_row( + features, + processing_class, + max_prompt_length, + max_completion_length, + add_special_tokens, + ) + # fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen + if processing_class.bos_token is None and res["prompt_input_ids"][0] is None: + for key in res.keys(): + res[key] = res[key][1:] + + if processing_class.bos_token and processing_class.bos_token_id is not None: + # dpo trainer may incorrectly prepend the bos_token_id to the dpo outputs + if res["chosen_input_ids"][0] == processing_class.bos_token_id: + res["chosen_input_ids"] = res["chosen_input_ids"][1:] + res["chosen_labels"] = res["chosen_labels"][1:] + res["chosen_attention_mask"] = res["chosen_attention_mask"][1:] + if res["rejected_input_ids"][0] == processing_class.bos_token_id: + res["rejected_input_ids"] = res["rejected_input_ids"][1:] + res["rejected_labels"] = res["rejected_labels"][1:] + res["rejected_attention_mask"] = res["rejected_attention_mask"][1:] + + return res + + def training_step( + self, + model: nn.Module, + inputs: Dict[str, Union[torch.Tensor, Any]], + num_items_in_batch=None, + ) -> torch.Tensor: + loss: torch.Tensor = super().training_step(model, inputs, num_items_in_batch) + gc.collect() + torch.cuda.empty_cache() + return loss diff --git a/src/axolotl/core/trainers/grpo/__init__.py b/src/axolotl/core/trainers/grpo/__init__.py new file mode 100644 index 000000000..5202cb09d --- /dev/null +++ b/src/axolotl/core/trainers/grpo/__init__.py @@ -0,0 +1,119 @@ +""" +GRPO Specific Strategy for training +""" + +import importlib +import inspect +import logging + +from trl.trainer.grpo_trainer import RewardFunc + +from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer + +LOG = logging.getLogger("axolotl") + + +class GRPOStrategy: + """ + Strategy for GRPO training + """ + + @classmethod + def get_trainer_class(cls): + return AxolotlGRPOTrainer + + @classmethod + def get_training_args_class(cls): + from axolotl.core.trainers.grpo.args import AxolotlGRPOConfig + + return AxolotlGRPOConfig + + @classmethod + def set_training_args_kwargs(cls, cfg): + grpo_args_kwargs = {} + if cfg.trl and cfg.trl.use_vllm: + grpo_args_kwargs["use_vllm"] = cfg.trl.use_vllm + if cfg.trl and cfg.trl.vllm_device: + grpo_args_kwargs["vllm_device"] = cfg.trl.vllm_device + else: + grpo_args_kwargs["vllm_device"] = "auto" + if cfg.trl and cfg.trl.vllm_gpu_memory_utilization: + grpo_args_kwargs[ + "vllm_gpu_memory_utilization" + ] = cfg.trl.vllm_gpu_memory_utilization + if cfg.trl and cfg.trl.vllm_max_model_len: + grpo_args_kwargs["vllm_max_model_len"] = cfg.trl.vllm_max_model_len + if cfg.trl and cfg.trl.num_generations: + grpo_args_kwargs["num_generations"] = cfg.trl.num_generations + if cfg.trl and cfg.trl.sync_ref_model: + grpo_args_kwargs["sync_ref_model"] = cfg.trl.sync_ref_model + if cfg.trl and cfg.trl.ref_model_mixup_alpha: + grpo_args_kwargs[ + "ref_model_mixup_alpha" + ] = cfg.trl.ref_model_mixup_alpha + if cfg.trl and cfg.trl.ref_model_sync_steps: + grpo_args_kwargs["ref_model_sync_steps"] = cfg.trl.ref_model_sync_steps + grpo_args_kwargs["max_completion_length"] = cfg.trl.max_completion_length + grpo_args_kwargs["log_completions"] = cfg.trl.log_completions + return grpo_args_kwargs + + @classmethod + def set_trainer_args(cls, cfg): + trainer_args = [] + if cfg.trl and cfg.trl.reward_funcs: + reward_funcs = [] + for reward_func_fqn in cfg.trl.reward_funcs: + reward_funcs.append(cls.get_reward_func(reward_func_fqn)) + trainer_args.append(reward_funcs) + return trainer_args + + @classmethod + def set_trainer_kwargs(cls, cfg): + trainer_kwargs = {} + if cfg.trl and cfg.trl.reward_processing_classes: + trainer_kwargs[ + "reward_processing_classes" + ] = cfg.trl.reward_processing_classes + return trainer_kwargs + + @classmethod + def get_collator(cls, *args, **kwargs): # pylint: disable=unused-argument + # No data collation is needed in GRPO, handled by trl's trainer __init__ + return None + + @classmethod + def get_blocklist_args_kwargs(cls): + return ["dataset_num_proc"] + + @classmethod + def get_reward_func(cls, reward_func_fqn: str) -> RewardFunc: + """ + Returns the reward function from the given fully qualified name, or the path to the reward function model. + + Args: + reward_func_fqn (str): Fully qualified name of the reward function (e.g. r1_grpo.gsm8k_transform), + or a HF hub path to the reward model. + Raises: + ValueError: If the reward function does not accept at least two arguments. + + Returns: + RewardFunc: A callable that accepts prompts and completions and returns rewards, + or a path to a reward model. + + """ + try: + # use importlib to dynamically load the reward function from the module + reward_func_module_name = reward_func_fqn.split(".")[-1] + reward_func_module = importlib.import_module(reward_func_fqn.split(".")[-2]) + reward_func = getattr(reward_func_module, reward_func_module_name) + if not len(inspect.signature(reward_func).parameters) >= 2: + raise ValueError( + "Reward function must accept at least two arguments: prompts: list and completions: list" + ) + return reward_func + except ModuleNotFoundError: + # the user has passed a string (ideally indicating the path of a reward model) + LOG.info( + f"Reward function {reward_func_fqn} is a pre-trained model path - if this is unexpected, please check the reward function path." + ) + return reward_func diff --git a/src/axolotl/core/trainers/grpo/args.py b/src/axolotl/core/trainers/grpo/args.py new file mode 100644 index 000000000..e14e6b0dc --- /dev/null +++ b/src/axolotl/core/trainers/grpo/args.py @@ -0,0 +1,15 @@ +""" +Axolotl Specific Training Args +""" +from dataclasses import dataclass + +from trl import GRPOConfig + +from axolotl.core.training_args import AxolotlTrainingMixins + + +@dataclass +class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig): + """ + Axolotl GRPO Config for GRPO training + """ diff --git a/src/axolotl/core/trainers/grpo/trainer.py b/src/axolotl/core/trainers/grpo/trainer.py new file mode 100644 index 000000000..8f8b9fcf9 --- /dev/null +++ b/src/axolotl/core/trainers/grpo/trainer.py @@ -0,0 +1,107 @@ +""" +Axolotl GRPO trainer +""" +from accelerate.utils import is_peft_model +from accelerate.utils.other import is_compiled_module +from transformers import PreTrainedModel +from trl import GRPOConfig, GRPOTrainer +from trl.models import unwrap_model_for_generation + +from axolotl.core.trainers.base import SchedulerMixin + + +# mypy: ignore-errors +class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer): + """ + Extend the base GRPOTrainer for axolotl helpers + """ + + _tag_names = ["trl", "grpo", "axolotl"] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # pylint: disable=access-member-before-definition + # Enable gradient checkpointing if requested + if kwargs["args"].gradient_checkpointing: + # Ensure use_cache is disabled + if hasattr(self.model, "config"): + self.model.config.use_cache = False + + # Enable gradient checkpointing on the base model for PEFT + if is_peft_model(self.model) and hasattr( + self.model.base_model, "gradient_checkpointing_enable" + ): + self.model.base_model.gradient_checkpointing_enable() + # Enable gradient checkpointing for non-PEFT models + elif hasattr(self.model, "gradient_checkpointing_enable"): + self.model.gradient_checkpointing_enable() + self.model = self._enable_gradient_checkpointing(self.model, kwargs["args"]) + # pylint: enable=access-member-before-definition + + def _enable_gradient_checkpointing( + self, model: PreTrainedModel, args: GRPOConfig + ) -> PreTrainedModel: + """Enables gradient checkpointing for the model.""" + # pylint: disable=unused-argument,redefined-builtin + gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {} + use_reentrant = ( + "use_reentrant" not in gradient_checkpointing_kwargs + or gradient_checkpointing_kwargs["use_reentrant"] + ) + + if use_reentrant: + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook( + make_inputs_require_grad + ) + + return model + # pylint: enable=unused-argument,redefined-builtin + + def _move_model_to_vllm(self): + with unwrap_model_for_generation( + self.model, + self.accelerator, + gather_deepspeed3_params=self.args.ds3_gather_for_generation, + ) as unwrapped_model: + if is_compiled_module(unwrapped_model): + unwrapped_model = ( + unwrapped_model._orig_mod # pylint: disable=protected-access + ) + if is_peft_model(unwrapped_model): + unwrapped_model.merge_adapter() + state_dict = unwrapped_model.state_dict() + unwrapped_model.unmerge_adapter() + # Remove base_model and base_layer prefixes + state_dict = { + k.removeprefix("base_model.model.") + .removeprefix("base_model.model.") + .replace(".base_layer", ""): v + for k, v in state_dict.items() + } + # Remove values with adapter prefix (example: "_lora") + state_dict = { + k: v + for k, v in state_dict.items() + if unwrapped_model.prefix not in k + } + # When module to save, remove its prefix and discard the original module + state_dict = { + k.replace("modules_to_save.default.", ""): v + for k, v in state_dict.items() + if "original_module" not in k + } + else: + state_dict = unwrapped_model.state_dict() + if self.accelerator.is_main_process: + llm_model = ( + self.llm.llm_engine.model_executor.driver_worker.model_runner.model + ) + llm_model.load_weights(state_dict.items()) diff --git a/src/axolotl/core/training_args.py b/src/axolotl/core/training_args.py index 9eae52162..7cace7643 100644 --- a/src/axolotl/core/training_args.py +++ b/src/axolotl/core/training_args.py @@ -5,7 +5,7 @@ from dataclasses import dataclass, field from typing import Optional from transformers import TrainingArguments -from trl import CPOConfig, DPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig +from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig @dataclass @@ -217,13 +217,6 @@ class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments): """ -@dataclass -class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig): - """ - DPO config for DPO training - """ - - @dataclass class AxolotlORPOConfig(AxolotlTrainingMixins, ORPOConfig): """ diff --git a/src/axolotl/prompt_strategies/base.py b/src/axolotl/prompt_strategies/base.py index cddb3d0e1..c146133fb 100644 --- a/src/axolotl/prompt_strategies/base.py +++ b/src/axolotl/prompt_strategies/base.py @@ -13,8 +13,19 @@ def load(strategy, cfg, module_base=None, **kwargs): if len(strategy.split(".")) == 1: strategy = strategy + ".default" load_fn = strategy.split(".")[-1] - strategy = ".".join(strategy.split(".")[:-1]) - mod = importlib.import_module(f".{strategy}", module_base) + if len(strategy.split(".")) > 1: + try: + importlib.import_module( + strategy.split(".")[-2], + ".".join(strategy.split(".")[:-2]), + ) + module_base = ".".join(strategy.split(".")[:-2]) + strategy = strategy.split(".")[-2] + except ModuleNotFoundError: + strategy = "." + ".".join(strategy.split(".")[:-1]) + else: + strategy = "." + ".".join(strategy.split(".")[:-1]) + mod = importlib.import_module(strategy, module_base) func = getattr(mod, load_fn) return func(cfg, **kwargs) except Exception: # pylint: disable=broad-exception-caught diff --git a/src/axolotl/prompt_strategies/dpo/passthrough.py b/src/axolotl/prompt_strategies/dpo/passthrough.py new file mode 100644 index 000000000..1fcb838db --- /dev/null +++ b/src/axolotl/prompt_strategies/dpo/passthrough.py @@ -0,0 +1,14 @@ +""" +DPO prompt strategies passthrough/zero-processing strategy +""" + + +def default( + cfg, dataset_idx=0, **kwargs +): # pylint: disable=possibly-unused-variable,unused-argument + def transform_fn( + sample, tokenizer=None + ): # pylint: disable=possibly-unused-variable,unused-argument + return sample + + return transform_fn diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index aa79c0f61..868328b0b 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -24,6 +24,8 @@ from transformers.utils.import_utils import is_torch_npu_available from axolotl.utils.config.models.internals import EnvCapabilities, GPUCapabilities +from .trl import TRLConfig + LOG = logging.getLogger("axolotl.utils.config.models.input") SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"} @@ -33,6 +35,7 @@ class RLType(str, Enum): """RL trainer type configuration subset""" dpo = "dpo" # pylint: disable=invalid-name + grpo = "grpo" # pylint: disable=invalid-name ipo = "ipo" # pylint: disable=invalid-name orpo = "orpo" # pylint: disable=invalid-name kto = "kto" # pylint: disable=invalid-name @@ -664,14 +667,20 @@ class AxolotlInputConfig( auto_resume_from_checkpoints: Optional[bool] = None resize_token_embeddings_to_32x: Optional[bool] = None mean_resizing_embeddings: Optional[bool] = False + # optionally shrink the embeddings when the tokenizer vocab size is smaller + shrink_embeddings: Optional[bool] = None rl: Optional[RLType] = None + trl: Optional[TRLConfig] = Field( + default_factory=lambda: TRLConfig(), # pylint: disable=unnecessary-lambda + ) reward_model: Optional[bool] = None process_reward_model: Optional[bool] = None num_labels: Optional[int] = None dpo_use_weighting: Optional[ bool ] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer. + dpo_use_logits_to_keep: Optional[bool] = None datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedDataset], min_length=1)] = None # type: ignore test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedDataset], min_length=1)] = None # type: ignore diff --git a/src/axolotl/utils/config/models/input/v0_4_1/trl.py b/src/axolotl/utils/config/models/input/v0_4_1/trl.py new file mode 100644 index 000000000..6361bb249 --- /dev/null +++ b/src/axolotl/utils/config/models/input/v0_4_1/trl.py @@ -0,0 +1,35 @@ +""" +GRPO specific configuration args +""" +from typing import List, Optional + +from pydantic import BaseModel, Field + + +class TRLConfig(BaseModel): + """ + Input args for TRL. + """ + + beta: Optional[float] = None + max_completion_length: Optional[int] = Field( + default=None, + json_schema_extra={ + "description": "Maximum length of the completion for RL training" + }, + ) + + # GRPO specific args + use_vllm: Optional[bool] = False + vllm_device: Optional[str] = "auto" + vllm_gpu_memory_utilization: Optional[float] = 0.9 + vllm_max_model_len: Optional[int] = None + vllm_dtype: Optional[str] = "auto" + + reward_funcs: Optional[List[str]] = None + num_generations: Optional[int] = None + log_completions: Optional[bool] = False + + sync_ref_model: Optional[bool] = False + ref_model_mixup_alpha: Optional[float] = 0.9 + ref_model_sync_steps: Optional[int] = 64 diff --git a/src/axolotl/utils/data/rl.py b/src/axolotl/utils/data/rl.py index 67075cc9f..4c7b71292 100644 --- a/src/axolotl/utils/data/rl.py +++ b/src/axolotl/utils/data/rl.py @@ -58,7 +58,7 @@ def _save_preprocessed_ds(cfg, sub_cfg, dataset): dataset.save_to_disk(str(prepared_ds_path)) -def map_dataset(cfg, data_set, ds_transform_fn, tokenizer): +def map_dataset(cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs): sig = inspect.signature(ds_transform_fn) if "tokenizer" in sig.parameters: if not tokenizer: @@ -71,6 +71,7 @@ def map_dataset(cfg, data_set, ds_transform_fn, tokenizer): data_set = data_set.map( ds_transform_fn, desc="Mapping RL Dataset", + **map_kwargs, ) return data_set @@ -113,6 +114,9 @@ def drop_long_rl_seq( return (len_prompt + len_completion) <= sequence_len + if rl == "grpo": + return True + raise ValueError("Unknown RL type") @@ -140,36 +144,45 @@ def load_prepare_preference_datasets(cfg): else: ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i) + map_kwargs = {} + if isinstance(ds_transform_fn, tuple): + ds_transform_fn, map_kwargs = ds_transform_fn split_datasets[i] = map_dataset( - cfg, data_set, ds_transform_fn, tokenizer + cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs ) elif _cfg.rl == "kto": ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i) + map_kwargs = {} + if isinstance(ds_transform_fn, tuple): + ds_transform_fn, map_kwargs = ds_transform_fn split_datasets[i] = map_dataset( - cfg, data_set, ds_transform_fn, tokenizer + cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs ) else: # If no `type` is provided, assume the dataset is already in the expected format with # "prompt", "chosen" and "rejected" already preprocessed split_datasets[i] = data_set - drop_long = partial( - drop_long_rl_seq, - rl=_cfg.rl, - tokenizer=tokenizer, - sequence_len=cfg.sequence_len, - ) + if not cfg.skip_prepare_dataset: + drop_long = partial( + drop_long_rl_seq, + rl=_cfg.rl, + tokenizer=tokenizer, + sequence_len=cfg.sequence_len, + ) - prior_len = len(split_datasets[i]) - split_datasets[i] = split_datasets[i].filter( - drop_long, - num_proc=cfg.dataset_processes, - load_from_cache_file=not cfg.is_preprocess, - desc="Dropping Long Sequences", - ) - dropped = prior_len - len(split_datasets[i]) - if dropped: - LOG.warning(f"Dropped {dropped} long samples from dataset index {i}") + prior_len = len(split_datasets[i]) + split_datasets[i] = split_datasets[i].filter( + drop_long, + num_proc=cfg.dataset_processes, + load_from_cache_file=not cfg.is_preprocess, + desc="Dropping Long Sequences", + ) + dropped = prior_len - len(split_datasets[i]) + if dropped: + LOG.warning( + f"Dropped {dropped} long samples from dataset index {i}" + ) combined_datasets = concatenate_datasets(split_datasets) combined_datasets = combined_datasets.shuffle(seed=cfg.seed) diff --git a/src/axolotl/utils/lora.py b/src/axolotl/utils/lora.py new file mode 100644 index 000000000..759c17ac2 --- /dev/null +++ b/src/axolotl/utils/lora.py @@ -0,0 +1,75 @@ +# Copyright 2025 Axolotl AI. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +module to get the state dict of a merged lora model +""" +import torch +from peft.tuners.tuners_utils import onload_layer +from peft.utils import ModulesToSaveWrapper, _get_submodules + + +def get_lora_merged_state_dict( + model: torch.nn.Module, +) -> dict: + r""" + Create and return a state_dict that has the LoRA deltas + merged into the base model’s weights, without modifying `model` in place. + + Arguments: + model (torch.nn.Module): A model that has LoRA/PEFT adapters attached. + + Returns: + dict: A state_dict of the merged parameters. + """ + + base_model_prefix = "base_model.model." + state_dict = {} + key_list = [key for key, _ in model.named_modules() if model.prefix not in key] + for key in key_list: + try: + _, target, _ = _get_submodules(model, key) + except AttributeError: + continue + with onload_layer(target): + weight_key = key.replace(base_model_prefix, "") + ".weight" + bias_key = key.replace(base_model_prefix, "") + ".bias" + if hasattr(target, "base_layer"): + target.merge(safe_merge=True, adapter_names=None) + # get the state_dict of target.base_layer + layer_state_dict = target.base_layer.state_dict() + state_dict[weight_key] = layer_state_dict["weight"] + elif isinstance(target, ModulesToSaveWrapper): + # save any additional trainable modules part of `modules_to_save` + new_module = target.modules_to_save[target.active_adapter] + if hasattr(new_module, "base_layer"): + # check if the module is itself a tuner layer + new_module.merge(safe_merge=True, adapter_names=None) + layer_state_dict = new_module.state_dict() + state_dict[weight_key] = layer_state_dict["weight"] + elif hasattr(target, "weight"): + if any( + skip in key + for skip in [ + ".original_module", + ".modules_to_save", + ".base_layer", + ] + ): + continue + layer_state_dict = target.state_dict() + state_dict[weight_key] = layer_state_dict["weight"] + if hasattr(target, "bias") and "bias" in layer_state_dict.keys(): + state_dict[bias_key] = layer_state_dict["bias"] + return state_dict diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index be5b2782a..a96ecb0cf 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -1053,9 +1053,12 @@ class ModelLoader: if self.cfg.resize_token_embeddings_to_32x else len(self.tokenizer) ) - if ( - hasattr(self.model, "get_input_embeddings") - and self.model.get_input_embeddings().num_embeddings != embeddings_len + if hasattr(self.model, "get_input_embeddings") and ( + self.model.get_input_embeddings().num_embeddings < embeddings_len + or ( + self.model.get_input_embeddings().num_embeddings > embeddings_len + and self.cfg.shrink_embeddings + ) ): resize_kwargs = {} if self.cfg.mean_resizing_embeddings is not None: @@ -1309,6 +1312,7 @@ def load_lora(model, cfg, inference=False, config_only=False): lora_config_kwargs["init_lora_weights"] = "loftq" if cfg.peft_use_dora: lora_config_kwargs["use_dora"] = cfg.peft_use_dora + LOG.info("Initializing LoRA weights using dora. This might take longer.") if cfg.peft_use_rslora: lora_config_kwargs["use_rslora"] = cfg.peft_use_rslora if cfg.peft_layer_replication: diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 61f03e7ad..c8e365fc5 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -576,7 +576,7 @@ def prepare_opinionated_env(cfg): def setup_trainer( cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps ): - if cfg.rl in ("dpo", "ipo", "orpo", "kto", "simpo"): + if cfg.rl: trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer, processor) trainer_builder.model_ref = model[1] trainer_builder.peft_config = model[2] diff --git a/tests/e2e/multigpu/test_grpo.py b/tests/e2e/multigpu/test_grpo.py new file mode 100644 index 000000000..d2b84994b --- /dev/null +++ b/tests/e2e/multigpu/test_grpo.py @@ -0,0 +1,173 @@ +""" +GRPO test suite +""" +import random +from pathlib import Path + +import pytest +import yaml +from accelerate.test_utils import execute_subprocess_async +from e2e.utils import require_vllm +from transformers.testing_utils import get_torch_dist_unique_port + +from axolotl.utils.dict import DictDefault + + +class TestGRPO: + """ + Test case for GRPO training using multilpe GPUs + """ + + def _utils_write_yaml_and_rewards(self, cfg, temp_dir, suffix=""): + # write cfg to yaml file + Path(temp_dir).mkdir(parents=True, exist_ok=True) + with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout: + fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) + with open(f"rewards_{suffix}.py", "w", encoding="utf-8") as fout: + fout.write( + """import random +def rand_reward_func(completions, **kwargs) -> list[float]: + return [random.uniform(0, 1) for _ in completions] + +def oai_gsm8k_transform(cfg, *args, **kwargs): + def transform_fn(example, tokenizer=None): + label = example["answer"].split("####")[-1].strip().replace(",", "") + return { + "prompt": [{"role": "user", "content": example["question"]},], + "answer": label, + } + return transform_fn, {"remove_columns": ["question"]} +""" + ) + + @pytest.mark.parametrize( + "num_gpus", + [1, 2], + ) + @require_vllm + def test_llama_dora(self, temp_dir, num_gpus): + rnd_reward_suffix = str(random.randint(1000, 9999)) + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "chat_template": "llama3", + "rl": "grpo", + "trl": { + "beta": 0.001, + "max_completion_length": 256, + "use_vllm": True, + "vllm_device": "auto" if num_gpus == 1 else "cuda:1", + "vllm_gpu_memory_utilization": 0.15, + "num_generations": 4, + "reward_funcs": [f"rewards_{rnd_reward_suffix}.rand_reward_func"], + }, + "datasets": [ + { + "path": "openai/gsm8k", + "name": "main", + "type": f"rewards_{rnd_reward_suffix}.oai_gsm8k_transform", + }, + ], + "adapter": "lora", + "lora_r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "lora_target_linear": True, + "peft_use_dora": True, + "flash_attention": True, + "sequence_len": 1024, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "max_steps": 5, + "num_epochs": 1, + "micro_batch_size": 4, + "gradient_accumulation_steps": 2, + "warmup_steps": 10, + "val_set_size": 0.0, + "output_dir": temp_dir, + "learning_rate": 0.0001, + "optimizer": "adamw_torch_fused", + "lr_scheduler": "cosine", + "save_safetensors": True, + "bf16": "auto", + "use_tensorboard": True, + } + ) + + self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_reward_suffix) + + execute_subprocess_async( + [ + "axolotl", + "train", + str(Path(temp_dir) / "config.yaml"), + "--num-processes", + str(num_gpus), + "--main-process-port", + f"{get_torch_dist_unique_port()}", + ] + ) + + @pytest.mark.parametrize( + "num_gpus", + [1, 2], + ) + @require_vllm + def test_llama_fft(self, temp_dir, num_gpus): + rnd_reward_suffix = str(random.randint(1000, 9999)) + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "chat_template": "llama3", + "rl": "grpo", + "trl": { + "beta": 0.001, + "max_completion_length": 256, + "use_vllm": True, + "vllm_device": "auto" if num_gpus == 1 else "cuda:1", + "vllm_gpu_memory_utilization": 0.15, + "num_generations": 4, + "reward_funcs": [f"rewards_{rnd_reward_suffix}.rand_reward_func"], + }, + "datasets": [ + { + "path": "openai/gsm8k", + "name": "main", + "type": f"rewards_{rnd_reward_suffix}.oai_gsm8k_transform", + }, + ], + "flash_attention": True, + "sequence_len": 1024, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "max_steps": 5, + "num_epochs": 1, + "micro_batch_size": 4, + "gradient_accumulation_steps": 2, + "warmup_steps": 10, + "val_set_size": 0.0, + "output_dir": temp_dir, + "learning_rate": 0.0001, + "optimizer": "adamw_torch_fused", + "lr_scheduler": "cosine", + "save_safetensors": True, + "bf16": "auto", + "use_tensorboard": True, + } + ) + + self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_reward_suffix) + + execute_subprocess_async( + [ + "axolotl", + "train", + str(Path(temp_dir) / "config.yaml"), + "--num-processes", + str(num_gpus), + "--main-process-port", + f"{get_torch_dist_unique_port()}", + ] + ) diff --git a/tests/e2e/utils.py b/tests/e2e/utils.py index a9f7fb28d..ff96f1f58 100644 --- a/tests/e2e/utils.py +++ b/tests/e2e/utils.py @@ -78,6 +78,24 @@ def require_torch_lt_2_6_0(test_case): return unittest.skipUnless(is_max_2_6_0(), "test requires torch<2.6.0")(test_case) +def require_vllm(test_case): + """ + Decorator marking a test that requires a vllm to be installed + """ + + def is_vllm_installed(): + try: + import vllm # pylint: disable=unused-import # noqa: F401 + + return True + except ImportError: + return False + + return unittest.skipUnless( + is_vllm_installed(), "test requires a vllm to be installed" + )(test_case) + + def is_hopper(): compute_capability = torch.cuda.get_device_capability() return compute_capability == (9, 0)