liger support for qwen 3.5 and fused rmsnorm+gated (#3531) [skip ci]

* liger support for qwen 3.5 and fused rmsnorm+gated

* support for qwen 3.5 moe

* fix version ref

* fixups for PR code review
This commit is contained in:
Wing Lian
2026-03-22 13:19:21 -04:00
committed by GitHub
parent 5b2e3f00ce
commit a67392c427
6 changed files with 970 additions and 0 deletions

View File

@@ -30,6 +30,15 @@ class LigerArgs(BaseModel):
liger_rope: bool | None = None
liger_rms_norm: bool | None = None
liger_rms_norm_gated: bool | None = Field(
default=None,
json_schema_extra={
"description": (
"Enables fused RMSNorm+SiLU gate Triton kernel for models with "
"gated RMSNorm (e.g. Qwen3.5 / Qwen3.5 MoE linear attention layers)."
)
},
)
liger_layer_norm: bool | None = None
liger_swiglu: bool | None = None
liger_glu_activation: bool | None = None

View File

@@ -0,0 +1,175 @@
"""
Liger FLCE for Qwen3.5. Based on transformers v5.3.0.
"""
import sys
from copy import deepcopy
from typing import Optional, Union
import torch
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
from transformers.cache_utils import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast
def lce_forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs,
) -> CausalLMOutputWithPast:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
"""
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
logits = None
loss = None
# if in training mode, don't materialize logits
if self.training and (labels is not None):
loss = LigerForCausalLMLoss(
hidden_states=hidden_states,
lm_head_weight=self.lm_head.weight,
labels=labels,
hidden_size=self.config.hidden_size,
**kwargs,
)
else: # if in inference mode materialize logits
slice_indices = (
slice(-logits_to_keep, None)
if isinstance(logits_to_keep, int)
else logits_to_keep
)
logits = self.lm_head(hidden_states[:, slice_indices, :])
if labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
vocab_size=self.config.vocab_size,
**kwargs,
)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def apply_liger_kernel_to_qwen3_5(
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = False,
rms_norm: bool = False,
rms_norm_gated: bool = False,
glu_activation: bool = False,
layer_norm: bool = False,
**kwargs,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Qwen3.5 models.
Note: Qwen3_5RMSNorm uses zero-init weight with offset 1.0 (like Gemma),
so we use LigerRMSNorm with offset=1.0 and init_fn="zeros".
Args:
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is False.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False.
rms_norm_gated (bool): Whether to apply fused RMSNorm+SiLU gate kernel for
Qwen3_5RMSNormGated (used in linear attention layers). Default is False.
glu_activation (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is False.
"""
import transformers.models.qwen3_5.modeling_qwen3_5 # noqa: F401
from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.layer_norm import LigerLayerNorm
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
assert not (cross_entropy and fused_linear_cross_entropy), (
"cross_entropy and fused_linear_cross_entropy cannot both be True."
)
modeling_qwen3_5 = sys.modules["transformers.models.qwen3_5.modeling_qwen3_5"]
if rms_norm:
# Qwen3_5RMSNorm uses zero-init weight with `output * (1.0 + weight)` pattern
class LigerRMSNormForQwen3_5(LigerRMSNorm):
def __init__(self, dim, eps=1e-6, **kwargs):
super().__init__(
dim,
eps=eps,
offset=1.0,
casting_mode="gemma",
init_fn="zeros",
in_place=False,
)
modeling_qwen3_5.Qwen3_5RMSNorm = LigerRMSNormForQwen3_5
if rms_norm_gated:
from axolotl.kernels.rms_norm_gated import FusedRMSNormGated
modeling_qwen3_5.Qwen3_5RMSNormGated = FusedRMSNormGated
if glu_activation:
def _liger_swiglu_mlp_wrapper(config, intermediate_size=None, **kwargs):
"""Accepts intermediate_size to pass to LigerSwiGLUMLP"""
config = deepcopy(config)
if intermediate_size is not None:
config.intermediate_size = intermediate_size
return LigerSwiGLUMLP(config, **kwargs)
modeling_qwen3_5.Qwen3_5MLP = _liger_swiglu_mlp_wrapper
if layer_norm:
modeling_qwen3_5.nn.LayerNorm = LigerLayerNorm
if cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if fused_linear_cross_entropy:
modeling_qwen3_5.Qwen3_5ForCausalLM.forward = lce_forward

