diff --git a/src/axolotl/integrations/modeling/gemma3/modeling_gemma3.py b/src/axolotl/integrations/modeling/gemma3/modeling_gemma3.py index ef8c6444f..f215e8f47 100644 --- a/src/axolotl/integrations/modeling/gemma3/modeling_gemma3.py +++ b/src/axolotl/integrations/modeling/gemma3/modeling_gemma3.py @@ -21,7 +21,7 @@ class Gemma3AddRMSNorm(LigerFusedAddRMSNorm): """ def __init__(self, hidden_size: int, eps: float = 1e-6): - super().__init__(hidden_size, eps, casting_mode="gemma") + super().__init__(hidden_size, eps, offset=1.0, casting_mode="gemma") class Gemma3DecoderLayer(GradientCheckpointingLayer): @@ -63,6 +63,7 @@ class Gemma3DecoderLayer(GradientCheckpointingLayer): ) -> tuple[ torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor | None] | None ]: + # pylint: disable=duplicate-code residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -85,7 +86,9 @@ class Gemma3DecoderLayer(GradientCheckpointingLayer): **kwargs, ) hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.pre_feedforward_layernorm(hidden_states, residual) + hidden_states, residual = self.pre_feedforward_layernorm( + hidden_states, residual + ) hidden_states = self.mlp(hidden_states) hidden_states = self.post_feedforward_layernorm(hidden_states) hidden_states = residual + hidden_states diff --git a/src/axolotl/integrations/modeling/llama/__init__.py b/src/axolotl/integrations/modeling/llama/__init__.py new file mode 100644 index 000000000..aa8a59863 --- /dev/null +++ b/src/axolotl/integrations/modeling/llama/__init__.py @@ -0,0 +1,9 @@ +""" +Llama modeling +""" + +from modeling_llama import patch_llama + +__all__ = [ + "patch_llama", +] diff --git a/src/axolotl/integrations/modeling/llama/modeling_llama.py b/src/axolotl/integrations/modeling/llama/modeling_llama.py new file mode 100644 index 000000000..b2df51242 --- /dev/null +++ b/src/axolotl/integrations/modeling/llama/modeling_llama.py @@ -0,0 +1,93 @@ +""" +Custom modeling for Llama for fused rms add kernels +""" + +import sys + +import torch +from liger_kernel.transformers.fused_add_rms_norm import LigerFusedAddRMSNorm +from transformers import Cache, GradientCheckpointingLayer +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaMLP, + LlamaRMSNorm, +) +from transformers.processing_utils import Unpack +from transformers.utils import TransformersKwargs + + +class LlamaAddRMSNorm(LigerFusedAddRMSNorm): + """ + Fused add rms norm + """ + + def __init__(self, hidden_size: int, eps: float = 1e-6): + super().__init__(hidden_size, eps, casting_mode="llama") + + +class LlamaDecoderLayer(GradientCheckpointingLayer): + """ + Llama decoder layer using liger fused add rms norm + """ + + def __init__(self, config: LlamaConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx) + + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaAddRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_value: Cache | None = None, + output_attentions: bool | None = False, + use_cache: bool | None = False, + cache_position: torch.LongTensor | None = None, + position_embeddings: ( + tuple[torch.Tensor, torch.Tensor] | None + ) = None, # necessary, but kept here for BC + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: + # pylint: disable=duplicate-code + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) # type: ignore + + return outputs # type: ignore + + +def patch_llama(): + import transformers.models.llama.modeling_llama + + transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer + sys.modules["transformers.models.llama.modeling_llama"].LlamaDecoderLayer = ( + LlamaDecoderLayer + ) diff --git a/src/axolotl/integrations/modeling/plugin.py b/src/axolotl/integrations/modeling/plugin.py index 9f4a07720..a4a2532c2 100644 --- a/src/axolotl/integrations/modeling/plugin.py +++ b/src/axolotl/integrations/modeling/plugin.py @@ -16,5 +16,7 @@ class AxolotlModelingPlugin(BasePlugin): def register(self, cfg): # pylint: disable=unused-argument if cfg.use_liger_fused_rms_add: from .gemma3 import patch_gemma3 + from .llama import patch_llama patch_gemma3() + patch_llama()