From fbe1b504dadda617b6acbe001403ea391d88581b Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 30 Jul 2025 08:21:03 -0400 Subject: [PATCH] add custom modeling for gemma3 using liger fused add rms --- src/axolotl/integrations/modeling/__init__.py | 11 ++ src/axolotl/integrations/modeling/args.py | 13 +++ .../integrations/modeling/gemma3/__init__.py | 9 ++ .../modeling/gemma3/modeling_gemma3.py | 107 ++++++++++++++++++ src/axolotl/integrations/modeling/plugin.py | 20 ++++ 5 files changed, 160 insertions(+) create mode 100644 src/axolotl/integrations/modeling/__init__.py create mode 100644 src/axolotl/integrations/modeling/args.py create mode 100644 src/axolotl/integrations/modeling/gemma3/__init__.py create mode 100644 src/axolotl/integrations/modeling/gemma3/modeling_gemma3.py create mode 100644 src/axolotl/integrations/modeling/plugin.py diff --git a/src/axolotl/integrations/modeling/__init__.py b/src/axolotl/integrations/modeling/__init__.py new file mode 100644 index 000000000..704bb6a5c --- /dev/null +++ b/src/axolotl/integrations/modeling/__init__.py @@ -0,0 +1,11 @@ +""" +Axolotl custom modeling module +""" + +from .args import AxolotlModelingArgs +from .plugin import AxolotlModelingPlugin + +__all__ = [ + "AxolotlModelingArgs", + "AxolotlModelingPlugin", +] diff --git a/src/axolotl/integrations/modeling/args.py b/src/axolotl/integrations/modeling/args.py new file mode 100644 index 000000000..b0b0096c8 --- /dev/null +++ b/src/axolotl/integrations/modeling/args.py @@ -0,0 +1,13 @@ +""" +Args for using Axolotl custom modeling +""" + +from pydantic import BaseModel + + +class AxolotlModelingArgs(BaseModel): + """ + Args for using Axolotl custom modeling + """ + + use_liger_fused_rms_add: bool = False diff --git a/src/axolotl/integrations/modeling/gemma3/__init__.py b/src/axolotl/integrations/modeling/gemma3/__init__.py new file mode 100644 index 000000000..2c43f4ed9 --- /dev/null +++ b/src/axolotl/integrations/modeling/gemma3/__init__.py @@ -0,0 +1,9 @@ +""" +Gemma3 modeling +""" + +from .modeling_gemma3 import patch_gemma3 + +__all__ = [ + "patch_gemma3", +] diff --git a/src/axolotl/integrations/modeling/gemma3/modeling_gemma3.py b/src/axolotl/integrations/modeling/gemma3/modeling_gemma3.py new file mode 100644 index 000000000..ef8c6444f --- /dev/null +++ b/src/axolotl/integrations/modeling/gemma3/modeling_gemma3.py @@ -0,0 +1,107 @@ +""" +Gemma3 custom decoder layer using liger fused add rms norm kernels +""" + +import sys + +import torch +from liger_kernel.transformers.fused_add_rms_norm import LigerFusedAddRMSNorm +from transformers import Cache, GradientCheckpointingLayer +from transformers.models.gemma3.configuration_gemma3 import Gemma3TextConfig +from transformers.models.gemma3.modeling_gemma3 import ( + Gemma3Attention, + Gemma3MLP, + Gemma3RMSNorm, +) + + +class Gemma3AddRMSNorm(LigerFusedAddRMSNorm): + """ + Fused add rms norm + """ + + def __init__(self, hidden_size: int, eps: float = 1e-6): + super().__init__(hidden_size, eps, casting_mode="gemma") + + +class Gemma3DecoderLayer(GradientCheckpointingLayer): + """ + Gemma3 decoder layer using liger fused add rms norm + """ + + def __init__(self, config: Gemma3TextConfig, layer_idx: int): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.layer_idx = layer_idx + self.attention_type = config.layer_types[layer_idx] + self.self_attn = Gemma3Attention(config=config, layer_idx=layer_idx) + self.mlp = Gemma3MLP(config) + self.input_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Gemma3RMSNorm( + self.hidden_size, eps=config.rms_norm_eps + ) + self.pre_feedforward_layernorm = Gemma3AddRMSNorm( + self.hidden_size, eps=config.rms_norm_eps + ) + self.post_feedforward_layernorm = Gemma3RMSNorm( + self.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings_global: torch.Tensor, + position_embeddings_local: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_value: Cache | None = None, + output_attentions: bool | None = False, + use_cache: bool | None = False, + cache_position: torch.LongTensor | None = None, + **kwargs, + ) -> tuple[ + torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor | None] | None + ]: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # apply global RoPE to non-sliding layer only + if self.self_attn.is_sliding: + position_embeddings = position_embeddings_local + else: + position_embeddings = position_embeddings_global + + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.pre_feedforward_layernorm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) # type: ignore + + return outputs # type: ignore + + +def patch_gemma3(): + import transformers.models.gemma3.modeling_gemma3 + + transformers.models.gemma3.modeling_gemma3.Gemma3DecoderLayer = Gemma3DecoderLayer + sys.modules["transformers.models.gemma3.modeling_gemma3"].Gemma3DecoderLayer = ( + Gemma3DecoderLayer + ) diff --git a/src/axolotl/integrations/modeling/plugin.py b/src/axolotl/integrations/modeling/plugin.py new file mode 100644 index 000000000..9f4a07720 --- /dev/null +++ b/src/axolotl/integrations/modeling/plugin.py @@ -0,0 +1,20 @@ +""" +Axolotl custom modeling plugin +""" + +from axolotl.integrations.base import BasePlugin + + +class AxolotlModelingPlugin(BasePlugin): + """ + Axolotl custom modeling plugin + """ + + def get_input_args(self) -> str | None: + return "axolotl.integrations.modeling.AxolotlModelingArgs" + + def register(self, cfg): # pylint: disable=unused-argument + if cfg.use_liger_fused_rms_add: + from .gemma3 import patch_gemma3 + + patch_gemma3()