View File

@@ -0,0 +1,198 @@
"""
Liger FLCE for Qwen3.5 MoE. Based on transformers v5.3.0.
"""
import sys
from copy import deepcopy
from typing import Optional, Union
import torch
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
from transformers.modeling_outputs import MoeCausalLMOutputWithPast
def lce_forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values=None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_router_logits: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs,
) -> MoeCausalLMOutputWithPast:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
"""
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import (
load_balancing_loss_func,
)
output_router_logits = (
output_router_logits
if output_router_logits is not None
else self.config.output_router_logits
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_router_logits=output_router_logits,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
logits = None
loss = None
# if in training mode, don't materialize logits
if self.training and (labels is not None):
loss = LigerForCausalLMLoss(
hidden_states=hidden_states,
lm_head_weight=self.lm_head.weight,
labels=labels,
hidden_size=self.config.hidden_size,
**kwargs,
)
else: # if in inference mode materialize logits
slice_indices = (
slice(-logits_to_keep, None)
if isinstance(logits_to_keep, int)
else logits_to_keep
)
logits = self.lm_head(hidden_states[:, slice_indices, :])
if labels is not None:
loss = self.loss_function(
logits,
labels,
self.vocab_size,
**kwargs,
)
aux_loss = None
if output_router_logits:
aux_loss = load_balancing_loss_func(
outputs.router_logits,
self.num_experts,
self.num_experts_per_tok,
attention_mask,
)
if labels is not None:
loss += self.router_aux_loss_coef * aux_loss.to(loss.device)
return MoeCausalLMOutputWithPast(
loss=loss,
aux_loss=aux_loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
router_logits=outputs.router_logits,
)
def apply_liger_kernel_to_qwen3_5_moe(
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = False,
rms_norm: bool = False,
rms_norm_gated: bool = False,
glu_activation: bool = False,
layer_norm: bool = False,
**kwargs,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Qwen3.5 MoE models.
Note: Qwen3_5MoeRMSNorm uses zero-init weight with offset 1.0 (like Gemma),
so we use LigerRMSNorm with offset=1.0 and init_fn="zeros".
Args:
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is False.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False.
rms_norm_gated (bool): Whether to apply fused RMSNorm+SiLU gate kernel for
Qwen3_5MoeRMSNormGated (used in linear attention layers). Default is False.
glu_activation (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is False.
"""
import transformers.models.qwen3_5_moe.modeling_qwen3_5_moe # noqa: F401
from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.layer_norm import LigerLayerNorm
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
assert not (cross_entropy and fused_linear_cross_entropy), (
"cross_entropy and fused_linear_cross_entropy cannot both be True."
)
modeling_mod = sys.modules["transformers.models.qwen3_5_moe.modeling_qwen3_5_moe"]
if rms_norm:
# Qwen3_5MoeRMSNorm uses zero-init weight with `output * (1.0 + weight)` pattern
class LigerRMSNormForQwen3_5Moe(LigerRMSNorm):
def __init__(self, dim, eps=1e-6, **kwargs):
super().__init__(
dim,
eps=eps,
offset=1.0,
casting_mode="gemma",
init_fn="zeros",
in_place=False,
)
modeling_mod.Qwen3_5MoeRMSNorm = LigerRMSNormForQwen3_5Moe
if rms_norm_gated:
from axolotl.kernels.rms_norm_gated import FusedRMSNormGated
modeling_mod.Qwen3_5MoeRMSNormGated = FusedRMSNormGated
if glu_activation:
def _liger_swiglu_mlp_wrapper(config, intermediate_size=None, **kwargs):
"""Accepts intermediate_size to pass to LigerSwiGLUMLP"""
config = deepcopy(config)
if intermediate_size is not None:
config.intermediate_size = intermediate_size
return LigerSwiGLUMLP(config, **kwargs)
modeling_mod.Qwen3_5MoeMLP = _liger_swiglu_mlp_wrapper
if layer_norm:
modeling_mod.nn.LayerNorm = LigerLayerNorm
if cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if fused_linear_cross_entropy:
modeling_mod.Qwen3_5MoeForCausalLM.forward = lce_forward

