diff --git a/src/axolotl/integrations/cut_cross_entropy/__init__.py b/src/axolotl/integrations/cut_cross_entropy/__init__.py index a2f0d52d7..6c47097b7 100644 --- a/src/axolotl/integrations/cut_cross_entropy/__init__.py +++ b/src/axolotl/integrations/cut_cross_entropy/__init__.py @@ -19,11 +19,13 @@ Cut Cross Entropy is an optimized implementation of cross entropy loss from Apple's ML team. """ import importlib +from functools import partial import torch from axolotl.integrations.base import BasePlugin from axolotl.utils import get_pytorch_version +from axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix from axolotl.utils.logging import get_logger from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401 @@ -84,6 +86,7 @@ class CutCrossEntropyPlugin(BasePlugin): """Apply cut cross entropy before model loading if enabled.""" if cfg.cut_cross_entropy: self._check_requirements() + self.patch_llama_like(cfg.model_config_type) from cut_cross_entropy.transformers.patch import cce_patch @@ -93,3 +96,48 @@ class CutCrossEntropyPlugin(BasePlugin): # The patch checks model_type internally cce_patch(cfg.model_config_type) + + def patch_llama_like( + self, + model_type: str, + ) -> None: + """ + Generic patch for model architectures with causal lm similar to llama + """ + from cut_cross_entropy.transformers.patch import PATCH_FNS + + def patch_generic( + maybe_model, patch_options, model_type: str + ): # pylint: disable=unused-argument + import cut_cross_entropy.transformers.llama + from cut_cross_entropy.transformers.llama import cce_forward + + try: + # Dynamically import the module and CausalLM class + module_path = f"transformers.models.{model_type}.modeling_{model_type}" + model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type) + module = __import__( + module_path, fromlist=[f"{model_cls_prefix}ForCausalLM"] + ) + model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM") + + cut_cross_entropy.transformers.llama._PATCH_OPTS = ( # pylint: disable=protected-access + patch_options + ) + + model_cls.forward = cce_forward + # pylint: disable=duplicate-code + except (ImportError, AttributeError) as e: + raise RuntimeError( + f"Could not import ForCausalLM class for model_type: {model_type}. " + f"Error: {str(e)}" + ) from e + + if model_type not in PATCH_FNS: + LOG.warning_once( + "Setting up generic cce patch for model type: %s", model_type + ) + LOG.warning_once( + f"Generic Cut Cross Entropy + {model_type} support is experimental and may not work as expected." + ) + PATCH_FNS[model_type] = partial(patch_generic, model_type=model_type) diff --git a/src/axolotl/integrations/kd/kernels/models.py b/src/axolotl/integrations/kd/kernels/models.py index 6a8b6da1c..4319f5f7d 100644 --- a/src/axolotl/integrations/kd/kernels/models.py +++ b/src/axolotl/integrations/kd/kernels/models.py @@ -22,6 +22,8 @@ except ImportError: TransformersKwargs, ) +from axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix + def kldiv_forward_llama_like( self, @@ -97,7 +99,7 @@ def kldiv_forward_llama_like( def apply_kernel(model_type): # Dynamically import the module and attention class module_path = f"transformers.models.{model_type}.modeling_{model_type}" - model_cls_prefix = "".join([part.capitalize() for part in model_type.split("_")]) + model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type) module = __import__(module_path, fromlist=[f"{model_cls_prefix}ForCausalLM"]) model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM") model_cls.forward = kldiv_forward_llama_like diff --git a/src/axolotl/integrations/liger/__init__.py b/src/axolotl/integrations/liger/__init__.py index 8de94c78b..86d56be80 100644 --- a/src/axolotl/integrations/liger/__init__.py +++ b/src/axolotl/integrations/liger/__init__.py @@ -18,170 +18,10 @@ Module for the Plugin for LIGER integraton with Axolotl. Liger Kernel is the collection of Triton-native kernels for LLM Training. It is designed to be performant, correct, and light-weight. """ -import inspect -import sys +from .args import LigerArgs +from .plugin import LigerPlugin -from axolotl.integrations.base import BasePlugin -from axolotl.utils.logging import get_logger - -from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401 -from .utils import patch_with_compile_disable - -LOG = get_logger(__name__) - - -class LigerPlugin(BasePlugin): - """ - Plugin for LIGER integraton with Axolotl. - """ - - def get_input_args(self): - return "axolotl.integrations.liger.LigerArgs" - - def pre_model_load(self, cfg): - if cfg.torch_compile: - # torch compile will unnecessarily attempt to optimize the triton kernel unless explicitly disabled - import liger_kernel.ops.fused_linear_cross_entropy - - patch_with_compile_disable( - liger_kernel.ops.fused_linear_cross_entropy, - "fused_linear_cross_entropy_forward", - ) - patch_with_compile_disable( - liger_kernel.ops.fused_linear_cross_entropy, - "fused_linear_cross_entropy_backward", - ) - from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss - from liger_kernel.transformers.functional import liger_cross_entropy - from liger_kernel.transformers.layer_norm import LigerLayerNorm - from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN - from liger_kernel.transformers.rms_norm import LigerRMSNorm - from liger_kernel.transformers.rope import liger_rotary_pos_emb - from liger_kernel.transformers.swiglu import LigerSwiGLUMLP - - if cfg.liger_cross_entropy and cfg.liger_fused_linear_cross_entropy: - raise ValueError( - "Cannot have both `liger_cross_entropy` and `liger_fused_linear_cross_entropy` set." - ) - - if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN: - apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type] - liger_fn_sig = inspect.signature(apply_liger_fn) - kwargs = {} - if "rope" in liger_fn_sig.parameters: - kwargs["rope"] = cfg.liger_rope - if "cross_entropy" in liger_fn_sig.parameters: - kwargs["cross_entropy"] = cfg.liger_cross_entropy - if "fused_linear_cross_entropy" in liger_fn_sig.parameters: - kwargs["fused_linear_cross_entropy"] = ( - cfg.liger_fused_linear_cross_entropy - ) - if "rms_norm" in liger_fn_sig.parameters: - kwargs["rms_norm"] = cfg.liger_rms_norm - if "layer_norm" in liger_fn_sig.parameters: - kwargs["layer_norm"] = cfg.liger_layer_norm - if "geglu" in liger_fn_sig.parameters: - kwargs["geglu"] = cfg.liger_glu_activation - elif "swiglu" in liger_fn_sig.parameters: - kwargs["swiglu"] = cfg.liger_glu_activation - LOG.info(f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}") - apply_liger_fn(**kwargs) - elif cfg.model_config_type == "jamba": - from transformers.models.jamba import modeling_jamba - - from .models.jamba import lce_forward as jamba_lce_forward - - if cfg.liger_rope: - modeling_jamba.apply_rotary_pos_emb = liger_rotary_pos_emb - if cfg.liger_rms_norm: - modeling_jamba.JambaRMSNorm = LigerRMSNorm - if cfg.liger_glu_activation: - modeling_jamba.JambaMLP = LigerSwiGLUMLP - if cfg.liger_layer_norm: - modeling_jamba.nn.LayerNorm = LigerLayerNorm - if cfg.liger_cross_entropy: - from transformers.loss.loss_utils import nn - - nn.functional.cross_entropy = liger_cross_entropy - if cfg.liger_fused_linear_cross_entropy: - modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward - elif cfg.model_config_type == "deepseek_v2": - from accelerate import init_empty_weights - from transformers import AutoModelForCausalLM - - with init_empty_weights(): - model = AutoModelForCausalLM.from_pretrained( - cfg.base_model, trust_remote_code=cfg.trust_remote_code or False - ) - modeling_mod = sys.modules[model.__class__.__module__] - - from .models.deepseekv2 import lce_forward as deepseekv2_lce_forward - - if cfg.liger_rope: - # The DeepseekV2 version of RoPE is different than upstream LLaMA. - # See https://github.com/linkedin/Liger-Kernel/issues/129#issuecomment-2313763528 - LOG.warning("Fused liger_rope is not supported for DeepseekV2.") - if cfg.liger_glu_activation: - LOG.warning("liger_glu_activation is not supported for DeepseekV2.") - if cfg.liger_rms_norm: - modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm - if cfg.liger_glu_activation: - modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward - if cfg.liger_layer_norm: - modeling_mod.DeepseekV2MLP.forward = LigerLayerNorm.forward - if cfg.liger_cross_entropy: - # We do not patch `nn.functional.cross_entropy` for DeepseekV2 as it still uses - # nn.CrossEntropyLoss in the forward method. - modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss - if cfg.liger_fused_linear_cross_entropy: - modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward - elif cfg.model_config_type == "llama4": - from axolotl.integrations.liger.models.llama4 import ( - apply_liger_kernel_to_llama4, - ) - - apply_liger_kernel_to_llama4( - cross_entropy=cfg.liger_cross_entropy, - fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy, - glu_activation=cfg.liger_glu_activation, - rms_norm=cfg.liger_rms_norm, - layer_norm=cfg.liger_layer_norm, - ) - elif cfg.model_config_type == "qwen3": - from axolotl.integrations.liger.models.qwen3 import ( - apply_liger_kernel_to_qwen3, - ) - - apply_liger_kernel_to_qwen3( - cross_entropy=cfg.liger_cross_entropy, - fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy, - glu_activation=cfg.liger_glu_activation, - rms_norm=cfg.liger_rms_norm, - layer_norm=cfg.liger_layer_norm, - ) - elif cfg.model_config_type == "qwen3_moe": - from axolotl.integrations.liger.models.qwen3_moe import ( - apply_liger_kernel_to_qwen3_moe, - ) - - apply_liger_kernel_to_qwen3_moe( - cross_entropy=cfg.liger_cross_entropy, - fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy, - glu_activation=cfg.liger_glu_activation, - rms_norm=cfg.liger_rms_norm, - layer_norm=cfg.liger_layer_norm, - ) - elif cfg.model_config_type == "granitemoe": - from liger_kernel.transformers import apply_liger_kernel_to_granite - - apply_liger_kernel_to_granite( - rope=cfg.liger_rope, - cross_entropy=cfg.liger_cross_entropy, - fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy, - rms_norm=cfg.liger_rms_norm, - swiglu=cfg.liger_glu_activation, - ) - else: - LOG.warning( - f"Unsupported model config type: {cfg.model_config_type}. Liger not applied." - ) +__all__ = [ + "LigerArgs", + "LigerPlugin", +] diff --git a/src/axolotl/integrations/liger/models/base.py b/src/axolotl/integrations/liger/models/base.py new file mode 100644 index 000000000..f3cf4299a --- /dev/null +++ b/src/axolotl/integrations/liger/models/base.py @@ -0,0 +1,189 @@ +""" +Generic FLCE patch for untested models similar to Llama +""" + +from typing import Optional, Tuple, Union + +import torch +from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.trainer.orpo_trainer import _FSDPForwardRedirection +from liger_kernel.utils import PEFT_AVAILABLE +from peft.utils import ModulesToSaveWrapper +from torch.distributed.fsdp import FullyShardedDataParallel +from transformers.modeling_outputs import CausalLMOutputWithPast + +from axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix + + +def lce_forward( + self, + *args, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + skip_logits: Optional[bool] = None, + **kwargs, +) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + """ + + # pylint: disable=duplicate-code + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + *args, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = ( + slice(-logits_to_keep, None) + if isinstance(logits_to_keep, int) + else logits_to_keep + ) + kept_hidden_states = hidden_states[:, slice_indices, :] + + shift_labels = kwargs.pop("shift_labels", None) + logits = None + loss = None + + # if in training mode, don't materialize logits + if skip_logits and labels is None and shift_labels is None: + raise ValueError("skip_logits is True, but labels and shift_labels are None") + + if skip_logits is None: + # By default, if in training mode, don't materialize logits + skip_logits = self.training and (labels is not None or shift_labels is not None) + + if skip_logits: + loss = lce_maybe_trainable_lm_head( + self, + hidden_states=kept_hidden_states, + hidden_size=self.config.hidden_size, + labels=labels, + shift_labels=shift_labels, + **kwargs, + ) + + else: + logits = self.lm_head(kept_hidden_states) + if labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +def lce_maybe_trainable_lm_head( + self, hidden_states, hidden_size, labels, shift_labels, **loss_kwargs +): + lm_head = self.lm_head + + # Unwrap the module if lm_head has been added as trainable module in PEFT LoRA configuration, + # i.e. listed in the modules_to_save field of LoraConfig, so the lm_head weights are read + # from the unwrapped module. + # See https://huggingface.co/docs/peft/package_reference/lora for reference. + if PEFT_AVAILABLE and isinstance(lm_head, ModulesToSaveWrapper): + lm_head = lm_head.modules_to_save.default + + # If FSDP is used and lm_head is trainable, e.g., during full fine-tuning or with LoRA, + # reading the lm_head module weights and calling the kernel must be done within FSDP forward pass + # so the module entire parameters are summoned and kept in memory during the kernel execution. + if isinstance(lm_head, FullyShardedDataParallel): + return _FSDPForwardRedirection()( + lm_head, + _liger_for_causal_lm_loss, + lm_head.module, + hidden_states, + hidden_size, + labels, + shift_labels, + **loss_kwargs, + ) + + # FSDP is not used so we can read the lm_head weights and call the kernel directly + return _liger_for_causal_lm_loss( + lm_head=self.lm_head, + hidden_states=hidden_states, + hidden_size=hidden_size, + labels=labels, + shift_labels=shift_labels, + **loss_kwargs, + ) + + +def _liger_for_causal_lm_loss( + lm_head, hidden_states, hidden_size, labels, shift_labels, **loss_kwargs +): + return LigerForCausalLMLoss( + hidden_states=hidden_states, + lm_head_weight=lm_head.weight, + labels=labels, + hidden_size=hidden_size, + shift_labels=shift_labels, + **loss_kwargs, + ) + + +def patch_lce_forward( + model_type, +): + try: + # Dynamically import the module and MLP class + module_path = f"transformers.models.{model_type}.modeling_{model_type}" + model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type) + module = __import__(module_path, fromlist=[f"{model_cls_prefix}ForCausalLM"]) + model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM") + + model_cls.forward = lce_forward + # pylint: disable=duplicate-code + except (ImportError, AttributeError) as e: + raise RuntimeError( + f"Could not import ForCausalLM class for model_type: {model_type}. " + f"Error: {str(e)}" + ) from e diff --git a/src/axolotl/integrations/liger/plugin.py b/src/axolotl/integrations/liger/plugin.py new file mode 100644 index 000000000..89f7c37b7 --- /dev/null +++ b/src/axolotl/integrations/liger/plugin.py @@ -0,0 +1,182 @@ +""" +Liger-Kernel Plugin for Axolotl +""" + +import inspect +import sys + +from axolotl.integrations.base import BasePlugin +from axolotl.utils.logging import get_logger + +from .models.base import patch_lce_forward +from .utils import patch_with_compile_disable + +LOG = get_logger(__name__) + + +class LigerPlugin(BasePlugin): + """ + Plugin for LIGER integraton with Axolotl. + """ + + def get_input_args(self): + return "axolotl.integrations.liger.LigerArgs" + + def pre_model_load(self, cfg): + if cfg.torch_compile: + # torch compile will unnecessarily attempt to optimize the triton kernel unless explicitly disabled + import liger_kernel.ops.fused_linear_cross_entropy + + patch_with_compile_disable( + liger_kernel.ops.fused_linear_cross_entropy, + "fused_linear_cross_entropy_forward", + ) + patch_with_compile_disable( + liger_kernel.ops.fused_linear_cross_entropy, + "fused_linear_cross_entropy_backward", + ) + from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss + from liger_kernel.transformers.functional import liger_cross_entropy + from liger_kernel.transformers.layer_norm import LigerLayerNorm + from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN + from liger_kernel.transformers.rms_norm import LigerRMSNorm + from liger_kernel.transformers.rope import liger_rotary_pos_emb + from liger_kernel.transformers.swiglu import LigerSwiGLUMLP + + if cfg.liger_cross_entropy and cfg.liger_fused_linear_cross_entropy: + raise ValueError( + "Cannot have both `liger_cross_entropy` and `liger_fused_linear_cross_entropy` set." + ) + + if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN: + apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type] + liger_fn_sig = inspect.signature(apply_liger_fn) + kwargs = {} + if "rope" in liger_fn_sig.parameters: + kwargs["rope"] = cfg.liger_rope + if "cross_entropy" in liger_fn_sig.parameters: + kwargs["cross_entropy"] = cfg.liger_cross_entropy + if "fused_linear_cross_entropy" in liger_fn_sig.parameters: + kwargs["fused_linear_cross_entropy"] = ( + cfg.liger_fused_linear_cross_entropy + ) + if "rms_norm" in liger_fn_sig.parameters: + kwargs["rms_norm"] = cfg.liger_rms_norm + if "layer_norm" in liger_fn_sig.parameters: + kwargs["layer_norm"] = cfg.liger_layer_norm + if "geglu" in liger_fn_sig.parameters: + kwargs["geglu"] = cfg.liger_glu_activation + elif "swiglu" in liger_fn_sig.parameters: + kwargs["swiglu"] = cfg.liger_glu_activation + LOG.info(f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}") + apply_liger_fn(**kwargs) + elif cfg.model_config_type == "jamba": + from transformers.models.jamba import modeling_jamba + + from .models.jamba import lce_forward as jamba_lce_forward + + if cfg.liger_rope: + modeling_jamba.apply_rotary_pos_emb = liger_rotary_pos_emb + if cfg.liger_rms_norm: + modeling_jamba.JambaRMSNorm = LigerRMSNorm + if cfg.liger_glu_activation: + modeling_jamba.JambaMLP = LigerSwiGLUMLP + if cfg.liger_layer_norm: + modeling_jamba.nn.LayerNorm = LigerLayerNorm + if cfg.liger_cross_entropy: + from transformers.loss.loss_utils import nn + + nn.functional.cross_entropy = liger_cross_entropy + if cfg.liger_fused_linear_cross_entropy: + modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward + elif cfg.model_config_type == "deepseek_v2": + from accelerate import init_empty_weights + from transformers import AutoModelForCausalLM + + with init_empty_weights(): + model = AutoModelForCausalLM.from_pretrained( + cfg.base_model, trust_remote_code=cfg.trust_remote_code or False + ) + modeling_mod = sys.modules[model.__class__.__module__] + + from .models.deepseekv2 import lce_forward as deepseekv2_lce_forward + + if cfg.liger_rope: + # The DeepseekV2 version of RoPE is different than upstream LLaMA. + # See https://github.com/linkedin/Liger-Kernel/issues/129#issuecomment-2313763528 + LOG.warning("Fused liger_rope is not supported for DeepseekV2.") + if cfg.liger_rms_norm: + modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm + if cfg.liger_glu_activation: + modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward + if cfg.liger_layer_norm: + LOG.warning("liger_layer_norm is not supported for DeepseekV2.") + if cfg.liger_cross_entropy: + # We do not patch `nn.functional.cross_entropy` for DeepseekV2 as it still uses + # nn.CrossEntropyLoss in the forward method. + modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss + if cfg.liger_fused_linear_cross_entropy: + modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward + elif cfg.model_config_type == "llama4": + from axolotl.integrations.liger.models.llama4 import ( + apply_liger_kernel_to_llama4, + ) + + apply_liger_kernel_to_llama4( + cross_entropy=cfg.liger_cross_entropy, + fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy, + glu_activation=cfg.liger_glu_activation, + rms_norm=cfg.liger_rms_norm, + layer_norm=cfg.liger_layer_norm, + ) + elif cfg.model_config_type == "qwen3": + from axolotl.integrations.liger.models.qwen3 import ( + apply_liger_kernel_to_qwen3, + ) + + apply_liger_kernel_to_qwen3( + cross_entropy=cfg.liger_cross_entropy, + fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy, + glu_activation=cfg.liger_glu_activation, + rms_norm=cfg.liger_rms_norm, + layer_norm=cfg.liger_layer_norm, + ) + elif cfg.model_config_type == "qwen3_moe": + from axolotl.integrations.liger.models.qwen3_moe import ( + apply_liger_kernel_to_qwen3_moe, + ) + + apply_liger_kernel_to_qwen3_moe( + cross_entropy=cfg.liger_cross_entropy, + fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy, + glu_activation=cfg.liger_glu_activation, + rms_norm=cfg.liger_rms_norm, + layer_norm=cfg.liger_layer_norm, + ) + elif cfg.model_config_type == "granitemoe": + from liger_kernel.transformers import apply_liger_kernel_to_granite + + apply_liger_kernel_to_granite( + rope=cfg.liger_rope, + cross_entropy=cfg.liger_cross_entropy, + fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy, + rms_norm=cfg.liger_rms_norm, + swiglu=cfg.liger_glu_activation, + ) + elif cfg.liger_fused_linear_cross_entropy: + try: + patch_lce_forward(cfg.model_config_type) + LOG.warning_once( + f"Applied ONLY liger_fused_linear_cross_entropy genericpatches for model type: {cfg.model_config_type}" + ) + LOG.warning_once( + f"Liger + {cfg.model_config_type} generic FLCE support is experimental and may not work as expected." + ) + except RuntimeError: + LOG.warning( + f"Unsupported model config type: {cfg.model_config_type}. Liger not applied." + ) + else: + LOG.warning( + f"Unsupported model config type: {cfg.model_config_type}. Liger not applied." + ) diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 84e6b33de..f346c56e0 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -272,7 +272,11 @@ class PatchManager: if self.cfg.tiled_mlp: from axolotl.monkeypatch.tiled_mlp import patch_tiled_mlp - patch_tiled_mlp(model_type, cfg_num_shards=self.cfg.tiled_mlp_num_shards) + patch_tiled_mlp( + model_type, + use_original_mlp=self.cfg.tiled_mlp_use_original_mlp, + cfg_num_shards=self.cfg.tiled_mlp_num_shards, + ) def _patch_attention(self): """Apply attention-specific patches based on model type.""" diff --git a/src/axolotl/monkeypatch/lora_kernels.py b/src/axolotl/monkeypatch/lora_kernels.py index 586412dd7..4702ad19d 100644 --- a/src/axolotl/monkeypatch/lora_kernels.py +++ b/src/axolotl/monkeypatch/lora_kernels.py @@ -18,6 +18,7 @@ from axolotl.kernels.lora import ( apply_lora_qkv, ) from axolotl.monkeypatch.utils import detab_code +from axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix from axolotl.utils.dict import DictDefault from axolotl.utils.logging import get_logger @@ -153,9 +154,7 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]: try: # Dynamically import the module and attention class module_path = f"transformers.models.{model_type}.modeling_{model_type}" - model_cls_prefix = "".join( - [part.capitalize() for part in model_type.split("_")] - ) + model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type) module = __import__(module_path, fromlist=[f"{model_cls_prefix}Attention"]) attention_cls = getattr(module, f"{model_cls_prefix}Attention") diff --git a/src/axolotl/monkeypatch/tiled_mlp.py b/src/axolotl/monkeypatch/tiled_mlp.py index 99a10df9c..3818c6b35 100644 --- a/src/axolotl/monkeypatch/tiled_mlp.py +++ b/src/axolotl/monkeypatch/tiled_mlp.py @@ -6,6 +6,8 @@ import os import torch import torch.distributed as dist +from axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix + def patch_tiled_mlp(model_type, use_original_mlp=False, cfg_num_shards=None): from deepspeed.runtime.sequence_parallel.ulysses_sp import TiledMLP @@ -13,9 +15,7 @@ def patch_tiled_mlp(model_type, use_original_mlp=False, cfg_num_shards=None): try: # Dynamically import the module and MLP class module_path = f"transformers.models.{model_type}.modeling_{model_type}" - model_cls_prefix = "".join( - [part.capitalize() for part in model_type.split("_")] - ) + model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type) module = __import__(module_path, fromlist=[f"{model_cls_prefix}MLP"]) mlp_cls = getattr(module, f"{model_cls_prefix}MLP") @@ -45,11 +45,12 @@ def patch_tiled_mlp(model_type, use_original_mlp=False, cfg_num_shards=None): else: num_shards = cfg_num_shards - compute_params = [ - self.down_proj.weight, - self.gate_proj.weight, - self.up_proj.weight, - ] + if not self._compute_params: # pylint: disable=protected-access + self._compute_params = [ # pylint: disable=protected-access + p for p in self.parameters() if p.requires_grad + ] + + compute_params = self._compute_params # pylint: disable=protected-access down_res = TiledMLP.apply( mlp_forward, @@ -61,6 +62,7 @@ def patch_tiled_mlp(model_type, use_original_mlp=False, cfg_num_shards=None): return down_res mlp_cls.forward = tiled_mlp_forward + mlp_cls._compute_params = [] # pylint: disable=protected-access except (ImportError, AttributeError) as e: raise RuntimeError( f"Could not import MLP class for model_type: {model_type}. " diff --git a/src/axolotl/utils/callbacks/models.py b/src/axolotl/utils/callbacks/models.py new file mode 100644 index 000000000..5a20d70d9 --- /dev/null +++ b/src/axolotl/utils/callbacks/models.py @@ -0,0 +1,23 @@ +"""Helper functions for model classes""" + +from typing import Tuple + +from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES + + +def get_causal_lm_model_cls_prefix(model_type: str) -> Tuple[str, str]: + if model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: + causal_lm_cls = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type] + causal_lm_cls_prefix = causal_lm_cls + for suffix in [ + "ForCausalLM", + "ForConditionalGeneration", + "LMHeadModel", + "GenerationDecoder", + ]: + causal_lm_cls_prefix = causal_lm_cls_prefix.replace(suffix, "") + return causal_lm_cls_prefix, causal_lm_cls + causal_lm_cls_prefix = "".join( + [part.capitalize() for part in model_type.split("_")] + ) + return causal_lm_cls_prefix, f"{causal_lm_cls_prefix}ForCausalLM" diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index e20cdaf47..06212a27f 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -576,6 +576,13 @@ class AxolotlInputConfig( }, ) + tiled_mlp_use_original_mlp: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Whether to use original mlp for ALST tiled mlp. Otherwise uses a generic MLP based on llama." + }, + ) + llama4_linearized_experts: bool | None = None deepspeed: str | dict[str, Any] | None = Field(