Gemma4 fixes and profiler (#3591)

This commit is contained in:
Wing Lian
2026-04-10 16:46:17 -04:00
committed by GitHub
parent 315cdeede9
commit 29fa4dedbb
10 changed files with 1926 additions and 1 deletions

View File

@@ -435,6 +435,23 @@ class AxolotlTrainer(
num_items_in_batch=num_items_in_batch,
)
# Gemma4ForConditionalGeneration computes loss with a manual
# nn.CrossEntropyLoss() that bypasses proper num_items_in_batch
# normalization and does redundant attention_mask filtering.
# Compute loss externally using the standard loss_function instead.
if _model_type == "gemma4" and "labels" in inputs:
labels = inputs.pop("labels")
outputs = model(**inputs)
logits = outputs.logits
unwrapped = self.accelerator.unwrap_model(model)
vocab_size = unwrapped.config.get_text_config().vocab_size
loss = unwrapped.loss_function(
logits, labels, vocab_size, num_items_in_batch=num_items_in_batch
)
if return_outputs:
return loss, outputs
return loss
return super().compute_loss(
model,
inputs,

View File

@@ -222,6 +222,56 @@ class LigerPlugin(BasePlugin):
rms_norm=cfg.liger_rms_norm,
swiglu=cfg.liger_glu_activation,
)
elif cfg.model_config_type in ("gemma4", "gemma4_text"):
# Gemma4: offset=0 (NOT 1 like Gemma3), in_place=False required for
# gradient checkpointing compatibility, RoPE incompatible (separate q/k).
from liger_kernel.transformers.geglu import LigerGEGLUMLP
from transformers.models.gemma4 import modeling_gemma4
if cfg.liger_rms_norm:
_OrigGemma4RMSNorm = modeling_gemma4.Gemma4RMSNorm
class _LigerGemma4RMSNorm(LigerRMSNorm):
"""LigerRMSNorm for Gemma4 with in_place=False and with_scale support."""
def __new__(cls, dim, eps=1e-6, with_scale=True):
if not with_scale:
return _OrigGemma4RMSNorm(dim, eps, with_scale=False)
return super().__new__(cls)
def __init__(self, dim, eps=1e-6, with_scale=True):
if not with_scale:
return
# offset=0.0 (standard), in_place=False (gradient checkpointing safe)
super().__init__(
dim, eps, offset=0.0, casting_mode="llama", in_place=False
)
modeling_gemma4.Gemma4RMSNorm = _LigerGemma4RMSNorm
if cfg.liger_glu_activation:
class _LigerGemma4MLP(LigerGEGLUMLP):
def __init__(self, config, layer_idx=None):
super().__init__(config)
modeling_gemma4.Gemma4TextMLP = _LigerGemma4MLP
if cfg.liger_rope:
LOG.warning(
"Liger RoPE is not compatible with Gemma4 (separate q/k application). Skipping."
)
if cfg.liger_layer_norm:
modeling_gemma4.nn.LayerNorm = LigerLayerNorm
if cfg.liger_cross_entropy:
modeling_gemma4.nn.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
LOG.warning(
"Liger fused linear cross entropy is not compatible with Gemma4. Skipping."
)
LOG.info(
f"Applied Liger kernels for gemma4: "
f"rms_norm={cfg.liger_rms_norm}, glu={cfg.liger_glu_activation}, "
f"rope=False (incompatible), layer_norm={cfg.liger_layer_norm}"
)
elif cfg.liger_fused_linear_cross_entropy:
try:
from .models.base import patch_lce_forward

View File

@@ -112,6 +112,47 @@ QKV_PATCHES = [
else:
key_states = key_states.view(hidden_shape)
value_states = value_states.view(hidden_shape) if self.v_proj is not None else key_states
""".lstrip("\n"),
),
# Gemma4 (transformers >= 5.6): shared_kv_states parameter replaces
# past_key_values.shared_layers, and v_norm added after k_norm.
(
"""
query_states = self.q_proj(hidden_states).view(hidden_shape)
query_states = self.q_norm(query_states)
query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2)
query_states = query_states.transpose(1, 2)
# For layers with shared KV (from kv sharing point onwards), we reuse the same keys/values states as the last non-sharing layer.
# We cannot simply reuse the cached state if we have a Cache, as sliding layers will not remember the full states in their Cache
# once we are past the sliding window - so we always use `shared_kv_states` instead, even when past_key_values is not None
if self.is_kv_shared_layer:
key_states, value_states = shared_kv_states[self.kv_shared_layer_index]
# Device of past layer may be different from current one
key_states = key_states.to(query_states.device)
value_states = value_states.to(query_states.device)
else:
key_states = self.k_proj(hidden_states).view(hidden_shape)
value_states = self.v_proj(hidden_states).view(hidden_shape) if self.v_proj is not None else key_states
""".lstrip("\n"),
"""
query_states, key_states, value_states = self.apply_qkv(hidden_states)
query_states = query_states.view(hidden_shape)
query_states = self.q_norm(query_states)
query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2)
query_states = query_states.transpose(1, 2)
# For layers with shared KV (from kv sharing point onwards), we reuse the same keys/values states as the last non-sharing layer.
# We cannot simply reuse the cached state if we have a Cache, as sliding layers will not remember the full states in their Cache
# once we are past the sliding window - so we always use `shared_kv_states` instead, even when past_key_values is not None
if self.is_kv_shared_layer:
key_states, value_states = shared_kv_states[self.kv_shared_layer_index]
# Device of past layer may be different from current one
key_states = key_states.to(query_states.device)
value_states = value_states.to(query_states.device)
else:
key_states = key_states.view(hidden_shape)
value_states = value_states.view(hidden_shape) if self.v_proj is not None else key_states
""".lstrip("\n"),
),
]