View File

@@ -174,6 +174,19 @@ class LigerPlugin(BasePlugin):
rms_norm=cfg.liger_rms_norm,
layer_norm=cfg.liger_layer_norm,
)
elif cfg.model_config_type == "qwen3_5":
from axolotl.integrations.liger.models.qwen3_5 import (
apply_liger_kernel_to_qwen3_5,
)
apply_liger_kernel_to_qwen3_5(
cross_entropy=cfg.liger_cross_entropy,
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
glu_activation=cfg.liger_glu_activation,
rms_norm=cfg.liger_rms_norm,
rms_norm_gated=getattr(cfg, "liger_rms_norm_gated", False),
layer_norm=cfg.liger_layer_norm,
)
elif cfg.model_config_type == "qwen3_moe":
from axolotl.integrations.liger.models.qwen3_moe import (
apply_liger_kernel_to_qwen3_moe,
@@ -186,6 +199,19 @@ class LigerPlugin(BasePlugin):
rms_norm=cfg.liger_rms_norm,
layer_norm=cfg.liger_layer_norm,
)
elif cfg.model_config_type == "qwen3_5_moe":
from axolotl.integrations.liger.models.qwen3_5_moe import (
apply_liger_kernel_to_qwen3_5_moe,
)
apply_liger_kernel_to_qwen3_5_moe(
cross_entropy=cfg.liger_cross_entropy,
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
glu_activation=cfg.liger_glu_activation,
rms_norm=cfg.liger_rms_norm,
rms_norm_gated=getattr(cfg, "liger_rms_norm_gated", False),
layer_norm=cfg.liger_layer_norm,
)
elif cfg.model_config_type == "granitemoe":
from liger_kernel.transformers import apply_liger_kernel_to_granite

View File

