diff --git a/.mypy.ini b/.mypy.ini index ede9fef88..c6d837d3f 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -11,6 +11,9 @@ ignore_errors = True [mypy-axolotl.models.mixtral.*] ignore_errors = True +[mypy-axolotl.integrations.liger.models.*] +ignore_errors = True + [mypy-axolotl.models.phi.*] ignore_errors = True diff --git a/requirements.txt b/requirements.txt index be0c4927e..f5fb547a2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -33,6 +33,8 @@ gradio==3.50.2 tensorboard python-dotenv==1.0.1 autoawq>=0.2.5 +triton>=2.3.0 +liger-kernel mamba-ssm==1.2.0.post1 diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index a05ee84e9..aaa62423c 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -27,6 +27,7 @@ from transformers.utils import is_torch_bf16_gpu_available from transformers.utils.import_utils import _is_package_available from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer +from axolotl.integrations.base import PluginManager from axolotl.logging_config import configure_logging from axolotl.train import TrainDatasetMeta from axolotl.utils.config import ( @@ -365,6 +366,11 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs): cfg.axolotl_config_path = config + if cfg.get("plugins"): + plugin_manager = PluginManager.get_instance() + for plugin_name in cfg["plugins"]: + plugin_manager.register(plugin_name) + try: device_props = torch.cuda.get_device_properties("cuda") gpu_version = "sm_" + str(device_props.major) + str(device_props.minor) diff --git a/src/axolotl/integrations/base.py b/src/axolotl/integrations/base.py new file mode 100644 index 000000000..d26eed90f --- /dev/null +++ b/src/axolotl/integrations/base.py @@ -0,0 +1,383 @@ +# Copyright 2024 Axolotl AI. All rights reserved. +# +# This software may be used and distributed according to +# the terms of the Axolotl Community License Agreement (the "License"); +# you may not use this file except in compliance with the License. +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. + +""" +Base class for all plugins. + +A plugin is a reusable, modular, and self-contained piece of code that extends the functionality of Axolotl. +Plugins can be used to integrate third-party models, modify the training process, or add new features. + +To create a new plugin, you need to inherit from the BasePlugin class and implement the required methods. +""" +import importlib +import logging +from typing import List + + +class BasePlugin: + """ + Base class for all plugins. Defines the interface for plugin methods. + + Attributes: + None + + Methods: + register(cfg): Registers the plugin with the given configuration. + pre_model_load(cfg): Performs actions before the model is loaded. + post_model_load(cfg, model): Performs actions after the model is loaded. + pre_lora_load(cfg, model): Performs actions before LoRA weights are loaded. + post_lora_load(cfg, model): Performs actions after LoRA weights are loaded. + create_optimizer(cfg, trainer): Creates and returns an optimizer for training. + create_lr_scheduler(cfg, trainer, optimizer): Creates and returns a learning rate scheduler. + add_callbacks_pre_trainer(cfg, model): Adds callbacks to the trainer before training. + add_callbacks_post_trainer(cfg, trainer): Adds callbacks to the trainer after training. + """ + + def __init__(self): + """ + Initializes the BasePlugin. + """ + + def register(self, cfg): + """ + Registers the plugin with the given configuration. + + Parameters: + cfg (dict): The configuration for the plugin. + + Returns: + None + """ + + def get_input_args(self): + """ + Returns a pydantic model for the plugin's input arguments. + """ + + def pre_model_load(self, cfg): + """ + Performs actions before the model is loaded. + + Parameters: + cfg (dict): The configuration for the plugin. + + Returns: + None + """ + + def post_model_load(self, cfg, model): + """ + Performs actions after the model is loaded. + + Parameters: + cfg (dict): The configuration for the plugin. + model (object): The loaded model. + + Returns: + None + """ + + def pre_lora_load(self, cfg, model): + """ + Performs actions before LoRA weights are loaded. + + Parameters: + cfg (dict): The configuration for the plugin. + model (object): The loaded model. + + Returns: + None + """ + + def post_lora_load(self, cfg, model): + """ + Performs actions after LoRA weights are loaded. + + Parameters: + cfg (dict): The configuration for the plugin. + model (object): The loaded model. + + Returns: + None + """ + + def create_optimizer(self, cfg, trainer): + """ + Creates and returns an optimizer for training. + + Parameters: + cfg (dict): The configuration for the plugin. + trainer (object): The trainer object for training. + + Returns: + object: The created optimizer. + """ + + def create_lr_scheduler(self, cfg, trainer, optimizer): + """ + Creates and returns a learning rate scheduler. + + Parameters: + cfg (dict): The configuration for the plugin. + trainer (object): The trainer object for training. + optimizer (object): The optimizer for training. + + Returns: + object: The created learning rate scheduler. + """ + + def add_callbacks_pre_trainer(self, cfg, model): + """ + Adds callbacks to the trainer before training. + + Parameters: + cfg (dict): The configuration for the plugin. + model (object): The loaded model. + + Returns: + List[callable]: A list of callback functions to be added to the TrainingArgs + """ + + def add_callbacks_post_trainer(self, cfg, trainer): + """ + Adds callbacks to the trainer after training. + + Parameters: + cfg (dict): The configuration for the plugin. + trainer (object): The trainer object for training. + + Returns: + List[callable]: A list of callback functions to be added to the TrainingArgs + """ + + +def load_plugin(plugin_name: str) -> BasePlugin: + """ + Loads a plugin based on the given plugin name. + + The plugin name should be in the format "module_name.class_name". + This function splits the plugin name into module and class, imports the module, + retrieves the class from the module, and creates an instance of the class. + + Parameters: + plugin_name (str): The name of the plugin to be loaded. The name should be in the format "module_name.class_name". + + Returns: + BasePlugin: An instance of the loaded plugin. + + Raises: + ImportError: If the plugin module cannot be imported. + """ + # split the plugin name into module and class + module_name, class_name = plugin_name.rsplit(".", 1) + + # import the module + module = importlib.import_module(module_name) + # instantiate the class + plugin_class = getattr(module, class_name) + # create an instance of the class + plugin = plugin_class() + + return plugin + + +class PluginManager: + """ + The PluginManager class is responsible for loading and managing plugins. + It should be a singleton so it can be accessed from anywhere in the codebase. + + Attributes: + plugins (List[BasePlugin]): A list of loaded plugins. + + Methods: + get_instance(): Static method to get the singleton instance of PluginManager. + register(plugin_name: str): Registers a new plugin by its name. + pre_model_load(cfg): Calls the pre_model_load method of all registered plugins. + """ + + plugins: List[BasePlugin] = [] + + _instance = None + + def __new__(cls): + """ + Creates a new instance of PluginManager if it doesn't exist yet. + """ + if cls._instance is None: + cls._instance = super(PluginManager, cls).__new__(cls) + cls._instance.plugins: List[BasePlugin] = [] + return cls._instance + + @staticmethod + def get_instance() -> "PluginManager": + """ + Returns the singleton instance of PluginManager. + If the instance doesn't exist, it creates a new one. + """ + if PluginManager._instance is None: + PluginManager() + return PluginManager._instance # type: ignore + + def register(self, plugin_name: str): + """ + Registers a new plugin by its name. + + Parameters: + plugin_name (str): The name of the plugin to be registered. + + Returns: + None + + Raises: + ImportError: If the plugin module cannot be imported. + """ + try: + plugin = load_plugin(plugin_name) + self.plugins.append(plugin) + except ImportError: + logging.error(f"Failed to load plugin: {plugin_name}") + + def get_input_args(self): + """ + Returns a list of Pydantic classes for all registered plugins' input arguments.' + + Returns: + list[str]: A list of Pydantic classes for all registered plugins' input arguments.' + """ + input_args = [] + for plugin in self.plugins: + input_args_from_plugin = plugin.get_input_args() + if input_args_from_plugin is not None: + input_args.append(input_args_from_plugin) + return input_args + + def pre_model_load(self, cfg): + """ + Calls the pre_model_load method of all registered plugins. + + Parameters: + cfg (dict): The configuration for the plugins. + + Returns: + None + """ + for plugin in self.plugins: + plugin.pre_model_load(cfg) + + def post_model_load(self, cfg, model): + """ + Calls the post_model_load method of all registered plugins. + + Parameters: + cfg (dict): The configuration for the plugins. + model (object): The loaded model. + + Returns: + None + """ + for plugin in self.plugins: + plugin.post_model_load(cfg, model) + + def pre_lora_load(self, cfg, model): + """ + Calls the pre_lora_load method of all registered plugins. + + Parameters: + cfg (dict): The configuration for the plugins. + model (object): The loaded model. + + Returns: + None + """ + for plugin in self.plugins: + plugin.pre_lora_load(cfg, model) + + def post_lora_load(self, cfg, model): + """ + Calls the post_lora_load method of all registered plugins. + + Parameters: + cfg (dict): The configuration for the plugins. + model (object): The loaded model. + + Returns: + None + """ + for plugin in self.plugins: + plugin.post_lora_load(cfg, model) + + def create_optimizer(self, cfg, trainer): + """ + Calls the create_optimizer method of all registered plugins and returns the first non-None optimizer. + + Parameters: + cfg (dict): The configuration for the plugins. + trainer (object): The trainer object for training. + + Returns: + object: The created optimizer, or None if none was found. + """ + for plugin in self.plugins: + optimizer = plugin.create_optimizer(cfg, trainer) + if optimizer is not None: + return optimizer + return None + + def create_lr_scheduler(self, cfg, trainer, optimizer): + """ + Calls the create_lr_scheduler method of all registered plugins and returns the first non-None scheduler. + + Parameters: + cfg (dict): The configuration for the plugins. + trainer (object): The trainer object for training. + optimizer (object): The optimizer for training. + + Returns: + object: The created learning rate scheduler, or None if none was found. + """ + for plugin in self.plugins: + scheduler = plugin.create_lr_scheduler(cfg, trainer, optimizer) + if scheduler is not None: + return scheduler + return None + + def add_callbacks_pre_trainer(self, cfg, model): + """ + Calls the add_callbacks_pre_trainer method of all registered plugins. + + Parameters: + cfg (dict): The configuration for the plugins. + model (object): The loaded model. + + Returns: + List[callable]: A list of callback functions to be added to the TrainingArgs. + """ + callbacks = [] + for plugin in self.plugins: + callbacks.extend(plugin.add_callbacks_pre_trainer(cfg, model)) + return callbacks + + def add_callbacks_post_trainer(self, cfg, trainer): + """ + Calls the add_callbacks_post_trainer method of all registered plugins. + + Parameters: + cfg (dict): The configuration for the plugins. + trainer (object): The trainer object for training. + + Returns: + List[callable]: A list of callback functions to be added to the TrainingArgs. + """ + callbacks = [] + for plugin in self.plugins: + callbacks.extend(plugin.add_callbacks_post_trainer(cfg, trainer)) + return callbacks diff --git a/src/axolotl/integrations/config.py b/src/axolotl/integrations/config.py new file mode 100644 index 000000000..b4ffd6758 --- /dev/null +++ b/src/axolotl/integrations/config.py @@ -0,0 +1,65 @@ +# Copyright 2024 Axolotl AI. All rights reserved. +# +# This software may be used and distributed according to +# the terms of the Axolotl Community License Agreement (the "License"); +# you may not use this file except in compliance with the License. +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. + +""" +module to handle merging the plugins' input arguments with the base configurations. + +this was moved here to prevent circular imports +""" + +from typing import Any, Dict, List + +from axolotl.utils.config.models.input.v0_4_1 import ( + AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase, +) +from axolotl.utils.config.models.input.v0_4_1 import ( + AxolotlInputConfig as AxolotlInputConfigBase, +) + + +def merge_input_args(): + """ + Merges input arguments from registered plugins with the base configurations. + + This function retrieves the input arguments from registered plugins using the PluginManager. + It then dynamically creates new classes, AxolotlConfigWCapabilities and AxolotlInputConfig, + that inherit from the base configurations and include the input arguments from the plugins. + + Returns: + tuple: A tuple containing the newly created classes, AxolotlConfigWCapabilities and AxolotlInputConfig. + """ + from axolotl.integrations.base import PluginManager + + plugin_manager = PluginManager.get_instance() + input_args: List[str] = plugin_manager.get_input_args() + plugin_classes = [] + dynamic_input = "" + for plugin_args in input_args: + plugin_module, plugin_cls = plugin_args.rsplit(".", 1) + dynamic_input += f"from {plugin_module} import {plugin_cls}\n" + plugin_classes.append(plugin_cls) + if dynamic_input: + dynamic_input += f"class AxolotlConfigWCapabilities(AxolotlConfigWCapabilitiesBase, {', '.join(plugin_classes)}):\n pass\n" + dynamic_input += f"class AxolotlInputConfig(AxolotlInputConfigBase, {', '.join(plugin_classes)}):\n pass\n" + + namespace: Dict[Any, Any] = {} + exec( # pylint: disable=exec-used # nosec B102 + dynamic_input, globals(), namespace + ) + AxolotlInputConfig = namespace[ # pylint: disable=invalid-name + "AxolotlInputConfig" + ] + AxolotlConfigWCapabilities = namespace[ # pylint: disable=invalid-name + "AxolotlConfigWCapabilities" + ] + return AxolotlConfigWCapabilities, AxolotlInputConfig + return AxolotlConfigWCapabilitiesBase, AxolotlInputConfigBase diff --git a/src/axolotl/integrations/liger/LICENSE b/src/axolotl/integrations/liger/LICENSE new file mode 100644 index 000000000..d64569567 --- /dev/null +++ b/src/axolotl/integrations/liger/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/src/axolotl/integrations/liger/__init__.py b/src/axolotl/integrations/liger/__init__.py new file mode 100644 index 000000000..d4c1ad9a4 --- /dev/null +++ b/src/axolotl/integrations/liger/__init__.py @@ -0,0 +1,104 @@ +# Copyright 2024 Axolotl AI. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Module for the Plugin for LIGER integraton with Axolotl. + +Liger Kernel is the collection of Triton-native kernels for LLM Training. +It is designed to be performant, correct, and light-weight. +""" +import logging + +from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss +from liger_kernel.transformers.geglu import LigerGEGLUMLP +from liger_kernel.transformers.model.llama import lce_forward +from liger_kernel.transformers.rms_norm import LigerRMSNorm +from liger_kernel.transformers.rope import liger_rotary_pos_emb +from liger_kernel.transformers.swiglu import LigerSwiGLUMLP + +from axolotl.integrations.base import BasePlugin + +from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401 + + +class LigerPlugin(BasePlugin): + """ + Plugin for LIGER integraton with Axolotl. + """ + + def get_input_args(self): + return "axolotl.integrations.liger.LigerArgs" + + def pre_model_load(self, cfg): + if cfg.model_config_type == "llama": + from transformers.models.llama import modeling_llama + + if cfg.liger_rope: + modeling_llama.apply_rotary_pos_emb = liger_rotary_pos_emb + if cfg.liger_rms_norm: + modeling_llama.LlamaRMSNorm = LigerRMSNorm + if cfg.liger_swiglu: + modeling_llama.LlamaMLP = LigerSwiGLUMLP + if cfg.liger_cross_entropy: + modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss + elif cfg.liger_fused_linear_cross_entropy: + modeling_llama.LlamaForCausalLM.forward = lce_forward + + elif cfg.model_config_type == "mistral": + from transformers.models.mistral import modeling_mistral + + if cfg.liger_rope: + modeling_mistral.apply_rotary_pos_emb = liger_rotary_pos_emb + if cfg.liger_rms_norm: + modeling_mistral.MistralRMSNorm = LigerRMSNorm + if cfg.liger_swiglu: + modeling_mistral.MistralMLP = LigerSwiGLUMLP + if cfg.liger_cross_entropy: + modeling_mistral.CrossEntropyLoss = LigerCrossEntropyLoss + if cfg.liger_fused_linear_cross_entropy: + logging.warning( + "Fused linear cross entropy is not supported for Mistral." + ) + + elif cfg.model_config_type == "gemma": + from transformers.models.gemma import modeling_gemma + + if cfg.liger_rope: + modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb + if cfg.liger_rms_norm: + modeling_gemma.GemmaRMSNorm = LigerRMSNorm + if cfg.liger_swiglu: + modeling_gemma.GemmaMLP = LigerGEGLUMLP + if cfg.liger_cross_entropy: + modeling_gemma.CrossEntropyLoss = LigerCrossEntropyLoss + if cfg.liger_fused_linear_cross_entropy: + logging.warning( + "Fused linear cross entropy is not supported for Gemma." + ) + + elif cfg.model_config_type == "jamba": + from transformers.models.jamba import modeling_jamba + + from .models.jamba import lce_forward as jamba_lce_forward + + if cfg.liger_rope: + modeling_jamba.apply_rotary_pos_emb = liger_rotary_pos_emb + if cfg.liger_rms_norm: + modeling_jamba.JambaRMSNorm = LigerRMSNorm + if cfg.liger_swiglu: + modeling_jamba.JambaMLP = LigerSwiGLUMLP + if cfg.liger_cross_entropy: + modeling_jamba.CrossEntropyLoss = LigerCrossEntropyLoss + if cfg.liger_fused_linear_cross_entropy: + modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward diff --git a/src/axolotl/integrations/liger/args.py b/src/axolotl/integrations/liger/args.py new file mode 100644 index 000000000..decdb3775 --- /dev/null +++ b/src/axolotl/integrations/liger/args.py @@ -0,0 +1,32 @@ +# Copyright 2024 Axolotl AI. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Module for handling LIGER input arguments. +""" +from typing import Optional + +from pydantic import BaseModel + + +class LigerArgs(BaseModel): + """ + Input args for LIGER. + """ + + liger_rope: Optional[bool] = None + liger_rms_norm: Optional[bool] = None + liger_swiglu: Optional[bool] = None + liger_cross_entropy: Optional[bool] = None + liger_fused_linear_cross_entropy: Optional[bool] = None diff --git a/src/axolotl/integrations/liger/models/jamba.py b/src/axolotl/integrations/liger/models/jamba.py new file mode 100644 index 000000000..40cec63a4 --- /dev/null +++ b/src/axolotl/integrations/liger/models/jamba.py @@ -0,0 +1,173 @@ +""" +Jamba model with LigerFusedLinearCrossEntropyLoss +""" +# pylint: disable=duplicate-code + +from typing import Optional, Tuple, Union + +import torch +from liger_kernel.transformers.fused_linear_cross_entropy import ( + LigerFusedLinearCrossEntropyLoss, +) +from torch.nn import CrossEntropyLoss +from transformers.modeling_outputs import MoeCausalLMOutputWithPast +from transformers.models.jamba.modeling_jamba import ( + _CONFIG_FOR_DOC, + JAMBA_INPUTS_DOCSTRING, + HybridMambaAttentionDynamicCache, + load_balancing_loss_func, +) +from transformers.utils import ( + add_start_docstrings_to_model_forward, + replace_return_docstrings, +) + + +@add_start_docstrings_to_model_forward(JAMBA_INPUTS_DOCSTRING) +@replace_return_docstrings( + output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC +) +def lce_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: Optional[Union[int, None]] = None, +) -> Union[Tuple, MoeCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + num_logits_to_keep (`int` or `None`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all + `input_ids`. Only last token logits are needed for generation, and calculating them only for that token + can save memory, which becomes pretty significant for long sequences. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, JambaForCausalLM + + >>> model = JambaForCausalLM.from_pretrained("ai21labs/Jamba-v0.1") + >>> tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_router_logits = ( + output_router_logits + if output_router_logits is not None + else self.config.output_router_logits + ) + + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + cache_position=cache_position, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + + loss = None + logits = None + + if self.training: + shift_hidden_states = hidden_states[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + # flatten tokens + shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size) + shift_labels = shift_labels.view(-1) + + lce = LigerFusedLinearCrossEntropyLoss() + loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels) + else: + if num_logits_to_keep is None: + logits = self.lm_head(hidden_states) + else: + logits = self.lm_head(hidden_states[..., -num_logits_to_keep:, :]) + logits = logits.float() + + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits if return_dict else outputs[-1], + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss.to( + loss.device + ) # make sure to reside in the same device + + if not return_dict: + output = (logits,) + outputs[1:] + if output_router_logits: + output = (aux_loss,) + output + return (loss,) + output if loss is not None else output + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index ed165e89c..82436e8d7 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -8,11 +8,14 @@ from typing import Optional import torch from transformers.utils import is_torch_bf16_gpu_available +from axolotl.integrations.config import merge_input_args from axolotl.utils.bench import log_gpu_memory_usage +from axolotl.utils.config.models.input.v0_4_1 import SUPPORTED_METRICS from axolotl.utils.config.models.input.v0_4_1 import ( - SUPPORTED_METRICS, - AxolotlConfigWCapabilities, - AxolotlInputConfig, + AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase, +) +from axolotl.utils.config.models.input.v0_4_1 import ( + AxolotlInputConfig as AxolotlInputConfigBase, ) from axolotl.utils.dict import DictDefault from axolotl.utils.models import load_model_config @@ -207,6 +210,15 @@ def normalize_cfg_datasets(cfg): def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None): + AxolotlConfigWCapabilities = AxolotlConfigWCapabilitiesBase + AxolotlInputConfig = AxolotlInputConfigBase + + if cfg.plugins: + ( + AxolotlConfigWCapabilities, # pylint: disable=invalid-name + AxolotlInputConfig, # pylint: disable=invalid-name + ) = merge_input_args() + if capabilities: return DictDefault( dict( diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 8d24524a2..6261ce20f 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -308,10 +308,17 @@ def load_model( """ Load a model for a given configuration and tokenizer. """ + base_model = cfg.base_model model_type = cfg.type_of_model model_config = load_model_config(cfg) + # load any patches from plugins + from axolotl.integrations.base import PluginManager + + plugin_manager = PluginManager.get_instance() + plugin_manager.pre_model_load(cfg) + # TODO refactor as a kwarg load_in_8bit = cfg.load_in_8bit diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 99c10c655..f4e1fc6cb 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -217,6 +217,24 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): desc="Dropping Long Sequences", ) + # drop samples with where the number of elements with labels not equal to -100 is zero + def drop_no_trainable_tokens(sample): + return np.sum(np.array(sample["labels"]) != -100) > 0 + + train_dataset = train_dataset.filter( + drop_no_trainable_tokens, + num_proc=cfg.dataset_processes, + load_from_cache_file=not cfg.is_preprocess, + desc="Drop Samples with Zero Trainable Tokens", + ) + if eval_dataset: + eval_dataset = eval_dataset.filter( + drop_no_trainable_tokens, + num_proc=cfg.dataset_processes, + load_from_cache_file=not cfg.is_preprocess, + desc="Drop Samples with Zero Trainable Tokens", + ) + if cfg.group_by_length: train_dataset = train_dataset.map( add_length,