From b5f1e53a0fbb43528c753f017bf099fa99f42c3e Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Fri, 23 May 2025 15:51:11 -0400 Subject: [PATCH] models.py -> loaders/ module refactor (#2680) * models.py -> loaders/ module refactor * refactor ModelLoader class * plugin manager changes * circular import fix * pytest * pytest * minor improvements * fix * minor changes * fix test * remove dead code * coderabbit comments * lint * fix * coderabbit suggestion I liked * more coderabbit * review comments, yak shaving * lint * updating in light of SP ctx manager changes * review comment * review comment 2 --- src/axolotl/cli/utils.py | 6 +- src/axolotl/common/datasets.py | 2 +- src/axolotl/core/trainer_builder.py | 2 +- src/axolotl/core/trainers/grpo/trainer.py | 2 +- src/axolotl/integrations/base.py | 550 +++--- src/axolotl/loaders/__init__.py | 10 + src/axolotl/loaders/adapter.py | 206 +++ src/axolotl/loaders/constants.py | 21 + src/axolotl/loaders/model.py | 754 ++++++++ src/axolotl/loaders/patch_manager.py | 380 ++++ src/axolotl/loaders/processor.py | 56 + src/axolotl/loaders/tokenizer.py | 281 +++ src/axolotl/loaders/utils.py | 211 +++ .../gradient_checkpointing/__init__.py | 4 +- .../gradient_checkpointing/offload_cpu.py | 0 .../gradient_checkpointing/offload_disk.py | 0 src/axolotl/monkeypatch/peft/utils.py | 2 +- src/axolotl/train.py | 12 +- src/axolotl/utils/config/__init__.py | 3 +- .../utils/ctx_managers/sequence_parallel.py | 2 +- src/axolotl/utils/data/rl.py | 2 +- src/axolotl/utils/lora_embeddings.py | 14 - src/axolotl/utils/models.py | 1648 ----------------- src/axolotl/utils/schemas/config.py | 10 + tests/core/test_trainer_builder.py | 8 +- tests/e2e/patched/test_model_patches.py | 6 +- tests/e2e/test_load_model.py | 13 +- tests/patched/test_validation.py | 16 +- tests/test_exact_deduplication.py | 22 +- .../{utils/test_models.py => test_loaders.py} | 37 +- tests/test_lora.py | 6 +- tests/test_tokenizers.py | 2 +- tests/utils/__init__.py | 0 33 files changed, 2249 insertions(+), 2039 deletions(-) create mode 100644 src/axolotl/loaders/__init__.py create mode 100644 src/axolotl/loaders/adapter.py create mode 100644 src/axolotl/loaders/constants.py create mode 100644 src/axolotl/loaders/model.py create mode 100644 src/axolotl/loaders/patch_manager.py create mode 100644 src/axolotl/loaders/processor.py create mode 100644 src/axolotl/loaders/tokenizer.py create mode 100644 src/axolotl/loaders/utils.py rename src/axolotl/{utils => monkeypatch}/gradient_checkpointing/__init__.py (91%) rename src/axolotl/{utils => monkeypatch}/gradient_checkpointing/offload_cpu.py (100%) rename src/axolotl/{utils => monkeypatch}/gradient_checkpointing/offload_disk.py (100%) delete mode 100644 src/axolotl/utils/lora_embeddings.py delete mode 100644 src/axolotl/utils/models.py rename tests/{utils/test_models.py => test_loaders.py} (83%) delete mode 100644 tests/utils/__init__.py diff --git a/src/axolotl/cli/utils.py b/src/axolotl/cli/utils.py index ee00db39d..e681589f3 100644 --- a/src/axolotl/cli/utils.py +++ b/src/axolotl/cli/utils.py @@ -20,8 +20,9 @@ from transformers import ( ProcessorMixin, ) +from axolotl.loaders import load_processor, load_tokenizer +from axolotl.loaders.model import ModelLoader from axolotl.utils.dict import DictDefault -from axolotl.utils.models import load_model, load_processor, load_tokenizer LOG = logging.getLogger(__name__) @@ -318,7 +319,8 @@ def load_model_and_tokenizer( tokenizer = load_tokenizer(cfg) LOG.info("loading model...") - model, _ = load_model(cfg, tokenizer, inference=inference) + model_loader = ModelLoader(cfg, tokenizer, inference=inference) + model, _ = model_loader.load() processor = None if cfg.is_multimodal: diff --git a/src/axolotl/common/datasets.py b/src/axolotl/common/datasets.py index f944cbd6a..e3ffb7ae9 100644 --- a/src/axolotl/common/datasets.py +++ b/src/axolotl/common/datasets.py @@ -10,10 +10,10 @@ from datasets import Dataset import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401 from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs +from axolotl.loaders import load_processor, load_tokenizer from axolotl.utils.data import prepare_dataset from axolotl.utils.data.rl import load_prepare_preference_datasets from axolotl.utils.dict import DictDefault -from axolotl.utils.models import load_processor, load_tokenizer from axolotl.utils.schemas.enums import RLType from axolotl.utils.tokenization import check_dataset_labels diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 863b065e6..9709f0fd4 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -59,6 +59,7 @@ from axolotl.core.training_args import ( AxolotlTrainingArguments, ) from axolotl.integrations.base import PluginManager +from axolotl.loaders.utils import ensure_dtype from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES from axolotl.monkeypatch.relora import ReLoRACallback from axolotl.monkeypatch.trainer.lr import patch_trainer_get_lr @@ -86,7 +87,6 @@ from axolotl.utils.collators import ( V2BatchSamplerDataCollatorForSeq2Seq, ) from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator -from axolotl.utils.models import ensure_dtype from axolotl.utils.schemas.enums import CustomSupportedOptimizers, RLType try: diff --git a/src/axolotl/core/trainers/grpo/trainer.py b/src/axolotl/core/trainers/grpo/trainer.py index a603ed860..b5b3912cf 100644 --- a/src/axolotl/core/trainers/grpo/trainer.py +++ b/src/axolotl/core/trainers/grpo/trainer.py @@ -43,7 +43,7 @@ from trl.trainer.utils import pad from axolotl.core.trainers.grpo.sampler import SequenceParallelRepeatRandomSampler from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin -from axolotl.monkeypatch.ring_attn.patch import get_ring_attn_group +from axolotl.monkeypatch.ring_attn import get_ring_attn_group if is_peft_available(): # pylint: disable=unused-import diff --git a/src/axolotl/integrations/base.py b/src/axolotl/integrations/base.py index 97cbac693..2beaf667a 100644 --- a/src/axolotl/integrations/base.py +++ b/src/axolotl/integrations/base.py @@ -10,71 +10,73 @@ # License for the specific language governing permissions and limitations under # the License. -""" -Base class for all plugins. +"""Base class for all plugins. A plugin is a reusable, modular, and self-contained piece of code that extends the functionality of Axolotl. Plugins can be used to integrate third-party models, modify the training process, or add new features. To create a new plugin, you need to inherit from the BasePlugin class and implement the required methods. """ + +from __future__ import annotations + import collections import importlib import logging -from typing import OrderedDict +from typing import TYPE_CHECKING, Callable, OrderedDict, Union -import torch +from peft import PeftModel +from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler +from transformers import PreTrainedModel, Trainer from axolotl.utils.dict import DictDefault +if TYPE_CHECKING: + from axolotl.common.datasets import TrainDatasetMeta + class BasePlugin: - """ - Base class for all plugins. Defines the interface for plugin methods. - - Attributes: - None + """Base class for all plugins. Defines the interface for plugin methods. Methods: - register(cfg): Registers the plugin with the given configuration. - load_datasets(cfg): Loads and preprocesses the dataset for training. - pre_model_load(cfg): Performs actions before the model is loaded. - post_model_build(cfg, model): Performs actions after the model is loaded, but before LoRA adapters are applied. - pre_lora_load(cfg, model): Performs actions before LoRA weights are loaded. - post_lora_load(cfg, model): Performs actions after LoRA weights are loaded. - post_model_load(cfg, model): Performs actions after the model is loaded, inclusive of any adapters. - post_trainer_create(cfg, trainer): Performs actions after the trainer is created. - create_optimizer(cfg, trainer): Creates and returns an optimizer for training. - create_lr_scheduler(cfg, trainer, optimizer, num_training_steps): Creates and returns a learning rate scheduler. - add_callbacks_pre_trainer(cfg, model): Adds callbacks to the trainer before training. - add_callbacks_post_trainer(cfg, trainer): Adds callbacks to the trainer after training. + register(cfg): Registers the plugin with the given configuration. + load_datasets(cfg): Loads and preprocesses the dataset for training. + pre_model_load(cfg): Performs actions before the model is loaded. + post_model_build(cfg, model): Performs actions after the model is loaded, but + before LoRA adapters are applied. + pre_lora_load(cfg, model): Performs actions before LoRA weights are loaded. + post_lora_load(cfg, model): Performs actions after LoRA weights are loaded. + post_model_load(cfg, model): Performs actions after the model is loaded, + inclusive of any adapters. + post_trainer_create(cfg, trainer): Performs actions after the trainer is + created. + create_optimizer(cfg, trainer): Creates and returns an optimizer for training. + create_lr_scheduler(cfg, trainer, optimizer, num_training_steps): Creates and + returns a learning rate scheduler. + add_callbacks_pre_trainer(cfg, model): Adds callbacks to the trainer before + training. + add_callbacks_post_trainer(cfg, trainer): Adds callbacks to the trainer after + training. """ def __init__(self): - """ - Initializes the BasePlugin. - """ + """Initializes the BasePlugin.""" def register(self, cfg): # pylint: disable=unused-argument - """ - Registers the plugin with the given configuration. + """Registers the plugin with the given configuration. - Parameters: - cfg (dict): The configuration for the plugin. - - Returns: - None + Args: + cfg: The configuration for the plugin. """ def get_input_args(self) -> str | None: - """ - Returns a pydantic model for the plugin's input arguments. - """ + """Returns a pydantic model for the plugin's input arguments.""" - def load_datasets(self, cfg: DictDefault, preprocess: bool = False): - """ - Loads and preprocesses the dataset for training. + def load_datasets( + self, cfg: DictDefault, preprocess: bool = False + ) -> Union["TrainDatasetMeta", None]: + """Loads and preprocesses the dataset for training. Args: cfg: The configuration for the plugin. @@ -84,181 +86,164 @@ class BasePlugin: dataset_meta: The metadata for the training dataset. """ - def pre_model_load(self, cfg): # pylint: disable=unused-argument - """ - Performs actions before the model is loaded. + def pre_model_load(self, cfg: DictDefault): # pylint: disable=unused-argument + """Performs actions before the model is loaded. Args: - cfg (dict): The configuration for the plugin. + cfg: The configuration for the plugin. + """ + + # pylint: disable=unused-argument + def post_model_build(self, cfg: DictDefault, model: PreTrainedModel): + """Performs actions after the model is built/loaded, but before any adapters are applied. + + Args: + cfg: The configuration for the plugin. + """ + + # pylint: disable=unused-argument + def pre_lora_load(self, cfg: DictDefault, model: PreTrainedModel): + """Performs actions before LoRA weights are loaded. + + Args: + cfg: The configuration for the plugin. + model: The loaded model. + """ + + # pylint: disable=unused-argument + def post_lora_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel): + """Performs actions after LoRA weights are loaded. + + Args: + cfg: The configuration for the plugin. + model: The loaded model. + """ + + # pylint: disable=unused-argument + def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel): + """Performs actions after the model is loaded. + + Args: + cfg: The configuration for the plugin. + model: The loaded model. + """ + + # pylint: disable=unused-argument + def get_trainer_cls(self, cfg: DictDefault) -> Trainer | None: + """Returns a custom class for the trainer. + + Args: + cfg: The global axolotl configuration. Returns: - None + The first non-`None` trainer class returned by a plugin. """ - def post_model_build(self, cfg, model): # pylint: disable=unused-argument - """ - Performs actions after the model is built/loaded, but before any adapters are applied. + # pylint: disable=unused-argument + def post_trainer_create(self, cfg: DictDefault, trainer: Trainer): + """Performs actions after the trainer is created. Args: - cfg (dict): The configuration for the plugin. + cfg: The configuration for the plugin. + trainer: The trainer object for training. """ - def post_model_load(self, cfg, model): # pylint: disable=unused-argument - """ - Performs actions after the model is loaded. + # pylint: disable=unused-argument + def create_optimizer(self, cfg: DictDefault, trainer: Trainer) -> Optimizer | None: + """Creates and returns an optimizer for training. Args: - cfg (dict): The configuration for the plugin. - model (object): The loaded model. + cfg: The configuration for the plugin. + trainer: The trainer object for training. Returns: - None - """ - - def pre_lora_load(self, cfg, model): # pylint: disable=unused-argument - """ - Performs actions before LoRA weights are loaded. - - Args: - cfg (dict): The configuration for the plugin. - model (object): The loaded model. - - Returns: - None - """ - - def post_lora_load(self, cfg, model): # pylint: disable=unused-argument - """ - Performs actions after LoRA weights are loaded. - - Args: - cfg (dict): The configuration for the plugin. - model (object): The loaded model. - - Returns: - None - """ - - def get_trainer_cls(self, cfg): # pylint: disable=unused-argument): - """ - Returns a custom class for the trainer. - - Args: - cfg (dict): The global axolotl configuration. - - Returns: - class: The class for the trainer. - """ - - def post_trainer_create(self, cfg, trainer): # pylint: disable=unused-argument - """ - Performs actions after the trainer is created. - - Args: - cfg (dict): The configuration for the plugin. - trainer (object): The trainer object for training. - - Returns: - None - """ - - def create_optimizer(self, cfg, trainer): # pylint: disable=unused-argument - """ - Creates and returns an optimizer for training. - - Args: - cfg (dict): The configuration for the plugin. - trainer (object): The trainer object for training. - - Returns: - object: The created optimizer. + The created optimizer. """ + # pylint: disable=unused-argument def create_lr_scheduler( - self, cfg, trainer, optimizer, num_training_steps - ) -> LRScheduler | None: # pylint: disable=unused-argument - """ - Creates and returns a learning rate scheduler. + self, + cfg: DictDefault, + trainer: Trainer, + optimizer: Optimizer, + num_training_steps: int, + ) -> LRScheduler | None: + """Creates and returns a learning rate scheduler. Args: - cfg (dict): The configuration for the plugin. - trainer (object): The trainer object for training. - optimizer (object): The optimizer for training. - num_training_steps (int): Total number of training steps + cfg: The configuration for the plugin. + trainer: The trainer object for training. + optimizer: The optimizer for training. + num_training_steps: Total number of training steps Returns: - object (LRScheduler): The created learning rate scheduler. + The created learning rate scheduler. """ - def add_callbacks_pre_trainer(self, cfg, model): # pylint: disable=unused-argument - """ - setup callbacks before creating the trainer. + # pylint: disable=unused-argument + def add_callbacks_pre_trainer( + self, cfg: DictDefault, model: PreTrainedModel + ) -> list[Callable]: + """Set up callbacks before creating the trainer. Args: - cfg (dict): The configuration for the plugin. - model (object): The loaded model. + cfg: The configuration for the plugin. + model: The loaded model. Returns: - List[callable]: A list of callback functions to be added to the TrainingArgs + A list of callback functions to be added to the `TrainingArgs`. """ return [] + # pylint: disable=unused-argument def add_callbacks_post_trainer( - self, cfg, trainer - ): # pylint: disable=unused-argument - """ - Adds callbacks to the trainer after creating the trainer. - This is useful for callbacks that require access to the model or trainer. + self, cfg: DictDefault, trainer: Trainer + ) -> list[Callable]: + """Adds callbacks to the trainer after creating the trainer. This is useful for + callbacks that require access to the model or trainer. Args: - cfg (dict): The configuration for the plugin. - trainer (object): The trainer object for training. + cfg: The configuration for the plugin. + trainer: The trainer object for training. Returns: - List[callable]: A list of callback functions to be added + A list of callback functions to be added """ return [] - def post_train(self, cfg, model): # pylint: disable=unused-argument - """ - Performs actions after training is complete. + # pylint: disable=unused-argument + def post_train(self, cfg: DictDefault, model: PreTrainedModel | PeftModel): + """Performs actions after training is complete. Args: - cfg (dict): The axolotl configuration - model (object): The loaded model. - - Returns: - None + cfg: The axolotl configuration. + model: The loaded model. """ - def post_train_unload(self, cfg): # pylint: disable=unused-argument - """ - Performs actions after training is complete and the model is unloaded. + def post_train_unload(self, cfg: DictDefault): # pylint: disable=unused-argument + """Performs actions after training is complete and the model is unloaded. Args: - cfg (dict): The configuration for the plugin. - - Returns: - None + cfg: The configuration for the plugin. """ def load_plugin(plugin_name: str) -> BasePlugin: - """ - Loads a plugin based on the given plugin name. + """Loads a plugin based on the given plugin name. - The plugin name should be in the format "module_name.class_name". - This function splits the plugin name into module and class, imports the module, - retrieves the class from the module, and creates an instance of the class. + The plugin name should be in the format "module_name.class_name". This function + splits the plugin name into module and class, imports the module, retrieves the + class from the module, and creates an instance of the class. - Parameters: - plugin_name (str): The name of the plugin to be loaded. The name should be in the format "module_name.class_name". + Args: + plugin_name: The name of the plugin to be loaded. The name should be in the + format "module_name.class_name". Returns: - BasePlugin: An instance of the loaded plugin. + An instance of the loaded plugin. Raises: - ImportError: If the plugin module cannot be imported. + ImportError: If the plugin module cannot be imported. """ # split the plugin name into module and class module_name, class_name = plugin_name.rsplit(".", 1) @@ -284,28 +269,25 @@ def load_plugin(plugin_name: str) -> BasePlugin: class PluginManager: - """ - The PluginManager class is responsible for loading and managing plugins. - It should be a singleton so it can be accessed from anywhere in the codebase. + """The `PluginManager` class is responsible for loading and managing plugins. It + should be a singleton so it can be accessed from anywhere in the codebase. Attributes: - plugins (List[BasePlugin]): A list of loaded plugins. + plugins: A list of loaded plugins. Methods: - get_instance(): Static method to get the singleton instance of PluginManager. - register(plugin_name: str): Registers a new plugin by its name. - pre_model_load(cfg): Calls the pre_model_load method of all registered plugins. + get_instance(): Static method to get the singleton instance of `PluginManager`. + register(plugin_name: str): Registers a new plugin by its name. + pre_model_load(cfg): Calls the pre_model_load method of all registered plugins. """ plugins: OrderedDict[str, BasePlugin] = collections.OrderedDict() - _instance = None - _cfg = None + _instance: PluginManager | None = None + _cfg: DictDefault | None = None def __new__(cls): - """ - Creates a new instance of PluginManager if it doesn't exist yet. - """ + """Creates a new instance of PluginManager if it doesn't exist yet.""" if cls._instance is None: cls._instance = super(PluginManager, cls).__new__(cls) cls._instance.plugins: OrderedDict[str, BasePlugin] = ( @@ -315,9 +297,8 @@ class PluginManager: @staticmethod def get_instance() -> "PluginManager": - """ - Returns the singleton instance of PluginManager. - If the instance doesn't exist, it creates a new one. + """Returns the singleton instance of PluginManager. If the instance doesn't + exist, it creates a new one. """ if PluginManager._instance is None: PluginManager() @@ -332,17 +313,13 @@ class PluginManager: self._cfg = cfg def register(self, plugin_name: str): - """ - Registers a new plugin by its name. + """Registers a new plugin by its name. - Parameters: - plugin_name (str): The name of the plugin to be registered. - - Returns: - None + Args: + plugin_name: The name of the plugin to be registered. Raises: - ImportError: If the plugin module cannot be imported. + ImportError: If the plugin module cannot be imported. """ try: logging.info(f"Attempting to load plugin: {plugin_name}") @@ -352,12 +329,11 @@ class PluginManager: except ImportError: logging.error(f"Failed to load plugin: {plugin_name}") - def get_input_args(self): - """ - Returns a list of Pydantic classes for all registered plugins' input arguments.' + def get_input_args(self) -> list[str]: + """Returns a list of Pydantic classes for all registered plugins' input arguments.' Returns: - list[str]: A list of Pydantic classes for all registered plugins' input arguments.' + A list of Pydantic classes for all registered plugins' input arguments.' """ input_args = [] for plugin in self.plugins.values(): @@ -366,16 +342,17 @@ class PluginManager: input_args.append(input_args_from_plugin) return input_args - def load_datasets(self, cfg, preprocess: bool = False): - """ - Calls the load_datasets method of each registered plugin. + def load_datasets( + self, cfg: DictDefault, preprocess: bool = False + ) -> Union["TrainDatasetMeta", None]: + """Calls the load_datasets method of each registered plugin. Args: cfg: The configuration for the plugins. - preprocess : Whether this is preprocess step of the datasets. + preprocess: Whether this is preprocess step of the datasets. Returns: - dataset_meta: The dataset metadata loaded from all registered plugins. + The dataset metadata loaded from all registered plugins. """ return_ds_meta = None for plugin in self.plugins.values(): @@ -387,83 +364,66 @@ class PluginManager: raise RuntimeError("Multiple plugins loaded datasets") return return_ds_meta - def pre_model_load(self, cfg): - """ - Calls the pre_model_load method of all registered plugins. + def pre_model_load(self, cfg: DictDefault): + """Calls the pre_model_load method of all registered plugins. - Parameters: - cfg (dict): The configuration for the plugins. - - Returns: - None + Args: + cfg: The configuration for the plugins. """ for plugin in self.plugins.values(): plugin.pre_model_load(cfg) - def post_model_build(self, cfg, model): - """ - Calls the post_model_build method of all registered plugins after the model has been built/loaded, - but before any adapters have been applied. + def post_model_build(self, cfg: DictDefault, model: PreTrainedModel): + """Calls the `post_model_build` method of all registered plugins after the + model has been built / loaded, but before any adapters have been applied. Args: - cfg (dict): The configuration for the plugins. - model (object): The loaded model. + cfg: The configuration for the plugins. + model: The loaded model. """ for plugin in self.plugins.values(): plugin.post_model_build(cfg, model) - def post_model_load(self, cfg, model): - """ - Calls the post_model_load method of all registered plugins after the model has been loaded - inclusive of any adapters + def pre_lora_load(self, cfg: DictDefault, model: PreTrainedModel): + """Calls the `pre_lora_load` method of all registered plugins. - Parameters: - cfg (dict): The configuration for the plugins. - model (object): The loaded model. - - Returns: - None - """ - for plugin in self.plugins.values(): - plugin.post_model_load(cfg, model) - - def pre_lora_load(self, cfg, model): - """ - Calls the pre_lora_load method of all registered plugins. - - Parameters: - cfg (dict): The configuration for the plugins. - model (object): The loaded model. - - Returns: - None + Args: + cfg: The configuration for the plugins. + model: The loaded model. """ for plugin in self.plugins.values(): plugin.pre_lora_load(cfg, model) - def post_lora_load(self, cfg, model): - """ - Calls the post_lora_load method of all registered plugins. + def post_lora_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel): + """Calls the `post_lora_load` method of all registered plugins. - Parameters: - cfg (dict): The configuration for the plugins. - model (object): The loaded model. - - Returns: - None + Args: + cfg: The configuration for the plugins. + model: The loaded model. """ for plugin in self.plugins.values(): plugin.post_lora_load(cfg, model) - def get_trainer_cls(self, cfg): - """ - Calls the get_trainer_cls method of all registered plugins and returns the first non-None trainer class. + def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel): + """Calls the `post_model_load` method of all registered plugins after the model + has been loaded inclusive of any adapters. - Parameters: - cfg (dict): The configuration for the plugins. + Args: + cfg: The configuration for the plugins. + model: The loaded model. + """ + for plugin in self.plugins.values(): + plugin.post_model_load(cfg, model) + + def get_trainer_cls(self, cfg: DictDefault) -> Trainer | None: + """Calls the `get_trainer_cls` method of all registered plugins and returns the + first non-`None` trainer class. + + Args: + cfg: The configuration for the plugins. Returns: - object: The trainer class, or None if none was found. + The first non-`None` trainer class returned by a plugin. """ for plugin in self.plugins.values(): trainer_cls = plugin.get_trainer_cls(cfg) @@ -471,29 +431,25 @@ class PluginManager: return trainer_cls return None - def post_trainer_create(self, cfg, trainer): - """ - Calls the post_trainer_create method of all registered plugins. + def post_trainer_create(self, cfg: DictDefault, trainer: Trainer): + """Calls the `post_trainer_create` method of all registered plugins. - Parameters: - cfg (dict): The configuration for the plugins. - trainer (object): The trainer object for training. - - Returns: - None + Args: + cfg: The configuration for the plugins. + trainer: The trainer object for training. """ for plugin in self.plugins.values(): plugin.post_trainer_create(cfg, trainer) - def create_optimizer(self, trainer): - """ - Calls the create_optimizer method of all registered plugins and returns the first non-None optimizer. + def create_optimizer(self, trainer: Trainer) -> Optimizer | None: + """Calls the `create_optimizer` method of all registered plugins and returns + the first non-`None` optimizer. - Parameters: - trainer (object): The trainer object for training. + Args: + trainer: The trainer object for training. Returns: - object: The created optimizer, or None if none was found. + The created optimizer, or `None` if none was found. """ for plugin in self.plugins.values(): optimizer = plugin.create_optimizer(self.cfg, trainer) @@ -502,17 +458,17 @@ class PluginManager: return None def create_lr_scheduler( - self, trainer, optimizer, num_training_steps + self, trainer: Trainer, optimizer: Optimizer, num_training_steps: int ) -> LRScheduler | None: - """ - Calls the create_lr_scheduler method of all registered plugins and returns the first non-None scheduler. + """Calls the `create_lr_scheduler` method of all registered plugins and returns + the first non-`None` scheduler. - Parameters: - trainer (object): The trainer object for training. - optimizer (object): The optimizer for training. + Args: + trainer: The trainer object for training. + optimizer: The optimizer for training. Returns: - object: The created learning rate scheduler, or None if none was found. + The created learning rate scheduler, or `None` if not found. """ for plugin in self.plugins.values(): scheduler: LRScheduler | None = plugin.create_lr_scheduler( @@ -525,16 +481,17 @@ class PluginManager: return scheduler return None - def add_callbacks_pre_trainer(self, cfg, model): - """ - Calls the add_callbacks_pre_trainer method of all registered plugins. + def add_callbacks_pre_trainer( + self, cfg: DictDefault, model: PreTrainedModel + ) -> list[Callable]: + """Calls the add_callbacks_pre_trainer method of all registered plugins. - Parameters: - cfg (dict): The configuration for the plugins. - model (object): The loaded model. + Args: + cfg: The configuration for the plugins. + model: The loaded model. Returns: - List[callable]: A list of callback functions to be added to the TrainingArgs. + A list of callback functions to be added to the `TrainingArgs`. """ callbacks = [] for plugin in self.plugins.values(): @@ -543,16 +500,17 @@ class PluginManager: callbacks.extend(plugin_callbacks) return callbacks - def add_callbacks_post_trainer(self, cfg, trainer): - """ - Calls the add_callbacks_post_trainer method of all registered plugins. + def add_callbacks_post_trainer( + self, cfg: DictDefault, trainer: Trainer + ) -> list[Callable]: + """Calls the `add_callbacks_post_trainer` method of all registered plugins. - Parameters: - cfg (dict): The configuration for the plugins. - trainer (object): The trainer object for training. + Args: + cfg: The configuration for the plugins. + trainer: The trainer object for training. Returns: - List[callable]: A list of callback functions to be added to the TrainingArgs. + A list of callback functions to be added to the `TrainingArgs`. """ callbacks = [] for plugin in self.plugins.values(): @@ -561,41 +519,31 @@ class PluginManager: callbacks.extend(plugin_callbacks) return callbacks - def post_train(self, cfg, model): - """ - Calls the post_train method of all registered plugins. + def post_train(self, cfg: DictDefault, model: PreTrainedModel | PeftModel): + """Calls the post_train method of all registered plugins. - Parameters: - cfg (dict): The configuration for the plugins. - model (object): The loaded model. - - Returns: - None + Args: + cfg: The configuration for the plugins. + model: The loaded model. """ for plugin in self.plugins.values(): plugin.post_train(cfg, model) - def post_train_unload(self, cfg): - """ - Calls the post_train_unload method of all registered plugins. + def post_train_unload(self, cfg: DictDefault): + """Calls the post_train_unload method of all registered plugins. - Parameters: - cfg (dict): The configuration for the plugins. - model (object): The loaded model. - - Returns: - None + Args: + cfg: The configuration for the plugins. + model: The loaded model. """ for plugin in self.plugins.values(): plugin.post_train_unload(cfg) class BaseOptimizerFactory: - """ - Base class for factories to create custom optimizers - """ + """Base class for factories to create custom optimizers""" def __call__( self, opt_model, training_args, **optimizer_kwargs - ) -> "torch.optim.Optimizer": + ) -> Optimizer | None: pass diff --git a/src/axolotl/loaders/__init__.py b/src/axolotl/loaders/__init__.py new file mode 100644 index 000000000..3eef75e58 --- /dev/null +++ b/src/axolotl/loaders/__init__.py @@ -0,0 +1,10 @@ +"""Init for axolotl.loaders module""" + +# pylint: disable=unused-import +# flake8: noqa + +from .adapter import load_adapter, load_lora +from .constants import MULTIMODAL_AUTO_MODEL_MAPPING +from .model import ModelLoader +from .processor import load_processor +from .tokenizer import load_tokenizer diff --git a/src/axolotl/loaders/adapter.py b/src/axolotl/loaders/adapter.py new file mode 100644 index 000000000..f7a484e9b --- /dev/null +++ b/src/axolotl/loaders/adapter.py @@ -0,0 +1,206 @@ +"""Adapter loading functionality, including LoRA / QLoRA and associated utils""" + +import logging +import os +import types +from typing import Any + +import bitsandbytes as bnb +import torch +from bitsandbytes.nn import Params4bit +from peft import ( + AdaptionPromptConfig, + LoftQConfig, + LoraConfig, + PeftConfig, + PeftMixedModel, + PeftModel, + get_peft_model, +) +from transformers import PreTrainedModel + +from axolotl.loaders.utils import get_linear_embedding_layers +from axolotl.utils.dict import DictDefault + +LOG = logging.getLogger(__name__) + + +def setup_quantized_meta_for_peft(model: torch.nn.Module): + """Replaces `quant_state.to` with a dummy function to prevent PEFT from moving `quant_state` to meta device""" + + def temp_to_method(self, *args, **kwargs): # pylint: disable=unused-argument + return self + + for param in model.parameters(): + if isinstance(param, Params4bit): + param.quant_state._orig_to = ( # pylint: disable=protected-access + param.quant_state.to + ) + param.quant_state.to = types.MethodType(temp_to_method, param.quant_state) + + +def setup_quantized_peft_meta_for_training(model: torch.nn.Module): + """Replaces dummy `quant_state.to` method with the original function to allow training to continue""" + for param in model.parameters(): + if isinstance(param, Params4bit) and hasattr(param.quant_state, "_orig_to"): + param.quant_state.to = ( + param.quant_state._orig_to # pylint: disable=protected-access + ) + param.quant_state._orig_to = None # pylint: disable=protected-access + + +def find_all_linear_names(model): + cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear) + lora_module_names = set() + for name, module in model.named_modules(): + if ( + isinstance(module, cls) + or "Linear" in module.__class__.__name__ + and module.__class__.__name__ not in ("LlamaLinearScalingRotaryEmbedding",) + ): + names = name.split(".") + lora_module_names.add(names[0] if len(names) == 1 else names[-1]) + + embedding_modules = get_linear_embedding_layers(model.config.model_type) + output_embedding = embedding_modules[1] + if output_embedding in lora_module_names: # needed for 16-bit + lora_module_names.remove(output_embedding) + + return list(lora_module_names) + + +def load_lora( + model: PreTrainedModel, + cfg: DictDefault, + inference: bool = False, + config_only: bool = False, +) -> tuple[PreTrainedModel | PeftModel | PeftMixedModel | None, PeftConfig | None]: + lora_target_modules = cfg.lora_target_modules or [] + + if cfg.lora_target_linear: + linear_names = find_all_linear_names(model) + LOG.info(f"found linear modules: {repr(sorted(linear_names))}") + lora_target_modules_as_list = ( + lora_target_modules + if isinstance(lora_target_modules, list) + else [lora_target_modules] + ) + lora_target_modules = list(set(lora_target_modules_as_list + linear_names)) + + lora_config_kwargs = {} + loftq_bits = cfg.peft and cfg.peft.loftq_config and cfg.peft.loftq_config.loftq_bits + if loftq_bits: + lora_config_kwargs["loftq_config"] = LoftQConfig(loftq_bits=loftq_bits) + lora_config_kwargs["init_lora_weights"] = "loftq" + if cfg.peft_init_lora_weights: + lora_config_kwargs["init_lora_weights"] = cfg.peft_init_lora_weights + if cfg.peft_use_dora: + lora_config_kwargs["use_dora"] = cfg.peft_use_dora + LOG.info("Initializing LoRA weights using dora. This might take longer.") + if cfg.peft_use_rslora: + lora_config_kwargs["use_rslora"] = cfg.peft_use_rslora + if cfg.peft_layer_replication: + lora_config_kwargs["layer_replication"] = cfg.peft_layer_replication + + lora_config = LoraConfig( + r=cfg.lora_r, + lora_alpha=cfg.lora_alpha, + target_modules=lora_target_modules, + layers_to_transform=cfg.peft_layers_to_transform, + layers_pattern=cfg.peft_layers_pattern, + lora_dropout=cfg.lora_dropout, + fan_in_fan_out=cfg.lora_fan_in_fan_out, + modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None, + bias="none", + task_type="CAUSAL_LM", + **lora_config_kwargs, + ) + + if config_only: + return None, lora_config + + rank = int(os.environ.get("LOCAL_RANK", 0)) + + if ( + cfg.fsdp + and cfg.adapter + and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading + and rank != 0 + ): + setup_quantized_meta_for_peft(model) + + if cfg.lora_model_dir: + LOG.debug("Loading pretrained PEFT - LoRA") + model_kwargs: Any = {} + if cfg.lora_on_cpu: + model_kwargs["max_memory"] = {"cpu": "256GiB"} + model_kwargs["device_map"] = {"": "cpu"} + model = PeftModel.from_pretrained( + model, + cfg.lora_model_dir, + is_trainable=(not inference), + **model_kwargs, + ) + else: + model = get_peft_model(model, lora_config) + + if rank == 0: + try: + model.print_trainable_parameters() + except AttributeError as exc: + LOG.warning( + "Exception caught during model.print_trainable_parameters(): %s", exc + ) + elif ( + cfg.fsdp + and cfg.adapter + and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading + and rank != 0 + ): + setup_quantized_peft_meta_for_training(model) + + return model, lora_config + + +def load_adapter( + model: PreTrainedModel, + cfg: DictDefault, + adapter: str | None, + inference: bool = False, +) -> tuple[PreTrainedModel | PeftModel | PeftMixedModel, PeftConfig | None]: + if adapter is None: + return model, None + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + if adapter in ["lora", "qlora"]: + peft_model, lora_config = load_lora(model, cfg, inference=inference) + return peft_model, lora_config + if adapter == "llama-adapter": + peft_model, lora_config = load_llama_adapter(model, cfg) + return peft_model, lora_config + + raise NotImplementedError(f"{adapter} PEFT adapter not available") + + +def load_llama_adapter( + model: PreTrainedModel, cfg: DictDefault +) -> tuple[PeftModel | PeftMixedModel, PeftConfig]: + peft_config = AdaptionPromptConfig( + adapter_layers=cfg.peft_adapter.layers, # layers (L) + adapter_len=cfg.peft_adapter.len, # prompt length (K) + task_type="CAUSAL_LM", + ) + + if cfg.lora_model_dir: + LOG.debug("Loading pretrained PEFT - llama_adapter") + peft_model = PeftModel.from_pretrained( + model, + cfg.lora_model_dir, + torch_dtype=torch.float16, + ) + else: + peft_model = get_peft_model(model, peft_config) + + peft_model.print_trainable_parameters() + + return peft_model, peft_config diff --git a/src/axolotl/loaders/constants.py b/src/axolotl/loaders/constants.py new file mode 100644 index 000000000..c08518dd6 --- /dev/null +++ b/src/axolotl/loaders/constants.py @@ -0,0 +1,21 @@ +"""Shared constants for axolotl.loaders module""" + +from transformers import ( + Gemma3ForConditionalGeneration, + Llama4ForConditionalGeneration, + LlavaForConditionalGeneration, + Mistral3ForConditionalGeneration, + MllamaForConditionalGeneration, + Qwen2_5_VLForConditionalGeneration, + Qwen2VLForConditionalGeneration, +) + +MULTIMODAL_AUTO_MODEL_MAPPING = { + "mllama": MllamaForConditionalGeneration, + "llama4": Llama4ForConditionalGeneration, + "llava": LlavaForConditionalGeneration, + "qwen2_vl": Qwen2VLForConditionalGeneration, + "qwen2_5_vl": Qwen2_5_VLForConditionalGeneration, + "mistral3": Mistral3ForConditionalGeneration, + "gemma3": Gemma3ForConditionalGeneration, +} diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py new file mode 100644 index 000000000..d7ac84a6d --- /dev/null +++ b/src/axolotl/loaders/model.py @@ -0,0 +1,754 @@ +"""Model loader class implementation for loading, configuring, and patching various +models. +""" + +import gc +import logging +import math +import os +from functools import cached_property +from importlib.util import find_spec +from typing import Any + +import peft +import torch +import transformers +import transformers.modeling_utils +from accelerate import init_empty_weights +from peft import PeftConfig, PeftMixedModel, PeftModel, prepare_model_for_kbit_training +from transformers import ( + AutoModelForCausalLM, + AutoModelForVision2Seq, + AwqConfig, + BitsAndBytesConfig, + GPTQConfig, + PreTrainedModel, + PreTrainedTokenizerBase, +) +from transformers.integrations.deepspeed import ( + HfTrainerDeepSpeedConfig, + is_deepspeed_zero3_enabled, +) + +from axolotl.common.architectures import MOE_ARCH_BLOCK +from axolotl.integrations.base import PluginManager +from axolotl.loaders.adapter import load_adapter, load_lora +from axolotl.loaders.constants import MULTIMODAL_AUTO_MODEL_MAPPING +from axolotl.loaders.patch_manager import PatchManager +from axolotl.loaders.utils import ( + get_linear_embedding_layers, + get_module_class_from_name, + load_model_config, +) +from axolotl.models.mamba import fix_mamba_attn_for_loss +from axolotl.utils.bench import log_gpu_memory_usage +from axolotl.utils.dict import DictDefault +from axolotl.utils.distributed import ( + get_device_count, + get_device_type, +) +from axolotl.utils.model_shard_quant import load_sharded_model_quant +from axolotl.utils.schemas.enums import RLType + +LOG = logging.getLogger(__name__) +PLUGIN_MANAGER = PluginManager.get_instance() + + +class ModelLoader: + """Manages model configuration, initialization and application of patches during + model loading. + + This class orchestrates the entire process of loading a model from configuration to + final preparation. It handles device mapping, quantization, attention mechanisms, + adapter integration, and various optimizations. + + The loading process includes: + - Loading and validating model configuration + - Applying monkey patches for optimizations / fixes + - Setting up device mapping (including multi-GPU configurations) + - Configuring quantization + - Setting attention mechanisms (Flash Attention, SDPA, etc.) + - Loading and initializing the model + - Applying adapters (LoRA, QLoRA, etc.) + + Attributes: + model: The loaded model instance (available after load() is called). + model_kwargs: Dictionary of keyword arguments passed to model initialization. + base_model: Name or path of the base model to load. + model_type: Type of model to load (e.g., `AutoModelForCausalLM`). + model_config: Configuration object for the model. + auto_model_loader: class used for loading the model (default: + `AutoModelForCausalLM`). + """ + + def __init__( + self, + cfg: DictDefault, + tokenizer: PreTrainedTokenizerBase, + *, + inference: bool = False, + reference_model: bool = False, + **kwargs, # pylint: disable=unused-argument + ): + """Initializes the ModelLoader. + + Args: + cfg: Configuration dictionary with model and training settings. + tokenizer: Tokenizer instance associated with the model. + processor: Optional processor for multimodal models. Defaults to None. + inference: Whether the model is being loaded for inference mode. Defaults + to False. + reference_model: Whether this is a reference model (used in setups like DPO + training). Defaults to False. + **kwargs: Additional keyword arguments (ignored). + """ + self.cfg = cfg + self.tokenizer = tokenizer + self.inference: bool = inference + self.reference_model: bool = reference_model + + # Init model kwargs + self.model_kwargs: dict[str, Any] = {} + if cfg.overrides_of_model_kwargs: + for key, val in cfg.overrides_of_model_kwargs.items(): + self.model_kwargs[key] = val + + # Init model + self.model: PreTrainedModel | PeftModel | PeftMixedModel + self.base_model = cfg.base_model + self.model_type = cfg.type_of_model + + # Init model config + self.model_config = load_model_config(cfg) + self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name + + # Initialize the patch manager + self.patch_manager = PatchManager( + cfg=cfg, + model_config=self.model_config, + inference=inference, + ) + + @cached_property + def has_flash_attn(self) -> bool: + """Check if flash attention is installed.""" + return find_spec("flash_attn") is not None + + @cached_property + def qlora_fsdp(self): + """Property that determines if FSDP with QLoRA is enabled.""" + return self.cfg.fsdp and self.cfg.adapter == "qlora" + + def load(self) -> tuple[PreTrainedModel, PeftConfig | None]: + """Load and prepare the model with all configurations and patches. + + Returns: + A tuple with the loaded model and its LoRA configuration (if applicable). + """ + # Initial setup and patches + self.patch_manager.apply_pre_model_load_patches() + self._apply_pre_model_load_setup() + + # Build the model + PLUGIN_MANAGER.pre_model_load(self.cfg) + skip_move_to_device = self._build_model() + PLUGIN_MANAGER.post_model_build(self.cfg, self.model) + + # Post-build model configuration + self._apply_post_model_load_setup() + + # Load adapters (LoRA, etc.) + PLUGIN_MANAGER.pre_lora_load(self.cfg, self.model) + lora_config = self._load_adapters() + PLUGIN_MANAGER.post_lora_load(self.cfg, self.model) + + # Apply remaining patches and finalize + self._apply_post_lora_load_setup(skip_move_to_device) + self.patch_manager.apply_post_model_load_patches(self.model) + PLUGIN_MANAGER.post_model_load(self.cfg, self.model) + + return self.model, lora_config + + def _apply_pre_model_load_setup(self): + """Apply patches and setup configurations before model loading.""" + self._set_auto_model_loader() + self._set_device_map_config() + if self.cfg.revision_of_model: + self.model_kwargs["revision"] = self.cfg.revision_of_model + self._set_quantization_config() + self._set_attention_config() + + def _apply_post_model_load_setup(self): + """Configure the model after it has been loaded.""" + # Handle PeftModel if needed + if ( + isinstance(self.model, (peft.PeftModel, peft.PeftModelForCausalLM)) + and not self.qlora_fsdp + ): + self.model = self.model.merge_and_unload() + + self._resize_token_embeddings() + self._adjust_model_config() + self._log_memory_usage() + self._configure_embedding_dtypes() + + def _resize_token_embeddings(self): + """Resize token embeddings if needed.""" + embeddings_len = ( + math.ceil(len(self.tokenizer) / 32) * 32 + if self.cfg.resize_token_embeddings_to_32x + else len(self.tokenizer) + ) + if hasattr(self.model, "get_input_embeddings") and ( + self.model.get_input_embeddings().num_embeddings < embeddings_len + or ( + self.model.get_input_embeddings().num_embeddings > embeddings_len + and self.cfg.shrink_embeddings + ) + ): + resize_kwargs = {} + if self.cfg.mean_resizing_embeddings is not None and ( + self.model_config.model_type != "llava" + ): + resize_kwargs["mean_resizing"] = self.cfg.mean_resizing_embeddings + self.model.resize_token_embeddings(embeddings_len, **resize_kwargs) + else: + self.model.tie_weights() + + def _adjust_model_config(self): + if ( + hasattr(self.model, "config") + and hasattr(self.model.config, "max_position_embeddings") + and self.model.config.max_position_embeddings + and self.cfg.sequence_len > self.model.config.max_position_embeddings + ): + LOG.warning( + "increasing model.config.max_position_embeddings from " + f"{self.model.config.max_position_embeddings} to {self.cfg.sequence_len}" + ) + self.model.config.max_position_embeddings = self.cfg.sequence_len + + if ( + hasattr(self.model, "config") + and hasattr(self.model.config, "bos_token_id") + and self.model.config.bos_token_id + and self.model.config.bos_token_id != self.tokenizer.bos_token_id + ): + self.model.config.bos_token_id = self.tokenizer.bos_token_id + + if ( + hasattr(self.model, "config") + and hasattr(self.model.config, "eos_token_id") + and self.model.config.eos_token_id + and self.model.config.eos_token_id != self.tokenizer.eos_token_id + ): + self.model.config.eos_token_id = self.tokenizer.eos_token_id + + def _log_memory_usage(self): + """Log device memory usage after model load.""" + if hasattr(self.model, "device") and self.model.device.type in ( + "cuda", + "mps", + "npu", + ): + log_gpu_memory_usage(LOG, "after model load", self.model.device) + + def _configure_embedding_dtypes(self): + """Configure embedding module dtypes.""" + # Get embedding modules + embedding_modules = get_linear_embedding_layers(self.cfg.model_config_type) + + # Initial dtype conversion + if not self.cfg.fsdp: + # We don't run this during FSDP because this will leave mixed and bfloat16 + # dtypes in the model which FSDP doesn't like + if self.cfg.load_in_4bit and self.cfg.embeddings_skip_upcast: + embedding_modules = [] + self._convert_embedding_modules_dtype( + embedding_modules, + dist_dtype=torch.float32, + before_kbit_train_or_finetune=True, + ) + + # Handle DeepSpeed Zero3 + if is_deepspeed_zero3_enabled(): + self._set_z3_leaf_modules() + + # Apply gradient checkpointing if needed + needs_fa2_dtype = self.cfg.adapter or self.cfg.fsdp + if self.cfg.adapter in ["lora", "qlora"]: + needs_fa2_dtype = True + if self.cfg.gradient_checkpointing: + self.model.gradient_checkpointing_enable( + gradient_checkpointing_kwargs=self.cfg.gradient_checkpointing_kwargs + ) + + self._prepare_model_for_quantization() + + # Convert dtypes if needed + should_convert = ( + # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so + # we need to convert them back to fp16/bf16 for flash-attn compatibility. + ( + (needs_fa2_dtype or self.cfg.flash_attention or self.cfg.flex_attention) + and not self.qlora_fsdp + ) + # CCE requires embedding layers to be in fp16/bf16 for backward pass + or self.cfg.cut_cross_entropy + ) + + if should_convert: + LOG.info("Converting modules to %s", self.cfg.torch_dtype) + self._convert_embedding_modules_dtype( + embedding_modules=embedding_modules, + dist_dtype=self.cfg.torch_dtype, + before_kbit_train_or_finetune=False, + ) + + def _load_adapters(self) -> PeftConfig | None: + """Load LoRA or other adapters.""" + # Load LoRA or adapter + lora_config = None + if not self.reference_model or self.cfg.lora_model_dir: + # If we're not loading the reference model, then we're loading the model + # for training. Then, the DPO trainer doesn't want the PEFT model loaded + # over it, it just wants the LoRA / PEFT config. + if ( + self.cfg.adapter + and self.cfg.rl in [RLType.DPO, RLType.IPO, RLType.KTO] + and not self.cfg.merge_lora + ): + _, lora_config = load_lora( + self.model, self.cfg, inference=False, config_only=True + ) + else: + self.model, lora_config = load_adapter( + self.model, self.cfg, self.cfg.adapter + ) + + return lora_config + + def _apply_post_lora_load_setup(self, skip_move_to_device: bool): + """Apply final optimizations and patches.""" + # Place model on accelerator + if ( + self.cfg.ddp + and not self.cfg.load_in_8bit + and not (self.cfg.rl and self.cfg.load_in_4bit) + and not skip_move_to_device + ): + # TODO: validate this conditional + self.model.to(f"{str(get_device_type())}:{self.cfg.local_rank}") + + if get_device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1: + self.model.is_parallelizable = True + self.model.model_parallel = True + + if not any( + param.requires_grad + for _, param in self.model.named_parameters(recurse=True) + ): + LOG.warning("There are no parameters that require gradient updates") + + if self.cfg.flash_optimum: + from optimum.bettertransformer import BetterTransformer + + self.model = BetterTransformer.transform(self.model) + + if self.cfg.adapter is not None: + log_gpu_memory_usage(LOG, "after adapters", self.model.device) + + for _ in range(3): + gc.collect() + torch.cuda.empty_cache() + + def _set_auto_model_loader(self): + """Set `self.auto_model_loader`. Defaults to `transformers.AutoModelForCausalLM` + (set at `__init__`). When using a multimodal model, `self.auto_model_loader` + should be set according to the type of the model. + """ + if self.cfg.is_multimodal: + self.auto_model_loader = MULTIMODAL_AUTO_MODEL_MAPPING.get( + self.model_config.model_type, AutoModelForVision2Seq + ) + + def _set_device_map_config(self): + """Setup `device_map` according to config""" + device_map = self.cfg.device_map + max_memory = self.cfg.max_memory + + if self.cfg.gpu_memory_limit: + gpu_memory_limit = ( + str(self.cfg.gpu_memory_limit) + "GiB" + if isinstance(self.cfg.gpu_memory_limit, int) + else self.cfg.gpu_memory_limit + ) + + max_memory = {} + num_device = get_device_count() + for i in range(num_device): + max_memory[i] = gpu_memory_limit + max_memory["cpu"] = "256GiB" # something sufficiently large to fit anything + + if max_memory is not None: + # Based on https://github.com/togethercomputer/OpenChatKit/blob/main/inference/bot.py + from accelerate import infer_auto_device_map + + with init_empty_weights(): + model_canvas = self.auto_model_loader.from_config( + self.model_config, + trust_remote_code=self.cfg.trust_remote_code or False, + ) + model_canvas.tie_weights() + device_map = infer_auto_device_map( + model_canvas, + max_memory=max_memory, + dtype=self.cfg.torch_dtype, + ) + # We can discard max_memory now as we have a device map set up + max_memory = None + + self.model_kwargs["torch_dtype"] = self.cfg.torch_dtype + + if not is_deepspeed_zero3_enabled(): + self.model_kwargs["device_map"] = device_map + + cur_device = get_device_type() + if "mps" in str(cur_device): + self.model_kwargs["device_map"] = "mps:0" + elif "npu" in str(cur_device): + self.model_kwargs["device_map"] = "npu:0" + + # TODO: can we put the reference model on it's own gpu? I think we have to move + # logits around to calculate loss + # if cfg.rl: + # if torch.cuda.device_count() > 1: + # if reference_model: + # model_kwargs["device_map"] = "cuda:" + str( + # torch.cuda.current_device() + 1 + # ) + # else: + # model_kwargs["device_map"] = "cuda:" + str(torch.cuda.current_device()) + + def _set_quantization_config(self): + """Set up quantization config (bitsandbytes, awq, gptq, etc.)""" + self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit + self.model_kwargs["load_in_4bit"] = self.cfg.load_in_4bit + + if self.cfg.gptq: + if not hasattr(self.model_config, "quantization_config"): + LOG.warning( + "model config does not contain quantization_config information" + ) + else: + if self.cfg.gptq_disable_exllama is not None: + self.model_config.quantization_config["disable_exllama"] = ( + self.cfg.gptq_disable_exllama + ) + self.model_kwargs["quantization_config"] = GPTQConfig( + **self.model_config.quantization_config + ) + if ( + self.cfg.adapter in ["qlora", "lora"] + and hasattr(self.model_config, "quantization_config") + and self.model_config.quantization_config["quant_method"] + in ["gptq", "awq", "bitsandbytes"] + ): + if self.model_config.quantization_config["quant_method"] == "gptq": + self.model_kwargs["quantization_config"] = GPTQConfig( + **self.model_config.quantization_config + ) + elif self.model_config.quantization_config["quant_method"] == "awq": + self.model_kwargs["quantization_config"] = AwqConfig( + **self.model_config.quantization_config + ) + elif ( + self.model_config.quantization_config["quant_method"] == "bitsandbytes" + ): + self.model_kwargs["quantization_config"] = BitsAndBytesConfig( + **self.model_config.quantization_config + ) + elif self.cfg.adapter == "qlora" and self.model_kwargs["load_in_4bit"]: + bnb_config = { + "load_in_4bit": True, + "llm_int8_threshold": 6.0, + "llm_int8_has_fp16_weight": False, + "bnb_4bit_compute_dtype": self.cfg.torch_dtype, + "bnb_4bit_use_double_quant": True, + "bnb_4bit_quant_type": "nf4", + "bnb_4bit_quant_storage": torch.bfloat16, + } + if self.cfg.model_config_type in ["jamba", "qwen2_moe"] and not ( + self.cfg.deepspeed or self.cfg.fsdp + ): + # for some reason, this causes the loss to be off by an order of magnitude + # but deepspeed needs this still in bfloat16 + bnb_config["bnb_4bit_quant_storage"] = torch.float32 + + if self.cfg.bnb_config_kwargs: + bnb_config.update(self.cfg.bnb_config_kwargs) + + self.model_kwargs["quantization_config"] = BitsAndBytesConfig( + **bnb_config, + ) + elif self.cfg.adapter == "lora" and self.model_kwargs["load_in_8bit"]: + bnb_config = { + "load_in_8bit": True, + } + # Exclude mamba blocks from int8 quantization for jamba + if self.cfg.model_config_type == "jamba": + bnb_config["llm_int8_skip_modules"] = ["mamba"] + self.model_kwargs["quantization_config"] = BitsAndBytesConfig( + **bnb_config, + ) + + # no longer needed per https://github.com/huggingface/transformers/pull/26610 + if "quantization_config" in self.model_kwargs or self.cfg.gptq: + self.model_kwargs.pop("load_in_8bit", None) + self.model_kwargs.pop("load_in_4bit", None) + + def _set_attention_config(self): + """Sample packing uses custom FA2 patch""" + if self.cfg.flex_attention: + self.model_kwargs["attn_implementation"] = "flex_attention" + self.model_config._attn_implementation = ( # pylint: disable=protected-access + "flex_attention" + ) + + elif self.cfg.flash_attention: + if not self.cfg.sample_packing and self.cfg.s2_attention: + pass + self.model_kwargs["attn_implementation"] = "flash_attention_2" + self.model_config._attn_implementation = ( # pylint: disable=protected-access + "flash_attention_2" + ) + elif self.cfg.sdp_attention: + self.model_kwargs["attn_implementation"] = "sdpa" + self.model_config._attn_implementation = ( # pylint: disable=protected-access + "sdpa" + ) + elif self.cfg.eager_attention: + self.model_kwargs["attn_implementation"] = "eager" + self.model_config._attn_implementation = ( # pylint: disable=protected-access + "eager" + ) + + if self.cfg.low_cpu_mem_usage: + self.model_kwargs["low_cpu_mem_usage"] = True + + def _configure_zero3_memory_efficient_loading(self): + """Set the deepspeed config to load the model into RAM first before moving + to VRAM. + + We need to return `hf_ds_cfg` as it needs to exist before model loading. + """ + hf_ds_cfg = None + + if os.getenv("ACCELERATE_DEEPSPEED_ZERO_STAGE") == "3": + hf_ds_cfg = HfTrainerDeepSpeedConfig(self.cfg.deepspeed) + hf_ds_cfg.fill_match( + "train_micro_batch_size_per_gpu", self.cfg.micro_batch_size + ) + hf_ds_cfg.fill_match( + "gradient_accumulation_steps", self.cfg.gradient_accumulation_steps + ) + hf_ds_cfg.fill_match( + "train_batch_size", + int(os.getenv("WORLD_SIZE", "1")) + * self.cfg.micro_batch_size + * self.cfg.gradient_accumulation_steps, + ) + if "device_map" in self.model_kwargs: + del self.model_kwargs["device_map"] + + transformers.modeling_utils.is_deepspeed_zero3_enabled = lambda: True + transformers.integrations.deepspeed.is_deepspeed_zero3_enabled = ( + lambda: True + ) + + return hf_ds_cfg + + def _build_model(self) -> bool: + """Load model, with load strategy depending on config.""" + skip_move_to_device = False + if ( + self.qlora_fsdp + and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading + and ( + self.cfg.model_config_type == "dbrx" + or self.cfg.qlora_sharded_model_loading + ) + ): + quant_storage = self.cfg.torch_dtype + quantization_config = getattr( + self.model_config, "quantization_config", None + ) + quantization_config = ( + quantization_config or self.model_kwargs["quantization_config"] + ) + self.model = load_sharded_model_quant( + self.base_model, + self.model_config, + self.cfg, + quant_storage=quant_storage, + quantization_config=quantization_config, + ) + skip_move_to_device = True + elif ( + self.model_config.model_type in ["llama", "llama4"] + and not self.cfg.trust_remote_code + and not self.cfg.gptq + ): + # TODO: Do we need to open this up for all models? + if self.cfg.fsdp and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading: + skip_move_to_device = True + if "device_map" in self.model_kwargs: + del self.model_kwargs["device_map"] + + self._configure_zero3_memory_efficient_loading() + + # Load model with random initialization if specified + if self.cfg.random_init_weights: + # AutoModel classes support the from_config method + if self.auto_model_loader in [ + AutoModelForCausalLM, + AutoModelForVision2Seq, + ]: + self.model = self.auto_model_loader.from_config( + config=self.model_config, + ) + else: + self.model = self.auto_model_loader(config=self.model_config) + else: + self.model = self.auto_model_loader.from_pretrained( + self.base_model, + config=self.model_config, + **self.model_kwargs, + ) + elif self.model_type == "MambaLMHeadModel": + # FIXME this is janky at best and hacked together to make it work + MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name + + self.model_kwargs["dtype"] = self.model_kwargs["torch_dtype"] + self.model_kwargs["device"] = torch.cuda.current_device() + self.model_kwargs.pop("torch_dtype", None) + self.model_kwargs.pop("device_map", None) + + self.model = MambaLMHeadModel.from_pretrained( + self.base_model, + **self.model_kwargs, + ) + elif ( + self.model_type + and self.model_type != "AutoModelForCausalLM" + and not self.cfg.trust_remote_code + ): + if self.cfg.gptq: + self.model = self.auto_model_loader.from_pretrained( + self.base_model, + config=self.model_config, + trust_remote_code=self.cfg.trust_remote_code or False, + **self.model_kwargs, + ) + else: + self.model = getattr(transformers, self.model_type).from_pretrained( + self.base_model, + config=self.model_config, + trust_remote_code=self.cfg.trust_remote_code or False, + **self.model_kwargs, + ) + else: + if self.cfg.gptq: + self.model = self.auto_model_loader.from_pretrained( + self.base_model, + config=self.model_config, + trust_remote_code=self.cfg.trust_remote_code or False, + **self.model_kwargs, + ) + else: + if ( + self.cfg.fsdp + and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading + ): + # disabling either of these two still leads to VRAM spike before setting back down + skip_move_to_device = True + if "device_map" in self.model_kwargs: + del self.model_kwargs["device_map"] + + self._configure_zero3_memory_efficient_loading() + + self.model = self.auto_model_loader.from_pretrained( + self.base_model, + config=self.model_config, + trust_remote_code=self.cfg.trust_remote_code or False, + **self.model_kwargs, + ) + if is_deepspeed_zero3_enabled(): + skip_move_to_device = True + + return skip_move_to_device + + def _set_z3_leaf_modules(self): + from deepspeed.utils import set_z3_leaf_modules + + if self.cfg.model_config_type in MOE_ARCH_BLOCK: + moe_blocks = MOE_ARCH_BLOCK[self.cfg.model_config_type] + moe_blocks = [moe_blocks] if isinstance(moe_blocks, str) else moe_blocks + set_z3_leaf_modules( + self.model, + [ + get_module_class_from_name(self.model, module_name) + for module_name in moe_blocks + ], + ) + + def _prepare_model_for_quantization(self): + """Prepare loaded model for quantization.""" + skip_prepare_model_for_kbit_training = False + if self.cfg.model_config_type == "qwen" and self.cfg.adapter == "lora": + # Qwen doesn't play nicely with LoRA if this is enabled + skip_prepare_model_for_kbit_training = True + + loftq_bits = ( + self.cfg.peft + and self.cfg.peft.loftq_config + and self.cfg.peft.loftq_config.loftq_bits + ) + if self.cfg.adapter == "lora" and loftq_bits: + skip_prepare_model_for_kbit_training = True + + if ( + self.qlora_fsdp + or (self.cfg.fsdp and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading) + or is_deepspeed_zero3_enabled() + ): + # Make sure everything is in the same dtype + skip_prepare_model_for_kbit_training = True + + if ( + not skip_prepare_model_for_kbit_training + and self.cfg.adapter in ["lora", "qlora"] + and (self.cfg.load_in_8bit or self.cfg.load_in_4bit) + ): + LOG.info("converting PEFT model w/ prepare_model_for_kbit_training") + self.model = prepare_model_for_kbit_training( + self.model, use_gradient_checkpointing=self.cfg.gradient_checkpointing + ) + + def _convert_embedding_modules_dtype( + self, + embedding_modules: list[str], + dist_dtype: torch.dtype, + before_kbit_train_or_finetune: bool, + ): + for name, module in self.model.named_modules(): + if "norm" in name: + module.to(dist_dtype) + if before_kbit_train_or_finetune: + if name.endswith(".gate"): + module.to(dist_dtype) + if self.model_config.model_type == "btlm": + # don't upcast lm_head for btlm + continue + if any(m in name for m in embedding_modules) and hasattr(module, "weight"): + module.to(dist_dtype) diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py new file mode 100644 index 000000000..f251f958d --- /dev/null +++ b/src/axolotl/loaders/patch_manager.py @@ -0,0 +1,380 @@ +"""Patch manager class implementation to complement `axolotl.loaders.ModelLoader`. + +Applies pre- and post-model load patches for various fixes and optimizations. +""" + +import importlib.util +import logging +from functools import cached_property + +import addict +import transformers +from transformers import PretrainedConfig, PreTrainedModel + +from axolotl.integrations.base import PluginManager +from axolotl.monkeypatch.multipack import ( + SUPPORTED_MULTIPACK_MODEL_TYPES, + patch_for_multipack, +) +from axolotl.utils.dict import DictDefault + +LOG = logging.getLogger(__name__) +PLUGIN_MANAGER = PluginManager.get_instance() + + +class PatchManager: + """Manages the application of patches during the model loading process.""" + + def __init__( + self, + cfg: DictDefault, + model_config: PretrainedConfig | addict.Dict, + inference: bool = False, + ): + """Initialize the `PatchManager`. + + Args: + cfg: Configuration dictionary with model and training settings. + model_config: Configuration object for the model. + inference: Whether the model is being loaded for inference mode. + """ + self.cfg = cfg + self.model_config = model_config + self.inference = inference + + @cached_property + def has_flash_attn(self) -> bool: + """Check if flash attention is installed.""" + return importlib.util.find_spec("flash_attn") is not None + + def apply_pre_model_load_patches(self): + """Apply pre-model load patches based on config.""" + self._apply_flash_attention_patches() + self._apply_fsdp_patches() + self._apply_adapter_patches() + self._apply_flex_attention_patches() + self._apply_model_specific_patches() + self._apply_fp8_patches() + self._apply_flash_attention_peft_patches() + self._apply_gradient_checkpointing_patches() + self._patch_attention() + self._apply_multipack_patches() + self._patch_llama_derived_model() + self._apply_mistral_cross_entropy_patch() + self._apply_unsloth_self_attention_patch() + + def apply_post_model_load_patches(self, model: PreTrainedModel): + """Apply patches that require the model instance.""" + self._apply_llama_flash_attn_patches(model) + self._apply_unsloth_patches(model) + self._apply_lora_kernel_patch(model) + + def _apply_flash_attention_patches(self): + """Apply patches related to Flash Attention.""" + if self.cfg.xformers_attention and self.cfg.sample_packing: + from axolotl.monkeypatch.attention import patch_xformers_attn_over_fa2 + + patch_xformers_attn_over_fa2() + self.cfg.flash_attention = True + + def _apply_fsdp_patches(self): + """Apply patches for FSDP configurations.""" + if self.cfg.fsdp_config and str(self.cfg.fsdp_config.fsdp_version) == "2": + from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp_utils + + patch_accelerate_fsdp_utils() + + def _apply_adapter_patches(self): + """Apply patches for adapter configurations.""" + if self.cfg.adapter and self.cfg.embeddings_skip_upcast: + from axolotl.monkeypatch.peft.utils import patch_peft_prep_code + + patch_peft_prep_code() + + def _apply_flex_attention_patches(self): + """Apply patches for flexible attention.""" + if self.cfg.flex_attention: + from axolotl.monkeypatch.attention.flex_attn import ( + patch_flex_make_mask, + patch_flex_wrapper, + ) + + flex_attn_compile_kwargs = self.cfg.flex_attn_compile_kwargs or {} + patch_flex_wrapper(**flex_attn_compile_kwargs) + patch_flex_make_mask() + + def _apply_model_specific_patches(self): + """Apply patches specific to model architectures.""" + if ( + self.cfg.model_config_type == "llama4" + and self.cfg.llama4_linearized_experts + ): + from axolotl.monkeypatch.models.llama4.modeling import ( + patch_llama4_linearized_modeling, + ) + + patch_llama4_linearized_modeling() + + if self.cfg.model_config_type == "gemma3": + from axolotl.monkeypatch.gemma3 import ( + patch_gemma3conditionalgeneration_forward, + ) + + patch_gemma3conditionalgeneration_forward() + + def _apply_fp8_patches(self): + """Apply patches for FP8 support.""" + if self.cfg.fp8: + from axolotl.monkeypatch.trainer_accelerator_args import ( + patch_create_accelerate_code_for_fp8, + ) + + patch_create_accelerate_code_for_fp8() + + def _apply_flash_attention_peft_patches(self): + """Apply patches for Flash Attention with PEFT.""" + if self.cfg.adapter: + from axolotl.monkeypatch.transformers_fa_utils import ( + patch_fa_peft_integration, + ) + + patch_fa_peft_integration() + + def _apply_gradient_checkpointing_patches(self): + """Apply patches for gradient checkpointing.""" + if self.cfg.gradient_checkpointing in ["unsloth", "offload"]: + from axolotl.monkeypatch.gradient_checkpointing import ( + hf_grad_checkpoint_offload_wrapper, + ) + + transformers.modeling_utils.checkpoint = hf_grad_checkpoint_offload_wrapper + if self.cfg.gradient_checkpointing == "offload_disk": + from axolotl.monkeypatch.gradient_checkpointing import ( + hf_grad_checkpoint_disk_offload_wrapper, + ) + + transformers.modeling_utils.checkpoint = ( + hf_grad_checkpoint_disk_offload_wrapper + ) + + def _apply_mistral_cross_entropy_patch(self): + """Apply Mistral cross entropy patch if configured.""" + if ( + self.cfg.model_config_type == "mistral" + and self.cfg.flash_attn_cross_entropy_loss + ): + from axolotl.monkeypatch.mistral_attn_hijack_flash import ( + patch_mistral_cross_entropy, + ) + + patch_mistral_cross_entropy() + + def _apply_unsloth_self_attention_patch(self): + """Apply Unsloth self-attention patches if configured.""" + if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o: + from axolotl.monkeypatch.lora_kernels import patch_self_attn_lora + + patch_self_attn_lora(self.cfg) + + def _apply_multipack_patches(self): + """Apply multipack patches if necessary.""" + if ( + self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES + and (self.cfg.flash_attention or self.cfg.flex_attention) + and self.cfg.sample_packing + ): + # Get automap config if it exists + auto_map_config = None + if isinstance(self.model_config, dict) and "auto_map" in self.model_config: + auto_map_config = self.model_config["auto_map"] + elif hasattr(self.model_config, "auto_map"): + auto_map_config = self.model_config.auto_map + + # Determine if the model has remote code + if auto_map_config is not None: + has_remote_code = "AutoModelForCausalLM" in auto_map_config + else: + has_remote_code = False + + if has_remote_code and self.cfg.trust_remote_code is False: + # If explicitly set in YAML, prefer that + has_remote_code = self.cfg.trust_remote_code + + patch_for_multipack( + self.cfg.model_config_type, + model_name=self.cfg.base_model, + has_remote_code=has_remote_code, + ) + + if self.cfg.is_llama_derived_model: + self._patch_loss_llama() + + def _patch_attention(self): + """Apply attention-specific patches based on model type.""" + if not (self.cfg.flash_attention and hasattr(self.model_config, "model_type")): + return + + if self.model_config.model_type == "mllama" and self.cfg.flash_attention: + from axolotl.monkeypatch.attention.mllama import patch_mllama + + patch_mllama() + + if self.model_config.model_type == "btlm": + from axolotl.monkeypatch.btlm_attn_hijack_flash import ( + replace_btlm_attn_with_flash_attn, + ) + + replace_btlm_attn_with_flash_attn(self.cfg.base_model) + + if self.model_config.model_type == "stablelm_epoch" and self.cfg.sample_packing: + from axolotl.monkeypatch.stablelm_attn_hijack_flash import ( + replace_stablelm_attn_with_flash_attn, + ) + + replace_stablelm_attn_with_flash_attn(self.cfg.base_model) + + def _patch_loss_llama(self): + """Patch loss functions and other optimizations for LLaMA models.""" + if self.cfg.flash_attn_cross_entropy and self.has_flash_attn: + from axolotl.monkeypatch.llama_attn_hijack_flash import ( + patch_fa_llama_cross_entropy, + ) + + patch_fa_llama_cross_entropy() + elif self.cfg.unsloth_cross_entropy_loss: + from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch + + integrate_cross_entropy_loss_patch(model_type="llama") + + if self.cfg.flash_attn_rms_norm and self.has_flash_attn: + from axolotl.monkeypatch.llama_attn_hijack_flash import patch_llama_rms_norm + + patch_llama_rms_norm() + elif self.cfg.unsloth_rms_norm: + from axolotl.monkeypatch.unsloth_ import patch_unsloth_layernorm + + patch_unsloth_layernorm() + + if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o: + from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora + + patch_self_attn_lora() + + def _patch_llama_flash_attention(self, packed=False): + """Apply Flash Attention patches for LLaMA models.""" + from axolotl.monkeypatch.llama_attn_hijack_flash import ( + replace_llama_attn_with_flash_attn, + ) + + if packed: + if self.cfg.device not in ["mps", "cpu"] and not self.inference: + LOG.info("patching with flash attention for sample packing") + replace_llama_attn_with_flash_attn( + packed=True, + cross_entropy=self.cfg.flash_attn_cross_entropy, + rms_norm=self.cfg.flash_attn_rms_norm, + ) + elif self.cfg.s2_attention: + LOG.info("patching w/ flash-enabled, shifted-sparse attention") + replace_llama_attn_with_flash_attn( + packed=False, + cross_entropy=self.cfg.flash_attn_cross_entropy, + rms_norm=self.cfg.flash_attn_rms_norm, + use_shifted_sparse_attn=True, + ) + elif self.cfg.flash_attn_cross_entropy or self.cfg.flash_attn_rms_norm: + replace_llama_attn_with_flash_attn( + packed=False, + cross_entropy=self.cfg.flash_attn_cross_entropy, + rms_norm=self.cfg.flash_attn_rms_norm, + ) + + def _patch_llama_xformers_attention(self): + """Apply xformers attention patches for LLaMA models.""" + from axolotl.monkeypatch.llama_attn_hijack_xformers import ( + hijack_llama_attention, + ) + + LOG.info("Patching with xformers attention...") + hijack_llama_attention() + + def _patch_llama_sample_packing(self): + """Apply sample packing patches for LLaMA models.""" + from axolotl.monkeypatch.llama_patch_multipack import ( + hijack_llama_prepare_4d_mask, + ) + + LOG.info("Patching llama _prepare_4d_causal_attention_mask*...") + hijack_llama_prepare_4d_mask() + + def _patch_llama_derived_model(self): + """Modify all llama derived models in one block.""" + if self.cfg.is_llama_derived_model and not ( + self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES + and (self.cfg.flash_attention or self.cfg.flex_attention) + and self.cfg.sample_packing + ): + self._patch_loss_llama() + + if self.cfg.flash_attention: + self._patch_llama_flash_attention(packed=self.cfg.sample_packing) + elif self.cfg.xformers_attention: + self._patch_llama_xformers_attention() + elif self.cfg.sample_packing: + self._patch_llama_sample_packing() + elif self.cfg.s2_attention: + raise NotImplementedError( + "Shifted-sparse attention not currently implemented without flash attention." + ) + + def _apply_llama_flash_attn_patches(self, model): + """Apply LLaMA-specific flash attention patches.""" + if ( + self.model_config.model_type in ["llama", "llama4"] + and not self.cfg.trust_remote_code + and not self.cfg.gptq + and self.cfg.flash_attention + and not self.inference + ): + # TODO(MengqingCao): split these patches seperately + from axolotl.monkeypatch.llama_attn_hijack_flash import ( + is_xformers_swiglu_available, + replace_llama_mlp_with_swiglu, + replace_llama_qkv_with_fused, + ) + + if self.cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available(): + LOG.info("Patching with SwiGLU...") + replace_llama_mlp_with_swiglu(model) + + if self.cfg.flash_attn_fuse_qkv: + LOG.info("Patching with fused QKV...") + replace_llama_qkv_with_fused(model) + + def _apply_unsloth_patches(self, model): + """Apply unsloth optimization patches.""" + if self.cfg.unsloth_lora_mlp: + from axolotl.monkeypatch.unsloth_ import integrate_lora_mlp_patch + + integrate_lora_mlp_patch(peft_model=model) + + if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o: + from axolotl.monkeypatch.unsloth_ import integrate_lora_patch + + integrate_lora_patch(peft_model=model, cfg=self.cfg) + + if self.cfg.unsloth_rope: + from axolotl.monkeypatch.unsloth_ import integrate_rope_embeddings + + integrate_rope_embeddings() + + def _apply_lora_kernel_patch(self, model): + """Apply LoRA kernel patches.""" + if ( + self.cfg.lora_mlp_kernel + or self.cfg.lora_qkv_kernel + or self.cfg.lora_o_kernel + ): + from axolotl.monkeypatch.lora_kernels import apply_lora_kernel_patches + + apply_lora_kernel_patches(model=model, cfg=self.cfg) diff --git a/src/axolotl/loaders/processor.py b/src/axolotl/loaders/processor.py new file mode 100644 index 000000000..57394bc67 --- /dev/null +++ b/src/axolotl/loaders/processor.py @@ -0,0 +1,56 @@ +"""Processor loading functionality for multi-modal models""" + +import logging +from typing import Any + +import transformers +from transformers import ( + AutoProcessor, + PreTrainedTokenizerBase, +) + +from axolotl.utils.dict import DictDefault + +LOG = logging.getLogger(__name__) + + +def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase): + processor_kwargs: dict[str, Any] = {} # Do we actually need this? + + processor_cls = AutoProcessor + if cfg.processor_type: + processor_cls = getattr(transformers, cfg.processor_type) + + processor = processor_cls.from_pretrained( + cfg.processor_config, + trust_remote_code=cfg.trust_remote_code or False, + tokenizer=tokenizer, + **processor_kwargs, + ) + + # Attempt to load image size from processor if available + if ( + cfg.image_size is None + and hasattr(processor, "size") + and any(dim in processor.size for dim in ["width", "height"]) + ): + im_width = None + im_height = None + if "width" in processor.size: + im_width = processor.size["width"] + if "height" in processor.size: + im_height = processor.size["height"] + + # If both width and height are set, use a tuple + if im_width is not None and im_height is not None: + cfg.image_size = (im_width, im_height) + # If only width is set, use as integer + elif im_width is not None: + cfg.image_size = im_width + # If only height is set, use as integer + elif im_height is not None: + cfg.image_size = im_height + + LOG.debug(f"Loaded image size: {cfg.image_size} from processor") + + return processor diff --git a/src/axolotl/loaders/tokenizer.py b/src/axolotl/loaders/tokenizer.py new file mode 100644 index 000000000..ec9d69e8a --- /dev/null +++ b/src/axolotl/loaders/tokenizer.py @@ -0,0 +1,281 @@ +"""Tokenizer loading functionality and associated utils""" + +import json +import logging +import os + +import transformers +from transformers import ( + AddedToken, + AutoTokenizer, +) + +from axolotl.integrations.base import PluginManager +from axolotl.loaders.utils import get_linear_embedding_layers, load_model_config +from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN +from axolotl.utils.chat_templates import get_chat_template_from_config +from axolotl.utils.distributed import ( + barrier, + is_local_main_process, + is_main_process, +) + +LOG = logging.getLogger(__name__) +PLUGIN_MANAGER = PluginManager.get_instance() + + +def modify_tokenizer_files( + tokenizer_path: str, token_mappings: dict[int, str], output_dir: str +) -> str: + """ + Modify tokenizer files to replace added_tokens strings, save to output directory, + and return the path to the modified tokenizer. + + This only works with reserved tokens that were added to the tokenizer, not tokens + already part of the vocab. + + Args: + tokenizer_path: Path or name of the original tokenizer + token_mappings: Dict mapping {token_id (int): new_token_string} + output_dir: Directory to save the modified tokenizer + + Returns: + Path to the modified tokenizer directory + + Ref: https://github.com/huggingface/transformers/issues/27974#issuecomment-1854188941 + """ + # Create the tokenizer directory in output_dir if it doesn't exist + tokenizer_dir = os.path.join(output_dir, "tokenizer") + os.makedirs(tokenizer_dir, exist_ok=True) + + if is_local_main_process(): # pylint: disable=too-many-nested-blocks + # Load the tokenizer + temp_tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True) + + # Save the tokenizer to the output directory + temp_tokenizer.save_pretrained(tokenizer_dir) + + # Get the token IDs and map them to their new values + token_id_mappings = { + int(token_id): new_value for token_id, new_value in token_mappings.items() + } + + # 1. Update tokenizer_config.json - added_tokens_decoder + config_path = os.path.join(tokenizer_dir, "tokenizer_config.json") + if os.path.exists(config_path): + with open(config_path, "r", encoding="utf-8") as f: + config_data = json.load(f) + + # Update added_tokens_decoder + if "added_tokens_decoder" in config_data: + for token_id, new_value in token_id_mappings.items(): + token_id_str = str(token_id) + if token_id_str in config_data["added_tokens_decoder"]: + config_data["added_tokens_decoder"][token_id_str][ + "content" + ] = new_value + else: + raise ValueError( + f"Token ID {token_id_str} not found in added_tokens_decoder" + ) + + # Write the updated config back + with open(config_path, "w", encoding="utf-8") as f: + json.dump(config_data, f, indent=2) + + # 2. Update tokenizer.json - added_tokens + tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.json") + if os.path.exists(tokenizer_path): + with open(tokenizer_path, "r", encoding="utf-8") as f: + tokenizer_data = json.load(f) + + # Update added_tokens + if "added_tokens" in tokenizer_data: + for token_id, new_value in token_id_mappings.items(): + for i, token_entry in enumerate(tokenizer_data["added_tokens"]): + if token_entry["id"] == token_id: + tokenizer_data["added_tokens"][i]["content"] = new_value + break + else: + # Reaching this section means the token_id was not found in tokenizer.json added_tokens + raise ValueError( + f"Token ID {token_id} not found in added_tokens" + ) + if "model" in tokenizer_data and "vocab" in tokenizer_data["model"]: + for token_id, new_value in token_id_mappings.items(): + for entry_val, entry_id in tokenizer_data["model"]["vocab"].items(): + if entry_id == token_id: + del tokenizer_data["model"]["vocab"][entry_val] + tokenizer_data["model"]["vocab"][new_value] = token_id + break + + # Write the updated tokenizer data back + with open(tokenizer_path, "w", encoding="utf-8") as f: + json.dump(tokenizer_data, f, indent=2) + + barrier() + return tokenizer_dir + + +def load_tokenizer(cfg): + """Load and configure the tokenizer based on the provided config.""" + model_config = load_model_config(cfg) + tokenizer_kwargs = {} + use_fast = True # this is the default + + if cfg.tokenizer_use_fast is not None: + use_fast = cfg.tokenizer_use_fast + if cfg.tokenizer_legacy is not None: + # True is the default w/ https://github.com/huggingface/transformers/pull/25224 + tokenizer_kwargs["legacy"] = cfg.tokenizer_legacy + + tokenizer_cls = AutoTokenizer + if cfg.tokenizer_type: + tokenizer_cls = getattr(transformers, cfg.tokenizer_type) + + # Set base tokenizer path + tokenizer_path = cfg.tokenizer_config + + # Apply token string overrides if specified + if cfg.added_tokens_overrides: + # Modify tokenizer files and get path to modified tokenizer + tokenizer_path = modify_tokenizer_files( + tokenizer_path, cfg.added_tokens_overrides, output_dir=cfg.output_dir + ) + + tokenizer = tokenizer_cls.from_pretrained( + tokenizer_path, + trust_remote_code=cfg.trust_remote_code or False, + use_fast=use_fast, + **tokenizer_kwargs, + ) + + if ( + tokenizer.__class__.__name__ + in [ + "LlamaTokenizer", + "LlamaTokenizerFast", + "CodeLlamaTokenizer", + "CodeLlamaTokenizerFast", + ] + and hasattr(tokenizer, "pad_token") + and not tokenizer.pad_token + ): + # set a pad_token, but use eos_token so we don't add a new token + tokenizer.pad_token = LLAMA_DEFAULT_EOS_TOKEN + + if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast": + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + os.environ["TOKENIZERS_PARALLELISM"] = "false" + + # Mistral's official FA implementation requires left padding + if cfg.is_mistral_derived_model and cfg.flash_attention and not cfg.sample_packing: + tokenizer.padding_side = "left" + + # Qwen base only has single token, so we need to set the special tokens + if cfg.is_qwen_derived_model: + token_ids = ["bos_token_id", "eos_token_id", "pad_token_id", "unk_token_id"] + for attr_name in token_ids: + if getattr(tokenizer, attr_name) is None: + setattr(tokenizer, attr_name, tokenizer.eod_id) + + token_names = ["bos_token", "eos_token", "pad_token", "unk_token"] + for attr_name in token_names: + if getattr(tokenizer, attr_name) is None: + setattr(tokenizer, attr_name, "<|endoftext|>") + + additional_special_tokens = None + if cfg.special_tokens: + special_tokens = cfg.special_tokens.to_dict() + additional_special_tokens = special_tokens.pop( + "additional_special_tokens", None + ) + lora_modules_to_save = get_linear_embedding_layers(model_config.model_type) + for k, val in special_tokens.items(): + # check if new special token is not already in tokenizer and + # is adapter training to make sure lora_modules_to_save is set + # pylint: disable=too-many-boolean-expressions + if ( + (getattr(tokenizer, k) is None or getattr(tokenizer, k) != val) + and (len(tokenizer.encode(val, add_special_tokens=False)) > 2) + and cfg.adapter + and ( + not cfg.lora_modules_to_save + or not all( + x in cfg.lora_modules_to_save for x in lora_modules_to_save + ) + ) + and k != "pad_token" + ): + lora_modules_to_save = ", ".join( + [f"`{x}`" for x in lora_modules_to_save] + ) + raise ValueError( + f"Please set lora_modules_to_save to [{lora_modules_to_save}] when using an adapter and changing the special tokens." + ) + + tokenizer.add_special_tokens( + {k: AddedToken(val, rstrip=False, lstrip=False, normalized=False)} + ) + + # If we add bos_token and eos_token, we need to update the post processor to + # handle them correctly. + # https://github.com/huggingface/transformers/pull/24132 + bos_or_eos_in_special_tokens = ( + "bos_token" in cfg.special_tokens and "eos_token" in cfg.special_tokens + ) + if ( + tokenizer.__class__.__name__ + in ( + "LlamaTokenizerFast", + "CodeLlamaTokenizerFast", + ) + and bos_or_eos_in_special_tokens + ): + tokenizer.update_post_processor() + + if cfg.tokens: + tokenizer.add_tokens( + [ + AddedToken(token, rstrip=False, lstrip=False, normalized=False) + for token in cfg.tokens + ] + ) + + # Additional special tokens are a List, and need to be treated differently than regular special + # tokens. We add them after we have called `add_tokens` in case these additional special tokens + # are new tokens. + # + # Usage: + # + # ```py + # special_tokens: + # additional_special_tokens: ["<|im_start|>", "<|im_end|>"] + # ``` + if additional_special_tokens is not None: + tokenizer.add_special_tokens( + {"additional_special_tokens": additional_special_tokens} + ) + + if is_main_process(use_environ=True): + LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}") + LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}") + LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}") + LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}") + + if cfg.chat_template: + chat_template_string = get_chat_template_from_config( + cfg=cfg, + tokenizer=tokenizer, + ) + if cfg.default_system_message and cfg.chat_template == "chatml": + chat_template_string = chat_template_string.replace( + "You are a helpful assistant.", cfg.default_system_message + ) + + tokenizer.chat_template = chat_template_string + else: + LOG.info( + "No Chat template selected. Consider adding a chat template for easier inference." + ) + return tokenizer diff --git a/src/axolotl/loaders/utils.py b/src/axolotl/loaders/utils.py new file mode 100644 index 000000000..1aae4834d --- /dev/null +++ b/src/axolotl/loaders/utils.py @@ -0,0 +1,211 @@ +"""Utilities for axolotl.loaders module""" + +import contextlib +import logging +from typing import Type + +import addict +import torch +from transformers import AutoConfig, PretrainedConfig, PreTrainedModel + +from axolotl.utils.dict import DictDefault + +LOG = logging.getLogger(__name__) + + +def get_module_class_from_name( + module: torch.nn.Module, name: str +) -> Type[torch.nn.Module] | None: + """Gets a class from a module by its name. Copied from `accelerate.utils.dataclasses` + (https://github.com/huggingface/accelerate/blob/main/src/accelerate/utils/dataclasses.py#L2805). + + Args: + module: The module to get the class from. + name: The name of the class. + + Returns: + The class type of the matching module, or `None` if no match is found. + """ + modules_children = list(module.children()) + if module.__class__.__name__ == name: + return module.__class__ + + if len(modules_children) == 0: + return None + + for child_module in modules_children: + module_class = get_module_class_from_name(child_module, name) + if module_class is not None: + return module_class + + return None + + +def check_model_config(cfg: DictDefault, model_config: PretrainedConfig): + """Validates and adjusts model config based on `axolotl` config. + + This function performs several important checks and adjustments: + - Disables model caching for better memory efficiency + - Handles multimodal model-specific configurations + - Validates quantization settings + - Ensures proper LoRA configuration when using adapters with new tokens + + Args: + cfg: Dictionary mapping `axolotl` config keys to values. + model_config: The model's configuration object from `transformers`. + + Raises: + ValueError: If a multimodal model lacks text configuration, if GPTQ settings + are inconsistent, or if LoRA `modules_to_save` is improperly configured + with new tokens. + """ + if hasattr(model_config, "use_cache"): + model_config.use_cache = False + + if cfg.is_multimodal: + # For multimodal configs, use_cache is set in the text_config + if hasattr(model_config, "get_text_config"): + text_config = model_config.get_text_config() + if hasattr(text_config, "use_cache"): + text_config.use_cache = False + else: + raise ValueError( + "No text config found for multimodal model. Please raise an Issue with model details." + ) + + # Check if image_size is not set and load image size from model config if available + if ( + cfg.image_size is None + and hasattr(model_config, "vision_config") + and hasattr(model_config.vision_config, "image_size") + ): + cfg.image_size = model_config.vision_config.image_size + LOG.debug(f"Loaded image size: {cfg.image_size} from model config") + + quant_config_exists = ( + hasattr(model_config, "quantization_config") + and model_config.quantization_config + ) + + # Detect compressed-tensors config + is_compressed_tensors_config = ( + quant_config_exists + and model_config.quantization_config.get("quant_method") == "compressed-tensors" + ) + + if is_compressed_tensors_config: + if model_config.quantization_config.get("config_groups"): + LOG.warning( + "Found `config_groups` in a compressed-tensors config. " + "QAT integration with llmcompressor is not tested." + ) + # Skip further quant checks for compressed-tensors + return + + quant_config_method_is_gptq = ( + quant_config_exists + and "quant_method" in model_config.quantization_config + and model_config.quantization_config["quant_method"] == "gptq" + ) + + if cfg.gptq and not quant_config_method_is_gptq: + raise ValueError( + "model_config.quantization_config is not set or quant_method is not set to gptq. " + "Please make sure to point to a GPTQ model." + ) + + lora_modules_to_save = get_linear_embedding_layers(model_config.model_type) + if ( + cfg.adapter + and cfg.tokens + and ( + not cfg.lora_modules_to_save + or not all(x in cfg.lora_modules_to_save for x in lora_modules_to_save) + ) + ): + lora_modules_to_save_joined = ", ".join( + map(lambda x: f"`{x}`", lora_modules_to_save) + ) + raise ValueError( + "`lora_modules_to_save` not properly set when adding new tokens. " + f"Please include [{lora_modules_to_save_joined}] in `lora_modules_to_save`." + ) + + +def load_model_config(cfg: DictDefault) -> PretrainedConfig | addict.Dict: + """Loads and configures a model configuration from HuggingFace or local sources. + + This function determines the appropriate model config source, loads it, applies any + necessary overrides, and validates it for compatibility with the `axolotl` config. + + Args: + cfg: Dictionary mapping `axolotl` config keys to values. + + Returns: + A configured model configuration object (`AutoConfig` instance), or a simple + dictionary configuration for special cases like Mamba models. + + Raises: + ValueError: If configuration loading fails for reasons other than special cases + that are handled (e.g., Mamba models). + """ + model_config_name = cfg.base_model_config or cfg.base_model + if not model_config_name and cfg.tokenizer_config: + model_config_name = cfg.tokenizer_config + trust_remote_code = cfg.trust_remote_code is True + config_kwargs = {} + if cfg.revision_of_model: + config_kwargs["revision"] = cfg.revision_of_model + if cfg.num_labels: + # num_labels is used to initialize classifier models + config_kwargs["num_labels"] = cfg.num_labels + try: + model_config = AutoConfig.from_pretrained( + model_config_name, + trust_remote_code=trust_remote_code, + **config_kwargs, + ) + except ValueError as error: + if "mamba" in model_config_name: + return addict.Dict( + { + "model_type": "mamba", + } + ) + raise error + + if cfg.overrides_of_model_config: + for key, val in cfg.overrides_of_model_config.items(): + setattr(model_config, key, val) + + check_model_config(cfg, model_config) + + return model_config + + +def ensure_dtype(model: PreTrainedModel, dtype: torch.dtype = torch.bfloat16): + """Ensures all modules in the model are converted to the specified data type.""" + for name, module in model.named_modules(): + weight_mismatch = False + with contextlib.suppress(AttributeError): + weight_mismatch = module.weight.dtype != dtype + + bias_mismatch = False + with contextlib.suppress(AttributeError): + bias_mismatch = module.bias.dtype != dtype + + if weight_mismatch: + print(f"Converting module {name}.weight: {module.weight.dtype} -> {dtype}") + if bias_mismatch: + print(f"Converting module {name}.bias: {module.bias.dtype} -> {dtype}") + if weight_mismatch or bias_mismatch: + module.to(dtype) + + +def get_linear_embedding_layers(model_type: str) -> list[str]: + """Returns layer names of linear embeddings needed for LoRA based on model type.""" + if model_type == "gpt_neox": + return ["embed_in", "embed_out"] + if model_type == "falcon": + return ["word_embeddings", "lm_head"] + return ["embed_tokens", "lm_head"] diff --git a/src/axolotl/utils/gradient_checkpointing/__init__.py b/src/axolotl/monkeypatch/gradient_checkpointing/__init__.py similarity index 91% rename from src/axolotl/utils/gradient_checkpointing/__init__.py rename to src/axolotl/monkeypatch/gradient_checkpointing/__init__.py index ae0c559e9..5d631776b 100644 --- a/src/axolotl/utils/gradient_checkpointing/__init__.py +++ b/src/axolotl/monkeypatch/gradient_checkpointing/__init__.py @@ -5,10 +5,10 @@ from functools import partial from packaging import version -from axolotl.utils.gradient_checkpointing.offload_cpu import ( +from axolotl.monkeypatch.gradient_checkpointing.offload_cpu import ( CPU_Offloaded_Gradient_Checkpointer, ) -from axolotl.utils.gradient_checkpointing.offload_disk import ( +from axolotl.monkeypatch.gradient_checkpointing.offload_disk import ( Disco, ) diff --git a/src/axolotl/utils/gradient_checkpointing/offload_cpu.py b/src/axolotl/monkeypatch/gradient_checkpointing/offload_cpu.py similarity index 100% rename from src/axolotl/utils/gradient_checkpointing/offload_cpu.py rename to src/axolotl/monkeypatch/gradient_checkpointing/offload_cpu.py diff --git a/src/axolotl/utils/gradient_checkpointing/offload_disk.py b/src/axolotl/monkeypatch/gradient_checkpointing/offload_disk.py similarity index 100% rename from src/axolotl/utils/gradient_checkpointing/offload_disk.py rename to src/axolotl/monkeypatch/gradient_checkpointing/offload_disk.py diff --git a/src/axolotl/monkeypatch/peft/utils.py b/src/axolotl/monkeypatch/peft/utils.py index b3703d398..fdc49c5f6 100644 --- a/src/axolotl/monkeypatch/peft/utils.py +++ b/src/axolotl/monkeypatch/peft/utils.py @@ -75,4 +75,4 @@ def patch_peft_prep_code(): exec(prep_code, globals()) # pylint: disable=exec-used # nosec B102 LOG.info("patching prepare_model_for_kbit_training to allow for overrides") peft.utils.other.prepare_model_for_kbit_training = fixed_prepare_model_for_kbit_training # pylint: disable=protected-access # pylint: disable=undefined-variable # noqa: F821 - axolotl.utils.models.prepare_model_for_kbit_training = fixed_prepare_model_for_kbit_training # pylint: disable=protected-access # pylint: disable=undefined-variable # noqa: F821 + axolotl.loaders.model.prepare_model_for_kbit_training = fixed_prepare_model_for_kbit_training # pylint: disable=protected-access # pylint: disable=undefined-variable # noqa: F821 diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 46f722eeb..52ec8f22b 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -28,11 +28,15 @@ from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module ) from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder from axolotl.integrations.base import PluginManager +from axolotl.loaders import ( + ModelLoader, + load_processor, + load_tokenizer, +) from axolotl.utils.ctx_managers.sequence_parallel import SequenceParallelContextManager from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import cleanup_distributed from axolotl.utils.freeze import freeze_layers_except -from axolotl.utils.models import load_model, load_processor, load_tokenizer from axolotl.utils.schemas.enums import RLType from axolotl.utils.trainer import setup_trainer @@ -76,7 +80,8 @@ def setup_model_and_tokenizer( msg += " and peft_config..." LOG.debug(msg) - model, peft_config = load_model(cfg, tokenizer, processor=processor) + model_loader = ModelLoader(cfg, tokenizer, processor=processor) + model, peft_config = model_loader.load() if model.generation_config is not None: model.generation_config.do_sample = True @@ -113,7 +118,8 @@ def setup_reference_model( model_ref = None # explicit setting to None else: # load the model again for model_ref/baseline - model_ref, _ = load_model(cfg, tokenizer, reference_model=True) + model_loader = ModelLoader(cfg, tokenizer, reference_model=True) + model_ref, _ = model_loader.load() return model_ref diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index a96cc1286..49e4cfc6f 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -11,9 +11,10 @@ from transformers.utils.import_utils import is_torch_npu_available from axolotl.integrations.base import PluginManager from axolotl.integrations.config import merge_input_args +from axolotl.loaders import MULTIMODAL_AUTO_MODEL_MAPPING +from axolotl.loaders.utils import load_model_config from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.dict import DictDefault -from axolotl.utils.models import MULTIMODAL_AUTO_MODEL_MAPPING, load_model_config from axolotl.utils.schemas.config import ( AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase, ) diff --git a/src/axolotl/utils/ctx_managers/sequence_parallel.py b/src/axolotl/utils/ctx_managers/sequence_parallel.py index 2ae93acad..491cb9877 100644 --- a/src/axolotl/utils/ctx_managers/sequence_parallel.py +++ b/src/axolotl/utils/ctx_managers/sequence_parallel.py @@ -10,7 +10,7 @@ from torch.utils.hooks import RemovableHandle from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.utils import ModelOutput -from axolotl.monkeypatch.ring_attn.patch import ( +from axolotl.monkeypatch.ring_attn import ( get_ring_attn_group, patch_prepare_data_loader, patch_prepare_device_mesh, diff --git a/src/axolotl/utils/data/rl.py b/src/axolotl/utils/data/rl.py index dc5920099..15744d4c6 100644 --- a/src/axolotl/utils/data/rl.py +++ b/src/axolotl/utils/data/rl.py @@ -10,6 +10,7 @@ import yaml from datasets import Dataset, DatasetDict, concatenate_datasets, load_from_disk from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH +from axolotl.loaders import load_tokenizer from axolotl.prompt_strategies.dpo import load as load_dpo from axolotl.prompt_strategies.kto import load as load_kto from axolotl.prompt_strategies.orpo import load as load_orpo @@ -17,7 +18,6 @@ from axolotl.utils.data.shared import datasets_w_name_generator, load_dataset_w_ from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5 from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import is_main_process, zero_first -from axolotl.utils.models import load_tokenizer from axolotl.utils.schemas.enums import RLType LOG = logging.getLogger(__name__) diff --git a/src/axolotl/utils/lora_embeddings.py b/src/axolotl/utils/lora_embeddings.py deleted file mode 100644 index 70f56655e..000000000 --- a/src/axolotl/utils/lora_embeddings.py +++ /dev/null @@ -1,14 +0,0 @@ -""" -helpers for lora embeddings -""" - - -def get_linear_embedding_layers(model_type): - """ - returns the linear embedding layers needed for loras, dependent on the model arch - """ - if model_type == "gpt_neox": - return ["embed_in", "embed_out"] - if model_type == "falcon": - return ["word_embeddings", "lm_head"] - return ["embed_tokens", "lm_head"] diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py deleted file mode 100644 index cd7499869..000000000 --- a/src/axolotl/utils/models.py +++ /dev/null @@ -1,1648 +0,0 @@ -"""Module for models and model loading""" - -# pylint: disable=too-many-lines -import gc -import importlib -import logging -import math -import os -import types -from functools import cached_property -from typing import Any, Dict, Optional, Tuple - -import addict -import bitsandbytes as bnb -import torch -import transformers -import transformers.modeling_utils -from accelerate import init_empty_weights -from bitsandbytes.nn import Params4bit -from peft import ( - LoftQConfig, - PeftConfig, - PeftModel, - PeftModelForCausalLM, - prepare_model_for_kbit_training, -) -from torch import nn -from transformers import ( - AddedToken, - AutoConfig, - AutoModelForCausalLM, - AutoModelForVision2Seq, - AutoProcessor, - AutoTokenizer, - AwqConfig, - BitsAndBytesConfig, - Gemma3ForConditionalGeneration, - GPTQConfig, - Llama4ForConditionalGeneration, - LlavaForConditionalGeneration, - Mistral3ForConditionalGeneration, - MllamaForConditionalGeneration, - PretrainedConfig, - PreTrainedModel, - PreTrainedTokenizerBase, - ProcessorMixin, - Qwen2_5_VLForConditionalGeneration, - Qwen2VLForConditionalGeneration, -) -from transformers.integrations.deepspeed import ( - HfTrainerDeepSpeedConfig, - is_deepspeed_zero3_enabled, -) - -from axolotl.common.architectures import MOE_ARCH_BLOCK -from axolotl.integrations.base import PluginManager -from axolotl.models.mamba import fix_mamba_attn_for_loss -from axolotl.monkeypatch.multipack import ( - SUPPORTED_MULTIPACK_MODEL_TYPES, - patch_for_multipack, -) -from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN -from axolotl.utils.bench import log_gpu_memory_usage -from axolotl.utils.chat_templates import get_chat_template_from_config -from axolotl.utils.dict import DictDefault -from axolotl.utils.distributed import ( - barrier, - get_device_count, - get_device_type, - is_local_main_process, - is_main_process, -) -from axolotl.utils.gradient_checkpointing import ( - hf_grad_checkpoint_disk_offload_wrapper, - hf_grad_checkpoint_offload_wrapper, -) -from axolotl.utils.lora_embeddings import get_linear_embedding_layers -from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant -from axolotl.utils.schemas.enums import RLType - -LOG = logging.getLogger(__name__) -PLUGIN_MANAGER = PluginManager.get_instance() - -MULTIMODAL_AUTO_MODEL_MAPPING = { - "mllama": MllamaForConditionalGeneration, - "llama4": Llama4ForConditionalGeneration, - "llava": LlavaForConditionalGeneration, - "qwen2_vl": Qwen2VLForConditionalGeneration, - "qwen2_5_vl": Qwen2_5_VLForConditionalGeneration, - "mistral3": Mistral3ForConditionalGeneration, - "gemma3": Gemma3ForConditionalGeneration, -} - - -# copied from accelerator.FullyShardedDataParallelPlugin -def get_module_class_from_name(module, name): - """ - Gets a class from a module by its name. - - Args: - module (`torch.nn.Module`): The module to get the class from. - name (`str`): The name of the class. - """ - modules_children = list(module.children()) - if module.__class__.__name__ == name: - return module.__class__ - - if len(modules_children) == 0: - return None - - for child_module in modules_children: - module_class = get_module_class_from_name(child_module, name) - if module_class is not None: - return module_class - - return None - - -def check_model_config(cfg: DictDefault, model_config: PretrainedConfig): - # Set use_cache to False - if hasattr(model_config, "use_cache"): - model_config.use_cache = False - - if cfg.is_multimodal: - # For multimodal configs, use_cache is set in the text_config - if hasattr(model_config, "get_text_config"): - text_config = model_config.get_text_config() - if hasattr(text_config, "use_cache"): - text_config.use_cache = False - else: - raise ValueError( - "No text config found for multimodal model. Please raise an Issue with model details." - ) - - # check if image_size is not set and load image size from model config if available - if ( - cfg.image_size is None - and hasattr(model_config, "vision_config") - and hasattr(model_config.vision_config, "image_size") - ): - cfg.image_size = model_config.vision_config.image_size - LOG.debug(f"Loaded image size: {cfg.image_size} from model config") - - quant_config_exists = ( - hasattr(model_config, "quantization_config") - and model_config.quantization_config - ) - - # Detect compressed-tensors config - is_compressed_tensors_config = ( - quant_config_exists - and model_config.quantization_config.get("quant_method") == "compressed-tensors" - ) - - if is_compressed_tensors_config: - if model_config.quantization_config.get("config_groups"): - LOG.warning( - "Found `config_groups` in a compressed-tensors config. " - "QAT integration with llmcompressor is not tested." - ) - # Skip further quant checks for compressed-tensors - return - - quant_config_method_is_gptq = ( - quant_config_exists - and "quant_method" in model_config.quantization_config - and model_config.quantization_config["quant_method"] == "gptq" - ) - - if cfg.gptq and not quant_config_method_is_gptq: - raise ValueError( - "model_config.quantization_config is not set or quant_method is not set to gptq. " - "Please make sure to point to a GPTQ model." - ) - - lora_modules_to_save = get_linear_embedding_layers(model_config.model_type) - if ( - cfg.adapter - and cfg.tokens - and ( - not cfg.lora_modules_to_save - or not all(x in cfg.lora_modules_to_save for x in lora_modules_to_save) - ) - ): - lora_modules_to_save = ", ".join(map(lambda x: f"`{x}`", lora_modules_to_save)) - raise ValueError( - f"`lora_modules_to_save` not properly set when adding new tokens. Please include [{lora_modules_to_save}] in `lora_modules_to_save`." - ) - - -def load_model_config(cfg): - model_config_name = cfg.base_model_config or cfg.base_model - if not model_config_name and cfg.tokenizer_config: - model_config_name = cfg.tokenizer_config - trust_remote_code = cfg.trust_remote_code is True - config_kwargs = {} - if cfg.revision_of_model: - config_kwargs["revision"] = cfg.revision_of_model - if cfg.num_labels: - # num_labels is used to initialize classifier models - config_kwargs["num_labels"] = cfg.num_labels - try: - model_config = AutoConfig.from_pretrained( - model_config_name, - trust_remote_code=trust_remote_code, - **config_kwargs, - ) - except ValueError as err: - if "mamba" in model_config_name: - return addict.Dict( - { - "model_type": "mamba", - } - ) - raise err - - if cfg.overrides_of_model_config: - for key, val in cfg.overrides_of_model_config.items(): - setattr(model_config, key, val) - - check_model_config(cfg, model_config) - - return model_config - - -def modify_tokenizer_files( - tokenizer_path: str, token_mappings: Dict[int, str], output_dir: str -) -> str: - """ - Modify tokenizer files to replace added_tokens strings, save to output directory, and return the path to the modified tokenizer. - - This only works with reserved tokens that were added to the tokenizer, not tokens already part of the vocab. - - Args: - tokenizer_path: Path or name of the original tokenizer - token_mappings: Dict mapping {token_id (int): new_token_string} - output_dir: Directory to save the modified tokenizer - - Returns: - Path to the modified tokenizer directory - - Ref: https://github.com/huggingface/transformers/issues/27974#issuecomment-1854188941 - """ - - import json - - # Create the tokenizer directory in output_dir if it doesn't exist - tokenizer_dir = os.path.join(output_dir, "tokenizer") - os.makedirs(tokenizer_dir, exist_ok=True) - - if is_local_main_process(): # pylint: disable=too-many-nested-blocks - # Load the tokenizer - temp_tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True) - - # Save the tokenizer to the output directory - temp_tokenizer.save_pretrained(tokenizer_dir) - - # Get the token IDs and map them to their new values - token_id_mappings = { - int(token_id): new_value for token_id, new_value in token_mappings.items() - } - - # 1. Update tokenizer_config.json - added_tokens_decoder - config_path = os.path.join(tokenizer_dir, "tokenizer_config.json") - if os.path.exists(config_path): - with open(config_path, "r", encoding="utf-8") as f: - config_data = json.load(f) - - # Update added_tokens_decoder - if "added_tokens_decoder" in config_data: - for token_id, new_value in token_id_mappings.items(): - token_id_str = str(token_id) - if token_id_str in config_data["added_tokens_decoder"]: - config_data["added_tokens_decoder"][token_id_str][ - "content" - ] = new_value - else: - raise ValueError( - f"Token ID {token_id_str} not found in added_tokens_decoder" - ) - - # Write the updated config back - with open(config_path, "w", encoding="utf-8") as f: - json.dump(config_data, f, indent=2) - - # 2. Update tokenizer.json - added_tokens - tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.json") - if os.path.exists(tokenizer_path): - with open(tokenizer_path, "r", encoding="utf-8") as f: - tokenizer_data = json.load(f) - - # Update added_tokens - if "added_tokens" in tokenizer_data: - for token_id, new_value in token_id_mappings.items(): - for i, token_entry in enumerate(tokenizer_data["added_tokens"]): - if token_entry["id"] == token_id: - tokenizer_data["added_tokens"][i]["content"] = new_value - break - else: - # Reaching this section means the token_id was not found in tokenizer.json added_tokens - raise ValueError( - f"Token ID {token_id} not found in added_tokens" - ) - if "model" in tokenizer_data and "vocab" in tokenizer_data["model"]: - for token_id, new_value in token_id_mappings.items(): - for entry_val, entry_id in tokenizer_data["model"]["vocab"].items(): - if entry_id == token_id: - del tokenizer_data["model"]["vocab"][entry_val] - tokenizer_data["model"]["vocab"][new_value] = token_id - break - - # Write the updated tokenizer data back - with open(tokenizer_path, "w", encoding="utf-8") as f: - json.dump(tokenizer_data, f, indent=2) - - barrier() - return tokenizer_dir - - -def load_tokenizer(cfg): - """Load and configure the tokenizer based on the provided config.""" - model_config = load_model_config(cfg) - tokenizer_kwargs = {} - use_fast = True # this is the default - - if cfg.tokenizer_use_fast is not None: - use_fast = cfg.tokenizer_use_fast - if cfg.tokenizer_legacy is not None: - # True is the default w/ https://github.com/huggingface/transformers/pull/25224 - tokenizer_kwargs["legacy"] = cfg.tokenizer_legacy - - tokenizer_cls = AutoTokenizer - if cfg.tokenizer_type: - tokenizer_cls = getattr(transformers, cfg.tokenizer_type) - - # Set base tokenizer path - tokenizer_path = cfg.tokenizer_config - - # Apply token string overrides if specified - if cfg.added_tokens_overrides: - # Modify tokenizer files and get path to modified tokenizer - tokenizer_path = modify_tokenizer_files( - tokenizer_path, cfg.added_tokens_overrides, output_dir=cfg.output_dir - ) - - tokenizer = tokenizer_cls.from_pretrained( - tokenizer_path, - trust_remote_code=cfg.trust_remote_code or False, - use_fast=use_fast, - **tokenizer_kwargs, - ) - - if ( - tokenizer.__class__.__name__ - in [ - "LlamaTokenizer", - "LlamaTokenizerFast", - "CodeLlamaTokenizer", - "CodeLlamaTokenizerFast", - ] - and hasattr(tokenizer, "pad_token") - and not tokenizer.pad_token - ): - # set a pad_token, but use eos_token so we don't add a new token - tokenizer.pad_token = LLAMA_DEFAULT_EOS_TOKEN - - if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast": - tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - os.environ["TOKENIZERS_PARALLELISM"] = "false" - - # Mistral's official FA implementation requires left padding - if cfg.is_mistral_derived_model and cfg.flash_attention and not cfg.sample_packing: - tokenizer.padding_side = "left" - - # Qwen base only has single token, so we need to set the special tokens - if cfg.is_qwen_derived_model: - token_ids = ["bos_token_id", "eos_token_id", "pad_token_id", "unk_token_id"] - for attr_name in token_ids: - if getattr(tokenizer, attr_name) is None: - setattr(tokenizer, attr_name, tokenizer.eod_id) - - token_names = ["bos_token", "eos_token", "pad_token", "unk_token"] - for attr_name in token_names: - if getattr(tokenizer, attr_name) is None: - setattr(tokenizer, attr_name, "<|endoftext|>") - - additional_special_tokens = None - if cfg.special_tokens: - special_tokens = cfg.special_tokens.to_dict() - additional_special_tokens = special_tokens.pop( - "additional_special_tokens", None - ) - lora_modules_to_save = get_linear_embedding_layers(model_config.model_type) - for k, val in special_tokens.items(): - # check if new special token is not already in tokenizer and - # is adapter training to make sure lora_modules_to_save is set - # pylint: disable=too-many-boolean-expressions - if ( - (getattr(tokenizer, k) is None or getattr(tokenizer, k) != val) - and (len(tokenizer.encode(val, add_special_tokens=False)) > 2) - and cfg.adapter - and ( - not cfg.lora_modules_to_save - or not all( - x in cfg.lora_modules_to_save for x in lora_modules_to_save - ) - ) - and k != "pad_token" - ): - lora_modules_to_save = ", ".join( - [f"`{x}`" for x in lora_modules_to_save] - ) - raise ValueError( - f"Please set lora_modules_to_save to [{lora_modules_to_save}] when using an adapter and changing the special tokens." - ) - - tokenizer.add_special_tokens( - {k: AddedToken(val, rstrip=False, lstrip=False, normalized=False)} - ) - - # If we add bos_token and eos_token, we need to update the post processor to - # handle them correctly. - # https://github.com/huggingface/transformers/pull/24132 - bos_or_eos_in_special_tokens = ( - "bos_token" in cfg.special_tokens and "eos_token" in cfg.special_tokens - ) - if ( - tokenizer.__class__.__name__ - in ( - "LlamaTokenizerFast", - "CodeLlamaTokenizerFast", - ) - and bos_or_eos_in_special_tokens - ): - tokenizer.update_post_processor() - - if cfg.tokens: - tokenizer.add_tokens( - [ - AddedToken(token, rstrip=False, lstrip=False, normalized=False) - for token in cfg.tokens - ] - ) - - # Additional special tokens are a List, and need to be treated differently than regular special - # tokens. We add them after we have called `add_tokens` in case these additional special tokens - # are new tokens. - # - # Usage: - # - # ```py - # special_tokens: - # additional_special_tokens: ["<|im_start|>", "<|im_end|>"] - # ``` - if additional_special_tokens is not None: - tokenizer.add_special_tokens( - {"additional_special_tokens": additional_special_tokens} - ) - - if is_main_process(use_environ=True): - LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}") - LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}") - LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}") - LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}") - - if cfg.chat_template: - chat_template_string = get_chat_template_from_config( - cfg=cfg, - tokenizer=tokenizer, - ) - if cfg.default_system_message and cfg.chat_template == "chatml": - chat_template_string = chat_template_string.replace( - "You are a helpful assistant.", cfg.default_system_message - ) - - tokenizer.chat_template = chat_template_string - else: - LOG.info( - "No Chat template selected. Consider adding a chat template for easier inference." - ) - return tokenizer - - -def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase): - processor_kwargs: Dict[str, Any] = {} # do we actually need this? - - processor_cls = AutoProcessor - if cfg.processor_type: - processor_cls = getattr(transformers, cfg.processor_type) - - processor = processor_cls.from_pretrained( - cfg.processor_config, - trust_remote_code=cfg.trust_remote_code or False, - tokenizer=tokenizer, - **processor_kwargs, - ) - - # Attempt to load image size from processor if available - if ( - cfg.image_size is None - and hasattr(processor, "size") - and any(dim in processor.size for dim in ["width", "height"]) - ): - im_width = None - im_height = None - if "width" in processor.size: - im_width = processor.size["width"] - if "height" in processor.size: - im_height = processor.size["height"] - - # If both width and height are set, use a tuple - if im_width is not None and im_height is not None: - cfg.image_size = (im_width, im_height) - # If only width is set, use as integer - elif im_width is not None: - cfg.image_size = im_width - # If only height is set, use as integer - elif im_height is not None: - cfg.image_size = im_height - - LOG.debug(f"Loaded image size: {cfg.image_size} from processor") - - return processor - - -class ModelLoader: - """ - ModelLoader: managing all the config and monkey patches while loading model - """ - - def __init__( - self, - cfg: DictDefault, - tokenizer: PreTrainedTokenizerBase, - *, - processor: ProcessorMixin = None, # pylint: disable=unused-argument - inference: bool = False, - reference_model: bool = False, - **kwargs, # pylint: disable=unused-argument - ) -> None: - self.cfg = cfg - self.tokenizer = tokenizer - self.inference: bool = inference - self.reference_model: bool = reference_model - - # init model kwargs - self.model_kwargs: Dict[str, Any] = {} - if cfg.overrides_of_model_kwargs: - for key, val in cfg.overrides_of_model_kwargs.items(): - self.model_kwargs[key] = val - - # init model - self.model: PreTrainedModel - self.base_model = cfg.base_model - self.model_type = cfg.type_of_model - - # init model config - self.model_config = load_model_config(cfg) - - self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name - - def apply_patches(self) -> None: - if self.cfg.xformers_attention and self.cfg.sample_packing: - from axolotl.monkeypatch.attention import patch_xformers_attn_over_fa2 - - patch_xformers_attn_over_fa2() - self.cfg.flash_attention = True - if self.cfg.fsdp_config and str(self.cfg.fsdp_config.fsdp_version) == "2": - from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp_utils - - patch_accelerate_fsdp_utils() - - if self.cfg.adapter and self.cfg.embeddings_skip_upcast: - from axolotl.monkeypatch.peft.utils import patch_peft_prep_code - - patch_peft_prep_code() - - if self.cfg.flex_attention: - from axolotl.monkeypatch.attention.flex_attn import ( - patch_flex_make_mask, - patch_flex_wrapper, - ) - - flex_attn_compile_kwargs = self.cfg.flex_attn_compile_kwargs or {} - patch_flex_wrapper(**flex_attn_compile_kwargs) - patch_flex_make_mask() - - # patch gemma3 conditional generation forward before loading plugins - # as it could be overridden by plugins - if self.cfg.model_config_type == "llama4": - if self.cfg.llama4_linearized_experts: - from axolotl.monkeypatch.models.llama4.modeling import ( - patch_llama4_linearized_modeling, - ) - - patch_llama4_linearized_modeling() - - if self.cfg.model_config_type == "gemma3": - from axolotl.monkeypatch.gemma3 import ( - patch_gemma3conditionalgeneration_forward, - ) - - patch_gemma3conditionalgeneration_forward() - - # load any patches from plugins - - PLUGIN_MANAGER.pre_model_load(self.cfg) - - # monkey patch to allow additional Accelerator init kwargs - if self.cfg.fp8: - from axolotl.monkeypatch.trainer_accelerator_args import ( - patch_create_accelerate_code_for_fp8, - ) - - patch_create_accelerate_code_for_fp8() - - if self.cfg.adapter: - from axolotl.monkeypatch.transformers_fa_utils import ( - patch_fa_peft_integration, - ) - - patch_fa_peft_integration() - - if self.cfg.gradient_checkpointing in ["unsloth", "offload"]: - transformers.modeling_utils.checkpoint = hf_grad_checkpoint_offload_wrapper - if self.cfg.gradient_checkpointing == "offload_disk": - transformers.modeling_utils.checkpoint = ( - hf_grad_checkpoint_disk_offload_wrapper - ) - - if self.cfg.flash_attention: - self.patch_attention() - - if self.cfg.sample_packing and self.cfg.s2_attention: - raise ValueError( - "Received `sample_packing=true` and `s2_attention=true`; however, \ - shifted-sparse attention does not currently support sample packing." - ) - - if ( - self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES - and (self.cfg.flash_attention or self.cfg.flex_attention) - and self.cfg.sample_packing - ): - if "auto_map" in self.model_config: - try: - auto_map_config = self.model_config["auto_map"] - except TypeError: - auto_map_config = self.model_config.auto_map - has_remote_code = "AutoModelForCausalLM" in auto_map_config - else: - has_remote_code = False - - if has_remote_code and self.cfg.trust_remote_code is False: - # if explicitly set in the YAML, we should prefer that, for example if explicitly disabled - has_remote_code = self.cfg.trust_remote_code - patch_for_multipack( - self.cfg.model_config_type, - model_name=self.cfg.base_model, - has_remote_code=has_remote_code, - ) - - if self.cfg.is_llama_derived_model: - self.patch_loss_llama() - elif self.cfg.is_llama_derived_model: - self.patch_llama_derived_model() - - if ( - self.cfg.model_config_type == "mistral" - and self.cfg.flash_attn_cross_entropy_loss - ): - from axolotl.monkeypatch.mistral_attn_hijack_flash import ( - patch_mistral_cross_entropy, - ) - - patch_mistral_cross_entropy() - - if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o: - from axolotl.monkeypatch.lora_kernels import patch_self_attn_lora - - patch_self_attn_lora(self.cfg) - - def patch_attention(self) -> None: - if hasattr(self.model_config, "model_type"): - if self.model_config.model_type == "mllama" and self.cfg.flash_attention: - from axolotl.monkeypatch.attention.mllama import patch_mllama - - patch_mllama() - - if self.model_config.model_type == "btlm": - from axolotl.monkeypatch.btlm_attn_hijack_flash import ( - replace_btlm_attn_with_flash_attn, - ) - - replace_btlm_attn_with_flash_attn(self.cfg.base_model) - - if ( - self.model_config.model_type == "stablelm_epoch" - and self.cfg.sample_packing - ): - from axolotl.monkeypatch.stablelm_attn_hijack_flash import ( - replace_stablelm_attn_with_flash_attn, - ) - - replace_stablelm_attn_with_flash_attn(self.cfg.base_model) - - @cached_property - def has_flash_attn(self) -> bool: - """Check if flash attention is installed""" - return importlib.util.find_spec("flash_attn") is not None - - def patch_loss_llama(self) -> None: - """Patch loss functions and other optimizations""" - if self.has_flash_attn: - from axolotl.monkeypatch.llama_attn_hijack_flash import ( - patch_fa_llama_cross_entropy, - patch_llama_rms_norm, - ) - - if self.cfg.flash_attn_cross_entropy and self.has_flash_attn: - patch_fa_llama_cross_entropy() - elif self.cfg.unsloth_cross_entropy_loss: - from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch - - integrate_cross_entropy_loss_patch(model_type="llama") - - if self.cfg.flash_attn_rms_norm and self.has_flash_attn: - patch_llama_rms_norm() - elif self.cfg.unsloth_rms_norm: - from axolotl.monkeypatch.unsloth_ import patch_unsloth_layernorm - - patch_unsloth_layernorm() - - if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o: - from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora - - patch_self_attn_lora() - - def patch_llama_derived_model(self): - """Modify all llama derived models in one block""" - self.patch_loss_llama() - - if self.cfg.flash_attention: - from axolotl.monkeypatch.llama_attn_hijack_flash import ( - replace_llama_attn_with_flash_attn, - ) - - if self.cfg.sample_packing: - if self.cfg.device not in ["mps", "cpu"] and not self.inference: - LOG.info("patching with flash attention for sample packing") - replace_llama_attn_with_flash_attn( - packed=True, - cross_entropy=self.cfg.flash_attn_cross_entropy, - rms_norm=self.cfg.flash_attn_rms_norm, - ) - elif self.cfg.s2_attention: - LOG.info("patching w/ flash-enabled, shifted-sparse attention") - replace_llama_attn_with_flash_attn( - packed=False, - cross_entropy=self.cfg.flash_attn_cross_entropy, - rms_norm=self.cfg.flash_attn_rms_norm, - use_shifted_sparse_attn=True, - ) - elif self.cfg.flash_attn_cross_entropy or self.cfg.flash_attn_rms_norm: - replace_llama_attn_with_flash_attn( - packed=False, - cross_entropy=self.cfg.flash_attn_cross_entropy, - rms_norm=self.cfg.flash_attn_rms_norm, - ) - elif self.cfg.xformers_attention: - from axolotl.monkeypatch.llama_attn_hijack_xformers import ( - hijack_llama_attention, - ) - - LOG.info("patching with xformers attention") - hijack_llama_attention() - elif self.cfg.sample_packing: - from axolotl.monkeypatch.llama_patch_multipack import ( - hijack_llama_prepare_4d_mask, - ) - - LOG.info("patching llama _prepare_4d_causal_attention_mask*") - hijack_llama_prepare_4d_mask() - elif self.cfg.s2_attention: - raise NotImplementedError( - "Shifted-sparse attention not currently implemented without flash attention." - ) - - def set_auto_model_loader(self): - """ - Set self.auto_model_loader. Defaults to `transformers.AutoModelForCausalLM` - (set at `__init__`). When using a multimodal model, `self.auto_model_loader` - should be set according to the type of the model. - """ - if self.cfg.is_multimodal: - self.auto_model_loader = MULTIMODAL_AUTO_MODEL_MAPPING.get( - self.model_config.model_type, AutoModelForVision2Seq - ) - - def set_device_map_config(self) -> None: - device_map = self.cfg.device_map - max_memory = self.cfg.max_memory - - if self.cfg.gpu_memory_limit: - gpu_memory_limit = ( - str(self.cfg.gpu_memory_limit) + "GiB" - if isinstance(self.cfg.gpu_memory_limit, int) - else self.cfg.gpu_memory_limit - ) - - max_memory = {} - num_device = get_device_count() - for i in range(num_device): - max_memory[i] = gpu_memory_limit - max_memory["cpu"] = "256GiB" # something sufficiently large to fit anything - - if max_memory is not None: - # Based on https://github.com/togethercomputer/OpenChatKit/blob/main/inference/bot.py - from accelerate import infer_auto_device_map - - with init_empty_weights(): - model_canvas = self.auto_model_loader.from_config( - self.model_config, - trust_remote_code=self.cfg.trust_remote_code or False, - ) - model_canvas.tie_weights() - device_map = infer_auto_device_map( - model_canvas, - max_memory=max_memory, - dtype=self.cfg.torch_dtype, - ) - # We can discard max_memory now as we have a device map set up for us - max_memory = None - - self.model_kwargs["device_map"] = device_map - self.model_kwargs["torch_dtype"] = self.cfg.torch_dtype - - cur_device = get_device_type() - if "mps" in str(cur_device): - self.model_kwargs["device_map"] = "mps:0" - elif "npu" in str(cur_device): - self.model_kwargs["device_map"] = "npu:0" - - # TODO can we put the reference model on it's own gpu? I think we have to move logits around to calculate loss - # if cfg.rl: - # if torch.cuda.device_count() > 1: - # if reference_model: - # model_kwargs["device_map"] = "cuda:" + str( - # torch.cuda.current_device() + 1 - # ) - # else: - # model_kwargs["device_map"] = "cuda:" + str(torch.cuda.current_device()) - - if is_deepspeed_zero3_enabled(): - del self.model_kwargs["device_map"] - - def set_quantization_config(self) -> None: - self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit - self.model_kwargs["load_in_4bit"] = self.cfg.load_in_4bit - - if self.cfg.gptq: - if not hasattr(self.model_config, "quantization_config"): - LOG.warning( - "model config does not contain quantization_config information" - ) - else: - if self.cfg.gptq_disable_exllama is not None: - self.model_config.quantization_config["disable_exllama"] = ( - self.cfg.gptq_disable_exllama - ) - self.model_kwargs["quantization_config"] = GPTQConfig( - **self.model_config.quantization_config - ) - if ( - self.cfg.adapter in ["qlora", "lora"] - and hasattr(self.model_config, "quantization_config") - and self.model_config.quantization_config["quant_method"] - in ["gptq", "awq", "bitsandbytes"] - ): - if self.model_config.quantization_config["quant_method"] == "gptq": - self.model_kwargs["quantization_config"] = GPTQConfig( - **self.model_config.quantization_config - ) - elif self.model_config.quantization_config["quant_method"] == "awq": - self.model_kwargs["quantization_config"] = AwqConfig( - **self.model_config.quantization_config - ) - elif ( - self.model_config.quantization_config["quant_method"] == "bitsandbytes" - ): - self.model_kwargs["quantization_config"] = BitsAndBytesConfig( - **self.model_config.quantization_config - ) - elif self.cfg.adapter == "qlora" and self.model_kwargs["load_in_4bit"]: - bnb_config = { - "load_in_4bit": True, - "llm_int8_threshold": 6.0, - "llm_int8_has_fp16_weight": False, - "bnb_4bit_compute_dtype": self.cfg.torch_dtype, - "bnb_4bit_use_double_quant": True, - "bnb_4bit_quant_type": "nf4", - "bnb_4bit_quant_storage": torch.bfloat16, - } - if self.cfg.model_config_type in ["jamba", "qwen2_moe"] and not ( - self.cfg.deepspeed or self.cfg.fsdp - ): - # for some reason, this causes the loss to be off by an order of magnitude - # but deepspeed needs this still in bfloat16 - bnb_config["bnb_4bit_quant_storage"] = torch.float32 - - if self.cfg.bnb_config_kwargs: - bnb_config.update(self.cfg.bnb_config_kwargs) - - self.model_kwargs["quantization_config"] = BitsAndBytesConfig( - **bnb_config, - ) - elif self.cfg.adapter == "lora" and self.model_kwargs["load_in_8bit"]: - bnb_config = { - "load_in_8bit": True, - } - # Exclude mamba blocks from int8 quantization for jamba - if self.cfg.model_config_type == "jamba": - bnb_config["llm_int8_skip_modules"] = ["mamba"] - self.model_kwargs["quantization_config"] = BitsAndBytesConfig( - **bnb_config, - ) - - # no longer needed per https://github.com/huggingface/transformers/pull/26610 - if "quantization_config" in self.model_kwargs or self.cfg.gptq: - self.model_kwargs.pop("load_in_8bit", None) - self.model_kwargs.pop("load_in_4bit", None) - - def set_attention_config(self) -> None: - """ - sample packing uses custom FA2 patch - """ - if self.cfg.flex_attention: - self.model_kwargs["attn_implementation"] = "flex_attention" - self.model_config._attn_implementation = ( # pylint: disable=protected-access - "flex_attention" - ) - - elif self.cfg.flash_attention: - if not self.cfg.sample_packing and self.cfg.s2_attention: - pass - self.model_kwargs["attn_implementation"] = "flash_attention_2" - self.model_config._attn_implementation = ( # pylint: disable=protected-access - "flash_attention_2" - ) - elif self.cfg.sdp_attention: - self.model_kwargs["attn_implementation"] = "sdpa" - self.model_config._attn_implementation = ( # pylint: disable=protected-access - "sdpa" - ) - elif self.cfg.eager_attention: - self.model_kwargs["attn_implementation"] = "eager" - self.model_config._attn_implementation = ( # pylint: disable=protected-access - "eager" - ) - - if self.cfg.low_cpu_mem_usage: - self.model_kwargs["low_cpu_mem_usage"] = True - - def build_model(self, qlora_fsdp) -> bool: - def _configure_zero3_memory_efficient_loading(): - """ - Set the deepspeed config to load the model into RAM first before moving to VRAM. - - We need to return hf_ds_cfg as it needs to exist before model loading. - """ - hf_ds_cfg = None - - if os.getenv("ACCELERATE_DEEPSPEED_ZERO_STAGE") == "3": - hf_ds_cfg = HfTrainerDeepSpeedConfig(self.cfg.deepspeed) - hf_ds_cfg.fill_match( - "train_micro_batch_size_per_gpu", self.cfg.micro_batch_size - ) - hf_ds_cfg.fill_match( - "gradient_accumulation_steps", self.cfg.gradient_accumulation_steps - ) - hf_ds_cfg.fill_match( - "train_batch_size", - int(os.getenv("WORLD_SIZE", "1")) - * self.cfg.micro_batch_size - * self.cfg.gradient_accumulation_steps, - ) - if "device_map" in self.model_kwargs: - del self.model_kwargs["device_map"] - - transformers.modeling_utils.is_deepspeed_zero3_enabled = lambda: True - transformers.integrations.deepspeed.is_deepspeed_zero3_enabled = ( - lambda: True - ) - - return hf_ds_cfg - - skip_move_to_device = False - if ( # pylint: disable=condition-evals-to-constant) - (self.cfg.fsdp and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading) - and not qlora_fsdp - and False - ): - self.model = load_sharded_model( - self.base_model, - self.model_config, - self.cfg, - torch_dtype=self.cfg.torch_dtype, - ) - skip_move_to_device = True - elif ( - qlora_fsdp - and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading - and ( - self.cfg.model_config_type == "dbrx" - or self.cfg.qlora_sharded_model_loading - ) - ): - quant_storage = self.cfg.torch_dtype - quantization_config = hasattr( - self.model_config, "quantization_config" - ) and getattr(self.model_config, "quantization_config") - quantization_config = ( - quantization_config or self.model_kwargs["quantization_config"] - ) - self.model = load_sharded_model_quant( - self.base_model, - self.model_config, - self.cfg, - quant_storage=quant_storage, - quantization_config=quantization_config, - ) - skip_move_to_device = True - elif ( - self.model_config.model_type in ["llama", "llama4"] - and not self.cfg.trust_remote_code - and not self.cfg.gptq - ): - # TODO do we need to open this up for all models? - if self.cfg.fsdp and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading: - skip_move_to_device = True - if "device_map" in self.model_kwargs: - del self.model_kwargs["device_map"] - - _ = _configure_zero3_memory_efficient_loading() - - # Load model with random initialization if specified - if self.cfg.random_init_weights: - # AutoModel classes support the from_config method - if self.auto_model_loader in [ - AutoModelForCausalLM, - AutoModelForVision2Seq, - ]: - self.model = self.auto_model_loader.from_config( - config=self.model_config, - ) - else: - self.model = self.auto_model_loader( - config=self.model_config, - ) - else: - self.model = self.auto_model_loader.from_pretrained( - self.base_model, - config=self.model_config, - **self.model_kwargs, - ) - - # TODO (MengqingCao) split these patches seperately - if self.cfg.flash_attention and not self.inference: - from axolotl.monkeypatch.llama_attn_hijack_flash import ( - is_xformers_swiglu_available, - replace_llama_mlp_with_swiglu, - replace_llama_qkv_with_fused, - ) - - if self.cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available(): - LOG.info("patching with SwiGLU") - replace_llama_mlp_with_swiglu(self.model) - - if self.cfg.flash_attn_fuse_qkv: - LOG.info("patching with fused QKV") - replace_llama_qkv_with_fused(self.model) - elif self.model_type == "MambaLMHeadModel": - # FIXME this is janky at best and hacked together to make it work - MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name - - self.model_kwargs["dtype"] = self.model_kwargs["torch_dtype"] - self.model_kwargs["device"] = torch.cuda.current_device() - del self.model_kwargs["torch_dtype"] - del self.model_kwargs["device_map"] - - self.model = MambaLMHeadModel.from_pretrained( - self.base_model, - **self.model_kwargs, - ) - elif ( - self.model_type - and self.model_type != "AutoModelForCausalLM" - and not self.cfg.trust_remote_code - ): - if self.cfg.gptq: - self.model = self.auto_model_loader.from_pretrained( - self.base_model, - config=self.model_config, - trust_remote_code=self.cfg.trust_remote_code or False, - **self.model_kwargs, - ) - else: - self.model = getattr(transformers, self.model_type).from_pretrained( - self.base_model, - config=self.model_config, - trust_remote_code=self.cfg.trust_remote_code or False, - **self.model_kwargs, - ) - else: - if self.cfg.gptq: - self.model = self.auto_model_loader.from_pretrained( - self.base_model, - config=self.model_config, - trust_remote_code=self.cfg.trust_remote_code or False, - **self.model_kwargs, - ) - else: - if ( - self.cfg.fsdp - and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading - ): - # disabling either of these two still leads to VRAM spike before setting back down - skip_move_to_device = True - if "device_map" in self.model_kwargs: - del self.model_kwargs["device_map"] - - _ = _configure_zero3_memory_efficient_loading() - - self.model = self.auto_model_loader.from_pretrained( - self.base_model, - config=self.model_config, - trust_remote_code=self.cfg.trust_remote_code or False, - **self.model_kwargs, - ) - if is_deepspeed_zero3_enabled(): - skip_move_to_device = True - - return skip_move_to_device - - def adjust_model_config(self) -> None: - if ( - hasattr(self.model, "config") - and hasattr(self.model.config, "max_position_embeddings") - and self.model.config.max_position_embeddings - and self.cfg.sequence_len > self.model.config.max_position_embeddings - ): - LOG.warning( - f"increasing model.config.max_position_embeddings from {self.model.config.max_position_embeddings} to {self.cfg.sequence_len}" - ) - self.model.config.max_position_embeddings = self.cfg.sequence_len - - if ( - hasattr(self.model, "config") - and hasattr(self.model.config, "bos_token_id") - and self.model.config.bos_token_id - and self.model.config.bos_token_id != self.tokenizer.bos_token_id - ): - self.model.config.bos_token_id = self.tokenizer.bos_token_id - - if ( - hasattr(self.model, "config") - and hasattr(self.model.config, "eos_token_id") - and self.model.config.eos_token_id - and self.model.config.eos_token_id != self.tokenizer.eos_token_id - ): - self.model.config.eos_token_id = self.tokenizer.eos_token_id - - def set_z3_leaf_modules(self) -> None: - from deepspeed.utils import ( # pylint: disable=no-name-in-module - set_z3_leaf_modules, - ) - - if self.cfg.model_config_type in MOE_ARCH_BLOCK: - moe_blocks = MOE_ARCH_BLOCK[self.cfg.model_config_type] - moe_blocks = [moe_blocks] if isinstance(moe_blocks, str) else moe_blocks - set_z3_leaf_modules( - self.model, - [ - get_module_class_from_name(self.model, module_name) - for module_name in moe_blocks - ], - ) - - def prepare_model(self, qlora_fsdp: bool) -> None: - skip_prepare_model_for_kbit_training = False - if self.cfg.model_config_type == "qwen" and self.cfg.adapter == "lora": - # Qwen doesn't play nicely with LoRA if this is enabled - skip_prepare_model_for_kbit_training = True - - loftq_bits = ( - self.cfg.peft - and self.cfg.peft.loftq_config - and self.cfg.peft.loftq_config.loftq_bits - ) - if self.cfg.adapter == "lora" and loftq_bits: - skip_prepare_model_for_kbit_training = True - - if qlora_fsdp or ( - self.cfg.fsdp and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading - ): - # make sure everything is in the same dtype - skip_prepare_model_for_kbit_training = True - - if is_deepspeed_zero3_enabled(): - skip_prepare_model_for_kbit_training = True - - if ( - not skip_prepare_model_for_kbit_training - and self.cfg.adapter in ["lora", "qlora"] - and (self.cfg.load_in_8bit or self.cfg.load_in_4bit) - ): - LOG.info("converting PEFT model w/ prepare_model_for_kbit_training") - self.model = prepare_model_for_kbit_training( - self.model, use_gradient_checkpointing=self.cfg.gradient_checkpointing - ) - - def convert_embedding_modules_dtype( - self, embedding_modules, dist_dtype, before_kbit_train_or_finetune - ) -> None: - for name, module in self.model.named_modules(): - if "norm" in name: - module.to(dist_dtype) - if before_kbit_train_or_finetune: - if name.endswith(".gate"): - module.to(dist_dtype) - if self.model_config.model_type == "btlm": - # don't upcast lm_head for btlm - continue - if any(m in name for m in embedding_modules): - if hasattr(module, "weight"): - module.to(dist_dtype) - - # TODO: Deprecate this. - def apply_unsloth_lora_patch(self) -> None: - if self.cfg.unsloth_lora_mlp: - from axolotl.monkeypatch.unsloth_ import integrate_lora_mlp_patch - - integrate_lora_mlp_patch(self.model) - if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o: - from axolotl.monkeypatch.unsloth_ import integrate_lora_patch - - integrate_lora_patch(self.model, self.cfg) - if self.cfg.unsloth_rope: - from axolotl.monkeypatch.unsloth_ import integrate_rope_embeddings - - integrate_rope_embeddings() - - def apply_lora_patch(self) -> None: - if ( - self.cfg.lora_mlp_kernel - or self.cfg.lora_qkv_kernel - or self.cfg.lora_o_kernel - ): - from axolotl.monkeypatch.lora_kernels import apply_lora_kernel_patches - - apply_lora_kernel_patches(self.model, self.cfg) - - def load_model(self) -> Tuple[PreTrainedModel, Optional[PeftConfig]]: - self.apply_patches() - self.set_auto_model_loader() - self.set_device_map_config() - if self.cfg.revision_of_model: - self.model_kwargs["revision"] = self.cfg.revision_of_model - self.set_quantization_config() - self.set_attention_config() - - qlora_fsdp = self.cfg.fsdp and self.cfg.adapter == "qlora" - skip_move_to_device = False - - try: - skip_move_to_device = self.build_model(qlora_fsdp) - PLUGIN_MANAGER.post_model_build(self.cfg, self.model) - except Exception as err: # pylint: disable=broad-exception-caught - LOG.exception(err) - raise err - - if isinstance(self.model, (PeftModel, PeftModelForCausalLM)) and not qlora_fsdp: - self.model = self.model.merge_and_unload() - - embeddings_len = ( - math.ceil(len(self.tokenizer) / 32) * 32 - if self.cfg.resize_token_embeddings_to_32x - else len(self.tokenizer) - ) - if hasattr(self.model, "get_input_embeddings") and ( - self.model.get_input_embeddings().num_embeddings < embeddings_len - or ( - self.model.get_input_embeddings().num_embeddings > embeddings_len - and self.cfg.shrink_embeddings - ) - ): - resize_kwargs = {} - if self.cfg.mean_resizing_embeddings is not None and not ( - self.model_config.model_type == "llava" - ): - resize_kwargs["mean_resizing"] = self.cfg.mean_resizing_embeddings - self.model.resize_token_embeddings(embeddings_len, **resize_kwargs) - else: - self.model.tie_weights() - - self.adjust_model_config() - - # log device memory usage - if hasattr(self.model, "device") and self.model.device.type in ( - "cuda", - "mps", - "npu", - ): - log_gpu_memory_usage(LOG, "after model load", self.model.device) - - # make sure these are fp32 per Ramesh et al. (2021) - embedding_modules = get_linear_embedding_layers(self.cfg.model_config_type) - if not self.cfg.fsdp: - # we don't run this during FSDP because this will leave mixed - # float and bfloat16 dtypes in the model which FSDP doesn't like - if self.cfg.load_in_4bit and self.cfg.embeddings_skip_upcast: - embedding_modules = [] - self.convert_embedding_modules_dtype( - embedding_modules, - dist_dtype=torch.float32, - before_kbit_train_or_finetune=True, - ) - - if is_deepspeed_zero3_enabled(): - self.set_z3_leaf_modules() - - needs_fa2_dtype = self.cfg.adapter or self.cfg.fsdp - if self.cfg.adapter in ["lora", "qlora"]: - needs_fa2_dtype = True - if self.cfg.gradient_checkpointing: - self.model.gradient_checkpointing_enable( - gradient_checkpointing_kwargs=self.cfg.gradient_checkpointing_kwargs - ) - - self.prepare_model(qlora_fsdp) - - should_convert = ( - # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to - # convert them back to fp16/bf16 for flash-attn compatibility. - ( - (needs_fa2_dtype or self.cfg.flash_attention or self.cfg.flex_attention) - and not qlora_fsdp - ) - or self.cfg.cut_cross_entropy # Cut cross entropy requires embedding layers to be in fp16/bf16 for backward pass - ) - - if should_convert: - LOG.info("Converting modules to %s", self.cfg.torch_dtype) - self.convert_embedding_modules_dtype( - embedding_modules=embedding_modules, - dist_dtype=self.cfg.torch_dtype, - before_kbit_train_or_finetune=False, - ) - - PLUGIN_MANAGER.pre_lora_load(self.cfg, self.model) - - # --------------------------------------------------------- - # load lora or adapter - # --------------------------------------------------------- - lora_config = None - if not self.reference_model or self.cfg.lora_model_dir: - # if we're not loading the reference model, then we're loading the model for training - # then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config - if ( - self.cfg.adapter - and self.cfg.rl in [RLType.DPO, RLType.IPO, RLType.KTO] - and not self.cfg.merge_lora - ): - _, lora_config = load_lora( - self.model, self.cfg, inference=False, config_only=True - ) - else: - self.model, lora_config = load_adapter( - self.model, self.cfg, self.cfg.adapter - ) - - # --------------------------------------------------------- - # put model to accelerator - # --------------------------------------------------------- - if ( - self.cfg.ddp - and not self.cfg.load_in_8bit - and not (self.cfg.rl and self.cfg.load_in_4bit) - and not skip_move_to_device - ): - # TODO revaldate this conditional - self.model.to(f"{str(get_device_type())}:{self.cfg.local_rank}") - - if get_device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1: - setattr(self.model, "is_parallelizable", True) - setattr(self.model, "model_parallel", True) - - # --------------------------------------------------------- - # parameters that require gradient updates - # --------------------------------------------------------- - requires_grad = [] - for name, param in self.model.named_parameters(recurse=True): - if param.requires_grad: - requires_grad.append(f"{name}: {param.requires_grad}") - if len(requires_grad) == 0: - LOG.warning("there are no parameters that require gradient updates") - - if self.cfg.flash_optimum: - from optimum.bettertransformer import BetterTransformer - - self.model = BetterTransformer.transform(self.model) - - if self.cfg.adapter is not None: - log_gpu_memory_usage(LOG, "after adapters", self.model.device) - - self.apply_unsloth_lora_patch() - self.apply_lora_patch() - - for _ in range(3): - gc.collect() - torch.cuda.empty_cache() - - PLUGIN_MANAGER.post_model_load(self.cfg, self.model) - return self.model, lora_config - - -def load_model( - cfg: DictDefault, - tokenizer: PreTrainedTokenizerBase, - *, - processor: ProcessorMixin = None, # pylint: disable=unused-argument - inference: bool = False, - reference_model: bool = False, - **kwargs, # pylint: disable=unused-argument -) -> Tuple[PreTrainedModel, Optional[PeftConfig]]: - """ - Load a model for a given configuration and tokenizer. - """ - model_loader = ModelLoader( - cfg, - tokenizer, - processor=processor, - inference=inference, - reference_model=reference_model, - **kwargs, - ) - return model_loader.load_model() - - -def load_adapter(model, cfg, adapter, inference=False): - # type: (PreTrainedModel, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]] - - if adapter is None: - return model, None - if hasattr(model, "enable_input_require_grads"): - model.enable_input_require_grads() - if adapter in ["lora", "qlora"]: - model, lora_config = load_lora(model, cfg, inference=inference) - PLUGIN_MANAGER.post_lora_load(cfg, model) - return model, lora_config - if adapter == "llama-adapter": - model, lora_config = load_llama_adapter(model, cfg) - PLUGIN_MANAGER.post_lora_load(cfg, model) - return model, lora_config - - raise NotImplementedError(f"{adapter} peft adapter not available") - - -def load_llama_adapter(model, cfg): - # type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]] - from peft import AdaptionPromptConfig, get_peft_model - - peft_config = AdaptionPromptConfig( - adapter_layers=cfg.peft_adapter.layers, # layers (L) - adapter_len=cfg.peft_adapter.len, # prompt length (K) - task_type="CAUSAL_LM", - ) - - if cfg.lora_model_dir: - LOG.debug("Loading pretrained PEFT - llama_adapter") - model = PeftModel.from_pretrained( - model, - cfg.lora_model_dir, - torch_dtype=torch.float16, - ) - else: - model = get_peft_model(model, peft_config) - - model.print_trainable_parameters() - - return model, peft_config - - -def find_all_linear_names(model): - cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear) - lora_module_names = set() - for name, module in model.named_modules(): - if ( - isinstance(module, cls) - or "Linear" in module.__class__.__name__ - and module.__class__.__name__ not in ("LlamaLinearScalingRotaryEmbedding",) - ): - names = name.split(".") - lora_module_names.add(names[0] if len(names) == 1 else names[-1]) - - embedding_modules = get_linear_embedding_layers(model.config.model_type) - output_embedding = embedding_modules[1] - if output_embedding in lora_module_names: # needed for 16-bit - lora_module_names.remove(output_embedding) - - return list(lora_module_names) - - -def setup_quantized_meta_for_peft(model: nn.Module): - """Replaces `quant_state.to` with a dummy function to prevent PEFT from moving `quant_state` to meta device""" - - def temp_to_method(self, *args, **kwargs): # pylint: disable=unused-argument - return self - - for param in model.parameters(): - if isinstance(param, Params4bit): - param.quant_state._orig_to = ( # pylint: disable=protected-access - param.quant_state.to - ) - param.quant_state.to = types.MethodType(temp_to_method, param.quant_state) - - -def setup_quantized_peft_meta_for_training(model: nn.Module): - """Replaces dummy `quant_state.to` method with the original function to allow training to continue""" - for param in model.parameters(): - if isinstance(param, Params4bit) and hasattr(param.quant_state, "_orig_to"): - param.quant_state.to = ( - param.quant_state._orig_to # pylint: disable=protected-access - ) - param.quant_state._orig_to = None # pylint: disable=protected-access - - -def load_lora(model, cfg, inference=False, config_only=False): - # type: (PreTrainedModel, DictDefault, bool, bool) -> Tuple[Optional[PreTrainedModel], Optional[PeftConfig]] - - from peft import LoraConfig, get_peft_model - - lora_target_modules = cfg.lora_target_modules or [] - - if cfg.lora_target_linear: - linear_names = find_all_linear_names(model) - LOG.info(f"found linear modules: {repr(sorted(linear_names))}") - lora_target_modules_as_list = ( - lora_target_modules - if isinstance(lora_target_modules, list) - else [lora_target_modules] - ) - lora_target_modules = list(set(lora_target_modules_as_list + linear_names)) - - lora_config_kwargs = {} - loftq_bits = cfg.peft and cfg.peft.loftq_config and cfg.peft.loftq_config.loftq_bits - if loftq_bits: - lora_config_kwargs["loftq_config"] = LoftQConfig(loftq_bits=loftq_bits) - lora_config_kwargs["init_lora_weights"] = "loftq" - if cfg.peft_init_lora_weights: - lora_config_kwargs["init_lora_weights"] = cfg.peft_init_lora_weights - if cfg.peft_use_dora: - lora_config_kwargs["use_dora"] = cfg.peft_use_dora - LOG.info("Initializing LoRA weights using dora. This might take longer.") - if cfg.peft_use_rslora: - lora_config_kwargs["use_rslora"] = cfg.peft_use_rslora - if cfg.peft_layer_replication: - lora_config_kwargs["layer_replication"] = cfg.peft_layer_replication - - lora_config = LoraConfig( - r=cfg.lora_r, - lora_alpha=cfg.lora_alpha, - target_modules=lora_target_modules, - layers_to_transform=cfg.peft_layers_to_transform, - layers_pattern=cfg.peft_layers_pattern, - lora_dropout=cfg.lora_dropout, - fan_in_fan_out=cfg.lora_fan_in_fan_out, - modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None, - bias="none", - task_type="CAUSAL_LM", - **lora_config_kwargs, - ) - - if config_only: - return None, lora_config - - rank = int(os.environ.get("LOCAL_RANK", 0)) - - if ( - cfg.fsdp - and cfg.adapter - and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading - and rank != 0 - ): - setup_quantized_meta_for_peft(model) - - if cfg.lora_model_dir: - LOG.debug("Loading pretrained PEFT - LoRA") - model_kwargs: Any = {} - if cfg.lora_on_cpu: - model_kwargs["max_memory"] = {"cpu": "256GiB"} - model_kwargs["device_map"] = {"": "cpu"} - model = PeftModel.from_pretrained( - model, - cfg.lora_model_dir, - is_trainable=(not inference), - **model_kwargs, - ) - else: - model = get_peft_model(model, lora_config) - - if rank == 0: - try: - model.print_trainable_parameters() - except AttributeError as exc: - LOG.warning( - "Exception caught during model.print_trainable_parameters(): %s", exc - ) - elif ( - cfg.fsdp - and cfg.adapter - and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading - and rank != 0 - ): - setup_quantized_peft_meta_for_training(model) - - return model, lora_config - - -def ensure_dtype(model, dtype=torch.bfloat16): - for name, module in model.named_modules(): - weight_mismatch = False - bias_mismatch = False - try: - weight_mismatch = module.weight.dtype != dtype - except AttributeError: - pass - try: - bias_mismatch = module.bias.dtype != dtype - except AttributeError: - pass - - if weight_mismatch: - print(f"Converting module {name}.weight: {module.weight.dtype} -> {dtype}") - if bias_mismatch: - print(f"Converting module {name}.bias: {module.bias.dtype} -> {dtype}") - if weight_mismatch or bias_mismatch: - module.to(dtype) diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 8ae9d5c04..cc5f54ac4 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -470,6 +470,16 @@ class AxolotlInputConfig( return data + @model_validator(mode="before") + @classmethod + def check_sample_packing_with_s2attn(cls, data): + if data.get("sample_packing") and data.get("s2_attention"): + raise ValueError( + "Received `sample_packing=true` and `s2_attention=true`; however, \ + shifted-sparse attention does not currently support sample packing." + ) + return data + @model_validator(mode="before") @classmethod def check_batch_flattening_fa(cls, data): diff --git a/tests/core/test_trainer_builder.py b/tests/core/test_trainer_builder.py index d1ad273ea..492578c40 100644 --- a/tests/core/test_trainer_builder.py +++ b/tests/core/test_trainer_builder.py @@ -1,13 +1,11 @@ -""" -unit tests for axolotl.core.trainer_builder -""" +"""Unit tests for axolotl.core.trainer_builder""" import pytest from axolotl.core.trainer_builder import HFRLTrainerBuilder +from axolotl.loaders import ModelLoader, load_tokenizer from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from axolotl.utils.models import load_model, load_tokenizer from axolotl.utils.schemas.enums import RLType @@ -50,7 +48,7 @@ def fixture_tokenizer(cfg): @pytest.fixture(name="model") def fixture_model(cfg, tokenizer): - return load_model(cfg, tokenizer) + return ModelLoader(cfg, tokenizer).load() class TestHFRLTrainerBuilder: diff --git a/tests/e2e/patched/test_model_patches.py b/tests/e2e/patched/test_model_patches.py index 26090e697..5ea88b001 100644 --- a/tests/e2e/patched/test_model_patches.py +++ b/tests/e2e/patched/test_model_patches.py @@ -6,9 +6,9 @@ import unittest import transformers +from axolotl.loaders import ModelLoader, load_tokenizer from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault -from axolotl.utils.models import load_model, load_tokenizer from ..utils import with_temp_dir @@ -50,7 +50,7 @@ class TestModelPatches(unittest.TestCase): cfg = validate_config(cfg) normalize_config(cfg) tokenizer = load_tokenizer(cfg) - load_model(cfg, tokenizer, inference=False) + ModelLoader(cfg, tokenizer, inference=False).load() @with_temp_dir def test_mistral_multipack(self, temp_dir): @@ -83,7 +83,7 @@ class TestModelPatches(unittest.TestCase): cfg = validate_config(cfg) normalize_config(cfg) tokenizer = load_tokenizer(cfg) - load_model(cfg, tokenizer, inference=False) + ModelLoader(cfg, tokenizer, inference=False).load() assert ( "torch.jit" diff --git a/tests/e2e/test_load_model.py b/tests/e2e/test_load_model.py index 96745c040..5061945b4 100644 --- a/tests/e2e/test_load_model.py +++ b/tests/e2e/test_load_model.py @@ -6,8 +6,8 @@ import tempfile import pytest import torch +from axolotl.loaders import ModelLoader, load_tokenizer from axolotl.utils.dict import DictDefault -from axolotl.utils.models import ModelLoader, load_model, load_tokenizer @pytest.fixture(name="temp_dir") @@ -58,6 +58,8 @@ class TestLoadModelUtils: ModelLoader( cfg=self.cfg, tokenizer="", + inference=False, + reference_model=True, ) ) @@ -71,13 +73,8 @@ class TestLoadModelUtils: ): self.cfg.output_dir = temp_dir self.model_loader.tokenizer = load_tokenizer(self.cfg) # pylint: disable=all - self.model_loader.model, _ = load_model( - self.cfg, - self.model_loader.tokenizer, - inference=False, - reference_model=True, - ) - self.model_loader.convert_embedding_modules_dtype( + self.model_loader.load() + self.model_loader._convert_embedding_modules_dtype( embedding_modules, dist_dtype, before_kbit_train_or_finetune ) for name, module in self.model_loader.model.named_modules(): diff --git a/tests/patched/test_validation.py b/tests/patched/test_validation.py index 683db61b2..1c7325dff 100644 --- a/tests/patched/test_validation.py +++ b/tests/patched/test_validation.py @@ -9,11 +9,11 @@ from typing import Optional import pytest from pydantic import ValidationError +from axolotl.loaders.utils import check_model_config from axolotl.utils import is_comet_available from axolotl.utils.config import validate_config from axolotl.utils.dict import DictDefault from axolotl.utils.mlflow_ import setup_mlflow_env_vars -from axolotl.utils.models import check_model_config from axolotl.utils.schemas.config import AxolotlConfigWCapabilities from axolotl.utils.wandb_ import setup_wandb_env_vars @@ -1215,6 +1215,20 @@ class TestValidation(BaseValidation): cfg, capabilities=capabilities, env_capabilities=env_capabilities ) + def test_cfg_throws_error_with_s2_attention_and_sample_packing(self, minimal_cfg): + test_cfg = DictDefault( + { + "s2_attention": True, + "sample_packing": True, + } + | minimal_cfg + ) + with pytest.raises( + ValidationError, + match=r".*shifted-sparse attention does not currently support sample packing*", + ): + validate_config(test_cfg) + class TestTorchCompileValidation(BaseValidation): """ diff --git a/tests/test_exact_deduplication.py b/tests/test_exact_deduplication.py index 1d41a248d..29672c9e5 100644 --- a/tests/test_exact_deduplication.py +++ b/tests/test_exact_deduplication.py @@ -1,7 +1,8 @@ -""" -Test suite for functions in the axolotl.utils.data.utils module, focusing on the deduplicate_and_log_datasets function. +"""Test suite for functions in the `axolotl.utils.data.utils` module, focusing on the +`deduplicate_and_log_datasets` function. -Additionally, this test suite includes tests for functions that indirectly call deduplicate_and_log_datasets during the execution of the preprocess command. +Additionally, this test suite includes tests for functions that indirectly call +`deduplicate_and_log_datasets` during the execution of the preprocess command. """ import hashlib @@ -11,20 +12,19 @@ from unittest.mock import patch import pytest from datasets import Dataset +from axolotl.loaders import load_processor, load_tokenizer from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.data import prepare_dataset from axolotl.utils.data.rl import load_prepare_preference_datasets from axolotl.utils.data.utils import deduplicate_and_log_datasets from axolotl.utils.dict import DictDefault -from axolotl.utils.models import load_processor, load_tokenizer from tests.constants import ALPACA_MESSAGES_CONFIG_REVISION from tests.hf_offline_utils import enable_hf_offline def verify_deduplication(actual_dataset, expected_dataset, dataset_name): - """ - Validates deduplication results and size consistency. + """Validates deduplication results and size consistency. Parameters: - actual_dataset: Deduplicated dataset. @@ -49,9 +49,7 @@ def verify_deduplication(actual_dataset, expected_dataset, dataset_name): class TestDeduplicateIndividualFunctions(unittest.TestCase): - """ - test class for deduplication function in data utils - """ + """Test class for deduplication function in data utils""" def setUp(self): # Sample data with duplicates @@ -248,7 +246,7 @@ class TestDeduplicateRLDataset: # pylint: disable=duplicate-code with ( patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset, - patch("axolotl.utils.models.load_tokenizer") as mock_load_tokenizer, + patch("axolotl.loaders.load_tokenizer") as mock_load_tokenizer, ): # Set up the mock to return different values on successive calls mock_load_dataset.side_effect = [ @@ -272,7 +270,7 @@ class TestDeduplicateRLDataset: # pylint: disable=duplicate-code with ( patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset, - patch("axolotl.utils.models.load_tokenizer") as mock_load_tokenizer, + patch("axolotl.loaders.load_tokenizer") as mock_load_tokenizer, ): # Set up the mock to return different values on successive calls mock_load_dataset.side_effect = [ @@ -411,7 +409,7 @@ class TestDeduplicateNonRL(unittest.TestCase): class TestWrongCollisions(unittest.TestCase): - """Creating mock datasets for testing wrong collisions""" + """Creating mock datasets for testing wrong collisions.""" def setUp(self): self.train_data = {"text": ["sample 5", "sample 6"], "label": [1, 2]} diff --git a/tests/utils/test_models.py b/tests/test_loaders.py similarity index 83% rename from tests/utils/test_models.py rename to tests/test_loaders.py index bcc1ba5d1..7313a8267 100644 --- a/tests/utils/test_models.py +++ b/tests/test_loaders.py @@ -1,18 +1,18 @@ -"""Module for testing models utils file.""" +"""Module for `axolotl.loaders`.""" -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock import pytest from transformers import BitsAndBytesConfig, PreTrainedTokenizerBase from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.utils.import_utils import is_torch_mps_available +from axolotl.loaders import ModelLoader from axolotl.utils.dict import DictDefault -from axolotl.utils.models import ModelLoader, load_model class TestModelsUtils: - """Testing module for models utils.""" + """Testing module for `axolotl.loaders`.""" def setup_method(self) -> None: # load config @@ -50,7 +50,8 @@ class TestModelsUtils: device_map = self.cfg.device_map if is_torch_mps_available(): device_map = "mps" - self.model_loader.set_device_map_config() + # pylint: disable=protected-access + self.model_loader._set_device_map_config() if is_deepspeed_zero3_enabled(): assert "device_map" not in self.model_loader.model_kwargs else: @@ -59,29 +60,6 @@ class TestModelsUtils: # check torch_dtype assert self.cfg.torch_dtype == self.model_loader.model_kwargs["torch_dtype"] - def test_cfg_throws_error_with_s2_attention_and_sample_packing(self): - cfg = DictDefault( - { - "s2_attention": True, - "sample_packing": True, - "base_model": "", - "model_type": "AutoModelForCausalLM", - } - ) - - # Mock out call to HF hub - with patch( - "axolotl.utils.models.load_model_config" - ) as mocked_load_model_config: - mocked_load_model_config.return_value = {} - with pytest.raises(ValueError) as exc: - # Should error before hitting tokenizer, so we pass in an empty str - load_model(cfg, tokenizer="") # type: ignore - assert ( - "shifted-sparse attention does not currently support sample packing" - in str(exc.value) - ) - @pytest.mark.parametrize("adapter", ["lora", "qlora", None]) @pytest.mark.parametrize("load_in_8bit", [True, False]) @pytest.mark.parametrize("load_in_4bit", [True, False]) @@ -99,7 +77,8 @@ class TestModelsUtils: self.cfg.gptq = gptq self.cfg.adapter = adapter - self.model_loader.set_quantization_config() + # pylint: disable=protected-access + self.model_loader._set_quantization_config() if "quantization_config" in self.model_loader.model_kwargs or self.cfg.gptq: assert not ( hasattr(self.model_loader.model_kwargs, "load_in_8bit") diff --git a/tests/test_lora.py b/tests/test_lora.py index 540371bef..6edcdd88e 100644 --- a/tests/test_lora.py +++ b/tests/test_lora.py @@ -2,9 +2,9 @@ tests for loading loras """ +from axolotl.loaders import ModelLoader, load_tokenizer from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault -from axolotl.utils.models import load_model, load_tokenizer # pylint: disable=duplicate-code minimal_config = DictDefault( @@ -46,7 +46,7 @@ class TestLoRALoad: cfg = validate_config(cfg) normalize_config(cfg) tokenizer = load_tokenizer(cfg) - load_model(cfg, tokenizer) + ModelLoader(cfg, tokenizer).load() def test_load_lora_weights_empty_dropout(self): cfg = DictDefault( @@ -67,4 +67,4 @@ class TestLoRALoad: normalize_config(cfg) assert cfg.lora_dropout == 0.0 tokenizer = load_tokenizer(cfg) - load_model(cfg, tokenizer) + ModelLoader(cfg, tokenizer).load() diff --git a/tests/test_tokenizers.py b/tests/test_tokenizers.py index ffd51bc29..406462038 100644 --- a/tests/test_tokenizers.py +++ b/tests/test_tokenizers.py @@ -6,8 +6,8 @@ import unittest import pytest +from axolotl.loaders import load_tokenizer from axolotl.utils.dict import DictDefault -from axolotl.utils.models import load_tokenizer from tests.hf_offline_utils import enable_hf_offline diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py deleted file mode 100644 index e69de29bb..000000000