fix residuals and add llama support

This commit is contained in:
Wing Lian
2025-07-30 10:22:38 -04:00
parent fbe1b504da
commit dfa14f87ab
4 changed files with 109 additions and 2 deletions

View File

@@ -21,7 +21,7 @@ class Gemma3AddRMSNorm(LigerFusedAddRMSNorm):
"""
def __init__(self, hidden_size: int, eps: float = 1e-6):
super().__init__(hidden_size, eps, casting_mode="gemma")
super().__init__(hidden_size, eps, offset=1.0, casting_mode="gemma")
class Gemma3DecoderLayer(GradientCheckpointingLayer):
@@ -63,6 +63,7 @@ class Gemma3DecoderLayer(GradientCheckpointingLayer):
) -> tuple[
torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor | None] | None
]:
# pylint: disable=duplicate-code
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
@@ -85,7 +86,9 @@ class Gemma3DecoderLayer(GradientCheckpointingLayer):
**kwargs,
)
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.pre_feedforward_layernorm(hidden_states, residual)
hidden_states, residual = self.pre_feedforward_layernorm(
hidden_states, residual
)
hidden_states = self.mlp(hidden_states)
hidden_states = self.post_feedforward_layernorm(hidden_states)
hidden_states = residual + hidden_states

View File

@@ -0,0 +1,9 @@
"""
Llama modeling
"""
from modeling_llama import patch_llama
__all__ = [
"patch_llama",
]

View File

@@ -0,0 +1,93 @@
"""
Custom modeling for Llama for fused rms add kernels
"""
import sys
import torch
from liger_kernel.transformers.fused_add_rms_norm import LigerFusedAddRMSNorm
from transformers import Cache, GradientCheckpointingLayer
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaMLP,
LlamaRMSNorm,
)
from transformers.processing_utils import Unpack
from transformers.utils import TransformersKwargs
class LlamaAddRMSNorm(LigerFusedAddRMSNorm):
"""
Fused add rms norm
"""
def __init__(self, hidden_size: int, eps: float = 1e-6):
super().__init__(hidden_size, eps, casting_mode="llama")
class LlamaDecoderLayer(GradientCheckpointingLayer):
"""
Llama decoder layer using liger fused add rms norm
"""
def __init__(self, config: LlamaConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx)
self.mlp = LlamaMLP(config)
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaAddRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_value: Cache | None = None,
output_attentions: bool | None = False,
use_cache: bool | None = False,
cache_position: torch.LongTensor | None = None,
position_embeddings: (
tuple[torch.Tensor, torch.Tensor] | None
) = None, # necessary, but kept here for BC
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
# pylint: disable=duplicate-code
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,) # type: ignore
return outputs # type: ignore
def patch_llama():
import transformers.models.llama.modeling_llama
transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
sys.modules["transformers.models.llama.modeling_llama"].LlamaDecoderLayer = (
LlamaDecoderLayer
)

View File

@@ -16,5 +16,7 @@ class AxolotlModelingPlugin(BasePlugin):
def register(self, cfg): # pylint: disable=unused-argument
if cfg.use_liger_fused_rms_add:
from .gemma3 import patch_gemma3
from .llama import patch_llama
patch_gemma3()
patch_llama()