Apply generic fused liger ce, cce, and tiledmlp for arbitrary models (#2908)

* Apply generic fused liger ce for unknown models

* fix deepseek liger modeling

* generic cce and config tiled mlp to use original mlp and auto detect compute params

* fix weight and lint

* update warnings

* address PR feedback

* use lookup for model class prefixes

* revert inadvertent change to flash attn verison

* remove un-needed pylint annotations

* fix import
This commit is contained in:
Wing Lian
2025-07-15 22:40:41 -04:00
committed by GitHub
parent 942005f526
commit 2c408b5c5e
10 changed files with 475 additions and 179 deletions

View File

@@ -19,11 +19,13 @@ Cut Cross Entropy is an optimized implementation of cross entropy loss
from Apple's ML team.
"""
import importlib
from functools import partial
import torch
from axolotl.integrations.base import BasePlugin
from axolotl.utils import get_pytorch_version
from axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix
from axolotl.utils.logging import get_logger
from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401
@@ -84,6 +86,7 @@ class CutCrossEntropyPlugin(BasePlugin):
"""Apply cut cross entropy before model loading if enabled."""
if cfg.cut_cross_entropy:
self._check_requirements()
self.patch_llama_like(cfg.model_config_type)
from cut_cross_entropy.transformers.patch import cce_patch
@@ -93,3 +96,48 @@ class CutCrossEntropyPlugin(BasePlugin):
# The patch checks model_type internally
cce_patch(cfg.model_config_type)
def patch_llama_like(
self,
model_type: str,
) -> None:
"""
Generic patch for model architectures with causal lm similar to llama
"""
from cut_cross_entropy.transformers.patch import PATCH_FNS
def patch_generic(
maybe_model, patch_options, model_type: str
): # pylint: disable=unused-argument
import cut_cross_entropy.transformers.llama
from cut_cross_entropy.transformers.llama import cce_forward
try:
# Dynamically import the module and CausalLM class
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type)
module = __import__(
module_path, fromlist=[f"{model_cls_prefix}ForCausalLM"]
)
model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM")
cut_cross_entropy.transformers.llama._PATCH_OPTS = ( # pylint: disable=protected-access
patch_options
)
model_cls.forward = cce_forward
# pylint: disable=duplicate-code
except (ImportError, AttributeError) as e:
raise RuntimeError(
f"Could not import ForCausalLM class for model_type: {model_type}. "
f"Error: {str(e)}"
) from e
if model_type not in PATCH_FNS:
LOG.warning_once(
"Setting up generic cce patch for model type: %s", model_type
)
LOG.warning_once(
f"Generic Cut Cross Entropy + {model_type} support is experimental and may not work as expected."
)
PATCH_FNS[model_type] = partial(patch_generic, model_type=model_type)

View File

@@ -22,6 +22,8 @@ except ImportError:
TransformersKwargs,
)
from axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix
def kldiv_forward_llama_like(
self,
@@ -97,7 +99,7 @@ def kldiv_forward_llama_like(
def apply_kernel(model_type):
# Dynamically import the module and attention class
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
model_cls_prefix = "".join([part.capitalize() for part in model_type.split("_")])
model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type)
module = __import__(module_path, fromlist=[f"{model_cls_prefix}ForCausalLM"])
model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM")
model_cls.forward = kldiv_forward_llama_like

View File

@@ -18,170 +18,10 @@ Module for the Plugin for LIGER integraton with Axolotl.
Liger Kernel is the collection of Triton-native kernels for LLM Training.
It is designed to be performant, correct, and light-weight.
"""
import inspect
import sys
from .args import LigerArgs
from .plugin import LigerPlugin
from axolotl.integrations.base import BasePlugin
from axolotl.utils.logging import get_logger
from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401
from .utils import patch_with_compile_disable
LOG = get_logger(__name__)
class LigerPlugin(BasePlugin):
"""
Plugin for LIGER integraton with Axolotl.
"""
def get_input_args(self):
return "axolotl.integrations.liger.LigerArgs"
def pre_model_load(self, cfg):
if cfg.torch_compile:
# torch compile will unnecessarily attempt to optimize the triton kernel unless explicitly disabled
import liger_kernel.ops.fused_linear_cross_entropy
patch_with_compile_disable(
liger_kernel.ops.fused_linear_cross_entropy,
"fused_linear_cross_entropy_forward",
)
patch_with_compile_disable(
liger_kernel.ops.fused_linear_cross_entropy,
"fused_linear_cross_entropy_backward",
)
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.layer_norm import LigerLayerNorm
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.rope import liger_rotary_pos_emb
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
if cfg.liger_cross_entropy and cfg.liger_fused_linear_cross_entropy:
raise ValueError(
"Cannot have both `liger_cross_entropy` and `liger_fused_linear_cross_entropy` set."
)
if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN:
apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type]
liger_fn_sig = inspect.signature(apply_liger_fn)
kwargs = {}
if "rope" in liger_fn_sig.parameters:
kwargs["rope"] = cfg.liger_rope
if "cross_entropy" in liger_fn_sig.parameters:
kwargs["cross_entropy"] = cfg.liger_cross_entropy
if "fused_linear_cross_entropy" in liger_fn_sig.parameters:
kwargs["fused_linear_cross_entropy"] = (
cfg.liger_fused_linear_cross_entropy
)
if "rms_norm" in liger_fn_sig.parameters:
kwargs["rms_norm"] = cfg.liger_rms_norm
if "layer_norm" in liger_fn_sig.parameters:
kwargs["layer_norm"] = cfg.liger_layer_norm
if "geglu" in liger_fn_sig.parameters:
kwargs["geglu"] = cfg.liger_glu_activation
elif "swiglu" in liger_fn_sig.parameters:
kwargs["swiglu"] = cfg.liger_glu_activation
LOG.info(f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}")
apply_liger_fn(**kwargs)
elif cfg.model_config_type == "jamba":
from transformers.models.jamba import modeling_jamba
from .models.jamba import lce_forward as jamba_lce_forward
if cfg.liger_rope:
modeling_jamba.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_jamba.JambaRMSNorm = LigerRMSNorm
if cfg.liger_glu_activation:
modeling_jamba.JambaMLP = LigerSwiGLUMLP
if cfg.liger_layer_norm:
modeling_jamba.nn.LayerNorm = LigerLayerNorm
if cfg.liger_cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if cfg.liger_fused_linear_cross_entropy:
modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward
elif cfg.model_config_type == "deepseek_v2":
from accelerate import init_empty_weights
from transformers import AutoModelForCausalLM
with init_empty_weights():
model = AutoModelForCausalLM.from_pretrained(
cfg.base_model, trust_remote_code=cfg.trust_remote_code or False
)
modeling_mod = sys.modules[model.__class__.__module__]
from .models.deepseekv2 import lce_forward as deepseekv2_lce_forward
if cfg.liger_rope:
# The DeepseekV2 version of RoPE is different than upstream LLaMA.
# See https://github.com/linkedin/Liger-Kernel/issues/129#issuecomment-2313763528
LOG.warning("Fused liger_rope is not supported for DeepseekV2.")
if cfg.liger_glu_activation:
LOG.warning("liger_glu_activation is not supported for DeepseekV2.")
if cfg.liger_rms_norm:
modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm
if cfg.liger_glu_activation:
modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward
if cfg.liger_layer_norm:
modeling_mod.DeepseekV2MLP.forward = LigerLayerNorm.forward
if cfg.liger_cross_entropy:
# We do not patch `nn.functional.cross_entropy` for DeepseekV2 as it still uses
# nn.CrossEntropyLoss in the forward method.
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
elif cfg.model_config_type == "llama4":
from axolotl.integrations.liger.models.llama4 import (
apply_liger_kernel_to_llama4,
)
apply_liger_kernel_to_llama4(
cross_entropy=cfg.liger_cross_entropy,
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
glu_activation=cfg.liger_glu_activation,
rms_norm=cfg.liger_rms_norm,
layer_norm=cfg.liger_layer_norm,
)
elif cfg.model_config_type == "qwen3":
from axolotl.integrations.liger.models.qwen3 import (
apply_liger_kernel_to_qwen3,
)
apply_liger_kernel_to_qwen3(
cross_entropy=cfg.liger_cross_entropy,
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
glu_activation=cfg.liger_glu_activation,
rms_norm=cfg.liger_rms_norm,
layer_norm=cfg.liger_layer_norm,
)
elif cfg.model_config_type == "qwen3_moe":
from axolotl.integrations.liger.models.qwen3_moe import (
apply_liger_kernel_to_qwen3_moe,
)
apply_liger_kernel_to_qwen3_moe(
cross_entropy=cfg.liger_cross_entropy,
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
glu_activation=cfg.liger_glu_activation,
rms_norm=cfg.liger_rms_norm,
layer_norm=cfg.liger_layer_norm,
)
elif cfg.model_config_type == "granitemoe":
from liger_kernel.transformers import apply_liger_kernel_to_granite
apply_liger_kernel_to_granite(
rope=cfg.liger_rope,
cross_entropy=cfg.liger_cross_entropy,
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
rms_norm=cfg.liger_rms_norm,
swiglu=cfg.liger_glu_activation,
)
else:
LOG.warning(
f"Unsupported model config type: {cfg.model_config_type}. Liger not applied."
)
__all__ = [
"LigerArgs",
"LigerPlugin",
]