@@ -0,0 +1,333 @@
"""
Fused RMSNorm + SiLU Gate Triton kernel.
Computes: Y = (W + offset) * RMSNorm(X) * silu(G)
where RMSNorm(X) = X / sqrt(mean(X^2) + eps)
and silu(G) = G * sigmoid(G)
Used by Qwen3.5's GatedDeltaNet linear attention layers (Qwen3_5RMSNormGated).
"""
import math
import operator
import torch
import triton
import triton.language as tl
from liger_kernel.ops.utils import (
calculate_settings,
compare_version,
ensure_contiguous,
torch_to_triton_dtype,
)
from liger_kernel.utils import is_npu_available
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
try:
from triton.language.extra.libdevice import rsqrt
except ModuleNotFoundError:
from triton.language.extra.cuda.libdevice import rsqrt
else:
from triton.language.math import rsqrt
@triton.jit
def _rms_norm_gated_forward_kernel(
Y_ptr,
Y_row_stride,
X_ptr,
X_row_stride,
G_ptr,
G_row_stride,
W_ptr,
W_row_stride,
RSTD_ptr,
RSTD_row_stride,
n_cols,
eps,
offset,
BLOCK_SIZE: tl.constexpr,
):
"""
Y = (W + offset) * (X / RMS(X)) * silu(G)
All computation done in fp32 (Gemma-style), result cast to input dtype.
"""
row_idx = tl.program_id(0).to(tl.int64)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
X_row = tl.load(X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0)
G_row = tl.load(G_ptr + row_idx * G_row_stride + col_offsets, mask=mask, other=0)
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
X_row_dtype = X_row.dtype
# Cast everything to fp32
X_fp32 = X_row.to(tl.float32)
G_fp32 = G_row.to(tl.float32)
W_fp32 = W_row.to(tl.float32)
# RMS norm
mean_sq = tl.sum(X_fp32 * X_fp32, axis=0) / n_cols
rstd = rsqrt(mean_sq + eps)
tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd)
X_norm = X_fp32 * rstd
# SiLU gate: silu(G) = G * sigmoid(G)
sig_G = tl.sigmoid(G_fp32)
silu_G = G_fp32 * sig_G
# Fused output
Y_row = (offset + W_fp32) * X_norm * silu_G
tl.store(
Y_ptr + row_idx * Y_row_stride + col_offsets,
Y_row.to(X_row_dtype),
mask=mask,
)
@triton.jit
def _rms_norm_gated_backward_kernel(
dY_ptr,
dY_row_stride,
dX_ptr,
dX_row_stride,
dG_ptr,
dG_row_stride,
X_ptr,
X_row_stride,
X_dtype: tl.constexpr,
G_ptr,
G_row_stride,
W_ptr,
W_row_stride,
RSTD_ptr,
RSTD_row_stride,
dW_ptr,
dW_row_stride,
n_rows,
n_cols,
offset,
rows_per_program,
BLOCK_SIZE: tl.constexpr,
):
"""
Backward for Y = (W + offset) * (X * RSTD) * silu(G)
dW = sum_batch(dY * X_norm * silu(G))
dG = dY * (W + offset) * X_norm * silu'(G)
where silu'(G) = sigmoid(G) * (1 + G * (1 - sigmoid(G)))
dX = RSTD * (m - (1/N) * RSTD^2 * dot(m, X) * X)
where m = dY * (W + offset) * silu(G)
"""
row_block_id = tl.program_id(0).to(tl.int64)
row_start = row_block_id * rows_per_program
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
dW_acc = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
W_row = W_row.to(tl.float32) + offset
for row_idx in range(row_start, row_end):
dY_row = tl.load(
dY_ptr + row_idx * dY_row_stride + col_offsets, mask=mask, other=0.0
)
X_row = tl.load(
X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0.0
)
G_row = tl.load(
G_ptr + row_idx * G_row_stride + col_offsets, mask=mask, other=0.0
)
rstd_row = tl.load(RSTD_ptr + row_idx * RSTD_row_stride)
# Cast to fp32
dY_fp32 = dY_row.to(tl.float32)
X_fp32 = X_row.to(tl.float32)
G_fp32 = G_row.to(tl.float32)
# Recompute intermediates
X_norm = X_fp32 * rstd_row
sig_G = tl.sigmoid(G_fp32)
silu_G = G_fp32 * sig_G
# dW: accumulate dY * X_norm * silu(G)
dW_acc += dY_fp32 * X_norm * silu_G
# dG: dY * (W + offset) * X_norm * silu'(G)
# silu'(G) = sigmoid(G) * (1 + G * (1 - sigmoid(G)))
silu_prime_G = sig_G * (1.0 + G_fp32 * (1.0 - sig_G))
dG_row = dY_fp32 * W_row * X_norm * silu_prime_G
tl.store(
dG_ptr + row_idx * dG_row_stride + col_offsets,
dG_row.to(X_dtype),
mask=mask,
)
# dX: standard RMSNorm backward with effective gradient m = dY * W * silu(G)
m = dY_fp32 * W_row * silu_G
dX_row = rstd_row * m
dX_row += rstd_row * (
-(1.0 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_fp32, axis=0) * X_fp32
)
tl.store(
dX_ptr + row_idx * dX_row_stride + col_offsets,
dX_row.to(X_dtype),
mask=mask,
)
tl.store(
dW_ptr + row_block_id * dW_row_stride + col_offsets,
dW_acc,
mask=mask,
)
def rms_norm_gated_forward(X, G, W, eps, offset):
shape = X.shape
dim = shape[-1]
X = X.view(-1, dim)
G = G.view(-1, dim)
n_rows, n_cols = X.shape
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
RSTD = torch.empty(n_rows, dtype=torch.float32, device=X.device)
assert X.shape[1] == W.shape[0], (
f"Incompatible hidden size: X.shape[1]={X.shape[1]} vs W.shape[0]={W.shape[0]}"
)
assert X.shape == G.shape, (
f"X and G must have same shape, got {X.shape} and {G.shape}"
)
_rms_norm_gated_forward_kernel[(n_rows,)](
Y,
Y.stride(0),
X,
X.stride(0),
G,
G.stride(0),
W,
W.stride(0),
RSTD,
RSTD.stride(0),
n_cols,
eps,
offset,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
return Y.view(*shape), X, G, RSTD, BLOCK_SIZE, num_warps
def rms_norm_gated_backward(dY, X, G, W, RSTD, offset, BLOCK_SIZE, num_warps):
shape = dY.shape
dim = shape[-1]
dY = dY.view(-1, dim)
n_rows, n_cols = dY.shape
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
_dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
dX = torch.empty_like(dY)
dG = torch.empty_like(dY)
rows_per_program = math.ceil(n_rows / sm_count)
grid = (sm_count,)
_rms_norm_gated_backward_kernel[grid](
dY,
dY.stride(0),
dX,
dX.stride(0),
dG,
dG.stride(0),
X,
X.stride(0),
torch_to_triton_dtype[X.dtype],
G,
G.stride(0),
W,
W.stride(0),
RSTD,
RSTD.stride(0),
_dW,
_dW.stride(0),
n_rows,
n_cols,
offset,
rows_per_program,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
dX = dX.view(*shape)
dG = dG.view(*shape)
dW = _dW.sum(dim=0).to(W.dtype)
return dX, dG, dW
class FusedRMSNormGatedFunction(torch.autograd.Function):
@staticmethod
@ensure_contiguous
def forward(ctx, X, G, W, eps, offset=0.0):
"""
X: (B, T, H) or (BxT, H) — input hidden states
G: (B, T, H) or (BxT, H) — gate tensor
W: (H,) — weight parameter
"""
Y, X, G, RSTD, BLOCK_SIZE, num_warps = rms_norm_gated_forward(
X, G, W, eps, offset
)
ctx.offset = offset
ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.num_warps = num_warps
ctx.save_for_backward(X, G, W, RSTD)
return Y
@staticmethod
@ensure_contiguous
def backward(ctx, dY):
X, G, W, RSTD = ctx.saved_tensors
dX, dG, dW = rms_norm_gated_backward(
dY, X, G, W, RSTD, ctx.offset, ctx.BLOCK_SIZE, ctx.num_warps
)
return dX, dG, dW, None, None
class FusedRMSNormGated(torch.nn.Module):
"""
Fused RMSNorm + SiLU Gate.
Computes: Y = W * RMSNorm(X) * silu(G)
Drop-in replacement for Qwen3_5RMSNormGated with matching
init signature: __init__(hidden_size, eps=1e-6, **kwargs)
and forward signature: forward(hidden_states, gate=None)
"""
def __init__(self, hidden_size, eps=1e-6, offset=0.0, **kwargs):
super().__init__()
self.weight = torch.nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
self.offset = offset
def forward(self, hidden_states, gate=None):
if gate is None:
raise ValueError("FusedRMSNormGated requires a gate tensor")
if hidden_states.device.type != "cuda":
raise ValueError(
f"FusedRMSNormGated requires CUDA tensors, got device={hidden_states.device}"
)
return FusedRMSNormGatedFunction.apply(
hidden_states, gate, self.weight, self.variance_epsilon, self.offset
)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"

View File

@@ -0,0 +1,229 @@
"""
Correctness tests for fused RMSNorm + SiLU Gate kernel.
Tests against the eager Qwen3_5RMSNormGated implementation.
"""
import pytest
import torch
import torch.nn.functional as F
pytest.importorskip("triton", reason="triton required for fused kernels")
if not torch.cuda.is_available():
pytest.skip("CUDA required for fused kernel tests", allow_module_level=True)
from axolotl.kernels.rms_norm_gated import FusedRMSNormGated
class EagerRMSNormGated(torch.nn.Module):
"""Reference implementation matching Qwen3_5RMSNormGated exactly."""
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = torch.nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states, gate=None):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
hidden_states = self.weight * hidden_states.to(input_dtype)
hidden_states = hidden_states * F.silu(gate.to(torch.float32))
return hidden_states.to(input_dtype)
def _sync_weights(eager_mod, fused_mod):
"""Copy weights from eager to fused module."""
fused_mod.weight.data.copy_(eager_mod.weight.data)
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
@pytest.mark.parametrize(
"shape",
[
(2, 128, 256),
(4, 64, 512),
(1, 32, 1024),
(2, 16, 2560), # Qwen3.5-4B hidden_size
(2, 16, 4096), # Qwen3.5-9B hidden_size
(1, 8, 5120), # Qwen3.5-27B hidden_size
(4, 16, 2048), # Qwen3.5-35B-A3B (MoE) hidden_size
(4, 16, 3072), # Qwen3.5-122B-A10B (MoE) hidden_size
],
)
class TestRMSNormGatedForward:
def test_output_matches_eager(self, dtype, shape):
torch.manual_seed(42)
B, T, H = shape
X = torch.randn(B, T, H, dtype=dtype, device="cuda")
G = torch.randn(B, T, H, dtype=dtype, device="cuda")
eager = EagerRMSNormGated(H).to(dtype=dtype, device="cuda")
fused = FusedRMSNormGated(H).to(dtype=dtype, device="cuda")
_sync_weights(eager, fused)
y_eager = eager(X, gate=G)
y_fused = fused(X, gate=G)
if dtype == torch.float32:
torch.testing.assert_close(y_fused, y_eager, atol=1e-5, rtol=1e-5)
else:
torch.testing.assert_close(y_fused, y_eager, atol=1e-2, rtol=1e-2)
def test_output_shape(self, dtype, shape):
B, T, H = shape
X = torch.randn(B, T, H, dtype=dtype, device="cuda")
G = torch.randn(B, T, H, dtype=dtype, device="cuda")
fused = FusedRMSNormGated(H).to(dtype=dtype, device="cuda")
y = fused(X, gate=G)
assert y.shape == (B, T, H)
assert y.dtype == dtype
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
@pytest.mark.parametrize(
"shape",
[
(2, 32, 256),
(2, 16, 512),
(2, 16, 2560), # Qwen3.5-4B
(1, 8, 4096), # Qwen3.5-9B
(1, 8, 5120), # Qwen3.5-27B
(2, 16, 2048), # Qwen3.5-35B-A3B (MoE)
(2, 16, 3072), # Qwen3.5-122B-A10B (MoE)
],
)
class TestRMSNormGatedBackward:
def test_grad_x(self, dtype, shape):
torch.manual_seed(42)
B, T, H = shape
X = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
G = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
X_ref = X.detach().clone().requires_grad_(True)
G_ref = G.detach().clone().requires_grad_(True)
eager = EagerRMSNormGated(H).to(dtype=dtype, device="cuda")
fused = FusedRMSNormGated(H).to(dtype=dtype, device="cuda")
_sync_weights(eager, fused)
y_eager = eager(X_ref, gate=G_ref)
y_fused = fused(X, gate=G)
grad_out = torch.randn_like(y_eager)
y_eager.backward(grad_out)
y_fused.backward(grad_out)
if dtype == torch.float32:
atol, rtol = 1e-4, 1e-4
else:
atol, rtol = 5e-2, 5e-2
torch.testing.assert_close(X.grad, X_ref.grad, atol=atol, rtol=rtol)
def test_grad_gate(self, dtype, shape):
torch.manual_seed(42)
B, T, H = shape
X = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
G = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
X_ref = X.detach().clone().requires_grad_(True)
G_ref = G.detach().clone().requires_grad_(True)
eager = EagerRMSNormGated(H).to(dtype=dtype, device="cuda")
fused = FusedRMSNormGated(H).to(dtype=dtype, device="cuda")
_sync_weights(eager, fused)
y_eager = eager(X_ref, gate=G_ref)
y_fused = fused(X, gate=G)
grad_out = torch.randn_like(y_eager)
y_eager.backward(grad_out)
y_fused.backward(grad_out)
if dtype == torch.float32:
atol, rtol = 1e-4, 1e-4
else:
atol, rtol = 5e-2, 5e-2
torch.testing.assert_close(G.grad, G_ref.grad, atol=atol, rtol=rtol)
def test_grad_weight(self, dtype, shape):
torch.manual_seed(42)
B, T, H = shape
X = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
G = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
X_ref = X.detach().clone().requires_grad_(True)
G_ref = G.detach().clone().requires_grad_(True)
eager = EagerRMSNormGated(H).to(dtype=dtype, device="cuda")
fused = FusedRMSNormGated(H).to(dtype=dtype, device="cuda")
_sync_weights(eager, fused)
y_eager = eager(X_ref, gate=G_ref)
y_fused = fused(X, gate=G)
grad_out = torch.randn_like(y_eager)
y_eager.backward(grad_out)
y_fused.backward(grad_out)
if dtype == torch.float32:
atol, rtol = 1e-4, 1e-4
else:
atol, rtol = 5e-2, 5e-2
torch.testing.assert_close(
fused.weight.grad, eager.weight.grad, atol=atol, rtol=rtol
)
class TestRMSNormGatedEdgeCases:
def test_gate_none_raises(self):
fused = FusedRMSNormGated(256).cuda()
X = torch.randn(2, 4, 256, device="cuda")
with pytest.raises(ValueError, match="requires a gate tensor"):
fused(X, gate=None)
def test_2d_input(self):
"""Test with (BxT, H) shaped input instead of (B, T, H)."""
torch.manual_seed(42)
H = 512
X = torch.randn(64, H, dtype=torch.bfloat16, device="cuda", requires_grad=True)
G = torch.randn(64, H, dtype=torch.bfloat16, device="cuda", requires_grad=True)
X_ref = X.detach().clone().requires_grad_(True)
G_ref = G.detach().clone().requires_grad_(True)
eager = EagerRMSNormGated(H).to(dtype=torch.bfloat16, device="cuda")
fused = FusedRMSNormGated(H).to(dtype=torch.bfloat16, device="cuda")
_sync_weights(eager, fused)
y_eager = eager(X_ref, gate=G_ref)
y_fused = fused(X, gate=G)
torch.testing.assert_close(y_fused, y_eager, atol=1e-2, rtol=1e-2)
grad_out = torch.randn_like(y_eager)
y_eager.backward(grad_out)
y_fused.backward(grad_out)
torch.testing.assert_close(X.grad, X_ref.grad, atol=5e-2, rtol=5e-2)
torch.testing.assert_close(G.grad, G_ref.grad, atol=5e-2, rtol=5e-2)
def test_random_weight_init(self):
"""Test with non-default weight values."""
torch.manual_seed(123)
H = 256
X = torch.randn(2, 16, H, dtype=torch.bfloat16, device="cuda")
G = torch.randn(2, 16, H, dtype=torch.bfloat16, device="cuda")
eager = EagerRMSNormGated(H).to(dtype=torch.bfloat16, device="cuda")
# Randomize weights
eager.weight.data = torch.randn_like(eager.weight.data)
fused = FusedRMSNormGated(H).to(dtype=torch.bfloat16, device="cuda")
_sync_weights(eager, fused)
y_eager = eager(X, gate=G)
y_fused = fused(X, gate=G)
torch.testing.assert_close(y_fused, y_eager, atol=1e-2, rtol=1e-2)