Compare commits

...

4 Commits

Author SHA1 Message Date
Wing Lian
08aa74e418 fix llama modeling 2025-07-30 11:37:58 -04:00
Wing Lian
dfa14f87ab fix residuals and add llama support 2025-07-30 10:22:38 -04:00
Wing Lian
fbe1b504da add custom modeling for gemma3 using liger fused add rms 2025-07-30 08:21:03 -04:00
Wing Lian
5b8370969c actually call the register method on plugins 2025-07-30 08:05:25 -04:00
8 changed files with 263 additions and 0 deletions

View File

@@ -159,6 +159,9 @@ def plugin_set_cfg(cfg: DictDefault):
if cfg.get("plugins"):
plugin_manager = PluginManager.get_instance()
plugin_manager.cfg = cfg
# now that we have the finalized cfg, register the plugins individually
for plugin in plugin_manager.plugins.values():
plugin.register(cfg)
def load_cfg(

View File

@@ -0,0 +1,11 @@
"""
Axolotl custom modeling module
"""
from .args import AxolotlModelingArgs
from .plugin import AxolotlModelingPlugin
__all__ = [
"AxolotlModelingArgs",
"AxolotlModelingPlugin",
]

View File

@@ -0,0 +1,13 @@
"""
Args for using Axolotl custom modeling
"""
from pydantic import BaseModel
class AxolotlModelingArgs(BaseModel):
"""
Args for using Axolotl custom modeling
"""
use_liger_fused_rms_add: bool = False

View File

@@ -0,0 +1,9 @@
"""
Gemma3 modeling
"""
from .modeling_gemma3 import patch_gemma3
__all__ = [
"patch_gemma3",
]

View File

@@ -0,0 +1,110 @@
"""
Gemma3 custom decoder layer using liger fused add rms norm kernels
"""
import sys
import torch
from liger_kernel.transformers.fused_add_rms_norm import LigerFusedAddRMSNorm
from transformers import Cache, GradientCheckpointingLayer
from transformers.models.gemma3.configuration_gemma3 import Gemma3TextConfig
from transformers.models.gemma3.modeling_gemma3 import (
Gemma3Attention,
Gemma3MLP,
Gemma3RMSNorm,
)
class Gemma3AddRMSNorm(LigerFusedAddRMSNorm):
"""
Fused add rms norm
"""
def __init__(self, hidden_size: int, eps: float = 1e-6):
super().__init__(hidden_size, eps, offset=1.0, casting_mode="gemma")
class Gemma3DecoderLayer(GradientCheckpointingLayer):
"""
Gemma3 decoder layer using liger fused add rms norm
"""
def __init__(self, config: Gemma3TextConfig, layer_idx: int):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.layer_idx = layer_idx
self.attention_type = config.layer_types[layer_idx]
self.self_attn = Gemma3Attention(config=config, layer_idx=layer_idx)
self.mlp = Gemma3MLP(config)
self.input_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Gemma3RMSNorm(
self.hidden_size, eps=config.rms_norm_eps
)
self.pre_feedforward_layernorm = Gemma3AddRMSNorm(
self.hidden_size, eps=config.rms_norm_eps
)
self.post_feedforward_layernorm = Gemma3RMSNorm(
self.hidden_size, eps=config.rms_norm_eps
)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings_global: torch.Tensor,
position_embeddings_local: torch.Tensor,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_value: Cache | None = None,
output_attentions: bool | None = False,
use_cache: bool | None = False,
cache_position: torch.LongTensor | None = None,
**kwargs,
) -> tuple[
torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor | None] | None
]:
# pylint: disable=duplicate-code
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# apply global RoPE to non-sliding layer only
if self.self_attn.is_sliding:
position_embeddings = position_embeddings_local
else:
position_embeddings = position_embeddings_global
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states, residual = self.pre_feedforward_layernorm(
hidden_states, residual
)
hidden_states = self.mlp(hidden_states)
hidden_states = self.post_feedforward_layernorm(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,) # type: ignore
return outputs # type: ignore
def patch_gemma3():
import transformers.models.gemma3.modeling_gemma3
transformers.models.gemma3.modeling_gemma3.Gemma3DecoderLayer = Gemma3DecoderLayer
sys.modules["transformers.models.gemma3.modeling_gemma3"].Gemma3DecoderLayer = (
Gemma3DecoderLayer
)

View File

@@ -0,0 +1,9 @@
"""
Llama modeling
"""
from modeling_llama import patch_llama
__all__ = [
"patch_llama",
]

View File

@@ -0,0 +1,86 @@
"""
Custom modeling for Llama for fused rms add kernels
"""
import sys
import torch
from liger_kernel.transformers.fused_add_rms_norm import LigerFusedAddRMSNorm
from transformers import Cache, GradientCheckpointingLayer
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaMLP,
LlamaRMSNorm,
)
from transformers.processing_utils import Unpack
from transformers.utils import TransformersKwargs
class LlamaAddRMSNorm(LigerFusedAddRMSNorm):
"""
Fused add rms norm
"""
def __init__(self, hidden_size: int, eps: float = 1e-6):
super().__init__(hidden_size, eps, casting_mode="llama")
class LlamaDecoderLayer(GradientCheckpointingLayer):
"""
Llama decoder layer using liger fused add rms norm
"""
def __init__(self, config: LlamaConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx)
self.mlp = LlamaMLP(config)
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaAddRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_value: Cache | None = None,
use_cache: bool | None = False,
cache_position: torch.LongTensor | None = None,
position_embeddings: (
tuple[torch.Tensor, torch.Tensor] | None
) = None, # necessary, but kept here for BC
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
# pylint: disable=duplicate-code
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
def patch_llama():
import transformers.models.llama.modeling_llama
transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
sys.modules["transformers.models.llama.modeling_llama"].LlamaDecoderLayer = (
LlamaDecoderLayer
)

View File

@@ -0,0 +1,22 @@
"""
Axolotl custom modeling plugin
"""
from axolotl.integrations.base import BasePlugin
class AxolotlModelingPlugin(BasePlugin):
"""
Axolotl custom modeling plugin
"""
def get_input_args(self) -> str | None:
return "axolotl.integrations.modeling.AxolotlModelingArgs"
def register(self, cfg): # pylint: disable=unused-argument
if cfg.use_liger_fused_rms_add:
from .gemma3 import patch_gemma3
from .llama import patch_llama
patch_gemma3()
patch_llama()