View File

@@ -0,0 +1,189 @@
"""
Generic FLCE patch for untested models similar to Llama
"""
from typing import Optional, Tuple, Union
import torch
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
from liger_kernel.transformers.trainer.orpo_trainer import _FSDPForwardRedirection
from liger_kernel.utils import PEFT_AVAILABLE
from peft.utils import ModulesToSaveWrapper
from torch.distributed.fsdp import FullyShardedDataParallel
from transformers.modeling_outputs import CausalLMOutputWithPast
from axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix
def lce_forward(
self,
*args,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
skip_logits: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
"""
# pylint: disable=duplicate-code
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
*args,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
**kwargs,
)
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = (
slice(-logits_to_keep, None)
if isinstance(logits_to_keep, int)
else logits_to_keep
)
kept_hidden_states = hidden_states[:, slice_indices, :]
shift_labels = kwargs.pop("shift_labels", None)
logits = None
loss = None
# if in training mode, don't materialize logits
if skip_logits and labels is None and shift_labels is None:
raise ValueError("skip_logits is True, but labels and shift_labels are None")
if skip_logits is None:
# By default, if in training mode, don't materialize logits
skip_logits = self.training and (labels is not None or shift_labels is not None)
if skip_logits:
loss = lce_maybe_trainable_lm_head(
self,
hidden_states=kept_hidden_states,
hidden_size=self.config.hidden_size,
labels=labels,
shift_labels=shift_labels,
**kwargs,
)
else:
logits = self.lm_head(kept_hidden_states)
if labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
vocab_size=self.config.vocab_size,
**kwargs,
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def lce_maybe_trainable_lm_head(
self, hidden_states, hidden_size, labels, shift_labels, **loss_kwargs
):
lm_head = self.lm_head
# Unwrap the module if lm_head has been added as trainable module in PEFT LoRA configuration,
# i.e. listed in the modules_to_save field of LoraConfig, so the lm_head weights are read
# from the unwrapped module.
# See https://huggingface.co/docs/peft/package_reference/lora for reference.
if PEFT_AVAILABLE and isinstance(lm_head, ModulesToSaveWrapper):
lm_head = lm_head.modules_to_save.default
# If FSDP is used and lm_head is trainable, e.g., during full fine-tuning or with LoRA,
# reading the lm_head module weights and calling the kernel must be done within FSDP forward pass
# so the module entire parameters are summoned and kept in memory during the kernel execution.
if isinstance(lm_head, FullyShardedDataParallel):
return _FSDPForwardRedirection()(
lm_head,
_liger_for_causal_lm_loss,
lm_head.module,
hidden_states,
hidden_size,
labels,
shift_labels,
**loss_kwargs,
)
# FSDP is not used so we can read the lm_head weights and call the kernel directly
return _liger_for_causal_lm_loss(
lm_head=self.lm_head,
hidden_states=hidden_states,
hidden_size=hidden_size,
labels=labels,
shift_labels=shift_labels,
**loss_kwargs,
)
def _liger_for_causal_lm_loss(
lm_head, hidden_states, hidden_size, labels, shift_labels, **loss_kwargs
):
return LigerForCausalLMLoss(
hidden_states=hidden_states,
lm_head_weight=lm_head.weight,
labels=labels,
hidden_size=hidden_size,
shift_labels=shift_labels,
**loss_kwargs,
)
def patch_lce_forward(
model_type,
):
try:
# Dynamically import the module and MLP class
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type)
module = __import__(module_path, fromlist=[f"{model_cls_prefix}ForCausalLM"])
model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM")
model_cls.forward = lce_forward
# pylint: disable=duplicate-code
except (ImportError, AttributeError) as e:
raise RuntimeError(
f"Could not import ForCausalLM class for model_type: {model_type}. "
f"Error: {str(e)}"
) from e

View File

@@ -0,0 +1,182 @@
"""
Liger-Kernel Plugin for Axolotl
"""
import inspect
import sys
from axolotl.integrations.base import BasePlugin
from axolotl.utils.logging import get_logger
from .models.base import patch_lce_forward
from .utils import patch_with_compile_disable
LOG = get_logger(__name__)
class LigerPlugin(BasePlugin):
"""
Plugin for LIGER integraton with Axolotl.
"""
def get_input_args(self):
return "axolotl.integrations.liger.LigerArgs"
def pre_model_load(self, cfg):
if cfg.torch_compile:
# torch compile will unnecessarily attempt to optimize the triton kernel unless explicitly disabled
import liger_kernel.ops.fused_linear_cross_entropy
patch_with_compile_disable(
liger_kernel.ops.fused_linear_cross_entropy,
"fused_linear_cross_entropy_forward",
)
patch_with_compile_disable(
liger_kernel.ops.fused_linear_cross_entropy,
"fused_linear_cross_entropy_backward",
)
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.layer_norm import LigerLayerNorm
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.rope import liger_rotary_pos_emb
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
if cfg.liger_cross_entropy and cfg.liger_fused_linear_cross_entropy:
raise ValueError(
"Cannot have both `liger_cross_entropy` and `liger_fused_linear_cross_entropy` set."
)
if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN:
apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type]
liger_fn_sig = inspect.signature(apply_liger_fn)
kwargs = {}
if "rope" in liger_fn_sig.parameters:
kwargs["rope"] = cfg.liger_rope
if "cross_entropy" in liger_fn_sig.parameters:
kwargs["cross_entropy"] = cfg.liger_cross_entropy
if "fused_linear_cross_entropy" in liger_fn_sig.parameters:
kwargs["fused_linear_cross_entropy"] = (
cfg.liger_fused_linear_cross_entropy
)
if "rms_norm" in liger_fn_sig.parameters:
kwargs["rms_norm"] = cfg.liger_rms_norm
if "layer_norm" in liger_fn_sig.parameters:
kwargs["layer_norm"] = cfg.liger_layer_norm
if "geglu" in liger_fn_sig.parameters:
kwargs["geglu"] = cfg.liger_glu_activation
elif "swiglu" in liger_fn_sig.parameters:
kwargs["swiglu"] = cfg.liger_glu_activation
LOG.info(f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}")
apply_liger_fn(**kwargs)
elif cfg.model_config_type == "jamba":
from transformers.models.jamba import modeling_jamba
from .models.jamba import lce_forward as jamba_lce_forward
if cfg.liger_rope:
modeling_jamba.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_jamba.JambaRMSNorm = LigerRMSNorm
if cfg.liger_glu_activation:
modeling_jamba.JambaMLP = LigerSwiGLUMLP
if cfg.liger_layer_norm:
modeling_jamba.nn.LayerNorm = LigerLayerNorm
if cfg.liger_cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if cfg.liger_fused_linear_cross_entropy:
modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward
elif cfg.model_config_type == "deepseek_v2":
from accelerate import init_empty_weights
from transformers import AutoModelForCausalLM
with init_empty_weights():
model = AutoModelForCausalLM.from_pretrained(
cfg.base_model, trust_remote_code=cfg.trust_remote_code or False
)
modeling_mod = sys.modules[model.__class__.__module__]
from .models.deepseekv2 import lce_forward as deepseekv2_lce_forward
if cfg.liger_rope:
# The DeepseekV2 version of RoPE is different than upstream LLaMA.
# See https://github.com/linkedin/Liger-Kernel/issues/129#issuecomment-2313763528
LOG.warning("Fused liger_rope is not supported for DeepseekV2.")
if cfg.liger_rms_norm:
modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm
if cfg.liger_glu_activation:
modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward
if cfg.liger_layer_norm:
LOG.warning("liger_layer_norm is not supported for DeepseekV2.")
if cfg.liger_cross_entropy:
# We do not patch `nn.functional.cross_entropy` for DeepseekV2 as it still uses
# nn.CrossEntropyLoss in the forward method.
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
elif cfg.model_config_type == "llama4":
from axolotl.integrations.liger.models.llama4 import (
apply_liger_kernel_to_llama4,
)
apply_liger_kernel_to_llama4(
cross_entropy=cfg.liger_cross_entropy,
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
glu_activation=cfg.liger_glu_activation,
rms_norm=cfg.liger_rms_norm,
layer_norm=cfg.liger_layer_norm,
)
elif cfg.model_config_type == "qwen3":
from axolotl.integrations.liger.models.qwen3 import (
apply_liger_kernel_to_qwen3,
)
apply_liger_kernel_to_qwen3(
cross_entropy=cfg.liger_cross_entropy,
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
glu_activation=cfg.liger_glu_activation,
rms_norm=cfg.liger_rms_norm,
layer_norm=cfg.liger_layer_norm,
)
elif cfg.model_config_type == "qwen3_moe":
from axolotl.integrations.liger.models.qwen3_moe import (
apply_liger_kernel_to_qwen3_moe,
)
apply_liger_kernel_to_qwen3_moe(
cross_entropy=cfg.liger_cross_entropy,
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
glu_activation=cfg.liger_glu_activation,
rms_norm=cfg.liger_rms_norm,
layer_norm=cfg.liger_layer_norm,
)
elif cfg.model_config_type == "granitemoe":
from liger_kernel.transformers import apply_liger_kernel_to_granite
apply_liger_kernel_to_granite(
rope=cfg.liger_rope,
cross_entropy=cfg.liger_cross_entropy,
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
rms_norm=cfg.liger_rms_norm,
swiglu=cfg.liger_glu_activation,
)
elif cfg.liger_fused_linear_cross_entropy:
try:
patch_lce_forward(cfg.model_config_type)
LOG.warning_once(
f"Applied ONLY liger_fused_linear_cross_entropy genericpatches for model type: {cfg.model_config_type}"
)
LOG.warning_once(
f"Liger + {cfg.model_config_type} generic FLCE support is experimental and may not work as expected."
)
except RuntimeError:
LOG.warning(
f"Unsupported model config type: {cfg.model_config_type}. Liger not applied."
)
else:
LOG.warning(
f"Unsupported model config type: {cfg.model_config_type}. Liger not applied."
)

View File

@@ -272,7 +272,11 @@ class PatchManager:
if self.cfg.tiled_mlp:
from axolotl.monkeypatch.tiled_mlp import patch_tiled_mlp
patch_tiled_mlp(model_type, cfg_num_shards=self.cfg.tiled_mlp_num_shards)
patch_tiled_mlp(
model_type,
use_original_mlp=self.cfg.tiled_mlp_use_original_mlp,
cfg_num_shards=self.cfg.tiled_mlp_num_shards,
)
def _patch_attention(self):
"""Apply attention-specific patches based on model type."""

View File

@@ -18,6 +18,7 @@ from axolotl.kernels.lora import (
apply_lora_qkv,
)
from axolotl.monkeypatch.utils import detab_code
from axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
@@ -153,9 +154,7 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
try:
# Dynamically import the module and attention class
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
model_cls_prefix = "".join(
[part.capitalize() for part in model_type.split("_")]
)
model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type)
module = __import__(module_path, fromlist=[f"{model_cls_prefix}Attention"])
attention_cls = getattr(module, f"{model_cls_prefix}Attention")

View File

@@ -6,6 +6,8 @@ import os
import torch
import torch.distributed as dist
from axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix
def patch_tiled_mlp(model_type, use_original_mlp=False, cfg_num_shards=None):
from deepspeed.runtime.sequence_parallel.ulysses_sp import TiledMLP
@@ -13,9 +15,7 @@ def patch_tiled_mlp(model_type, use_original_mlp=False, cfg_num_shards=None):
try:
# Dynamically import the module and MLP class
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
model_cls_prefix = "".join(
[part.capitalize() for part in model_type.split("_")]
)
model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type)
module = __import__(module_path, fromlist=[f"{model_cls_prefix}MLP"])
mlp_cls = getattr(module, f"{model_cls_prefix}MLP")
@@ -45,11 +45,12 @@ def patch_tiled_mlp(model_type, use_original_mlp=False, cfg_num_shards=None):
else:
num_shards = cfg_num_shards
compute_params = [
self.down_proj.weight,
self.gate_proj.weight,
self.up_proj.weight,
]
if not self._compute_params: # pylint: disable=protected-access
self._compute_params = [ # pylint: disable=protected-access
p for p in self.parameters() if p.requires_grad
]
compute_params = self._compute_params # pylint: disable=protected-access
down_res = TiledMLP.apply(
mlp_forward,
@@ -61,6 +62,7 @@ def patch_tiled_mlp(model_type, use_original_mlp=False, cfg_num_shards=None):
return down_res
mlp_cls.forward = tiled_mlp_forward
mlp_cls._compute_params = [] # pylint: disable=protected-access
except (ImportError, AttributeError) as e:
raise RuntimeError(
f"Could not import MLP class for model_type: {model_type}. "

View File

@@ -0,0 +1,23 @@
"""Helper functions for model classes"""
from typing import Tuple
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
def get_causal_lm_model_cls_prefix(model_type: str) -> Tuple[str, str]:
if model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
causal_lm_cls = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type]
causal_lm_cls_prefix = causal_lm_cls
for suffix in [
"ForCausalLM",
"ForConditionalGeneration",
"LMHeadModel",
"GenerationDecoder",
]:
causal_lm_cls_prefix = causal_lm_cls_prefix.replace(suffix, "")
return causal_lm_cls_prefix, causal_lm_cls
causal_lm_cls_prefix = "".join(
[part.capitalize() for part in model_type.split("_")]
)
return causal_lm_cls_prefix, f"{causal_lm_cls_prefix}ForCausalLM"

View File

@@ -576,6 +576,13 @@ class AxolotlInputConfig(
},
)
tiled_mlp_use_original_mlp: bool | None = Field(
default=None,
json_schema_extra={
"description": "Whether to use original mlp for ALST tiled mlp. Otherwise uses a generic MLP based on llama."
},
)
llama4_linearized_experts: bool | None = None
deepspeed: str | dict[str, Any] | None = Field(