liger support for qwen 3.5 and fused rmsnorm+gated (#3531) [skip ci]
* liger support for qwen 3.5 and fused rmsnorm+gated * support for qwen 3.5 moe * fix version ref * fixups for PR code review
This commit is contained in:
@@ -30,6 +30,15 @@ class LigerArgs(BaseModel):
|
||||
|
||||
liger_rope: bool | None = None
|
||||
liger_rms_norm: bool | None = None
|
||||
liger_rms_norm_gated: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": (
|
||||
"Enables fused RMSNorm+SiLU gate Triton kernel for models with "
|
||||
"gated RMSNorm (e.g. Qwen3.5 / Qwen3.5 MoE linear attention layers)."
|
||||
)
|
||||
},
|
||||
)
|
||||
liger_layer_norm: bool | None = None
|
||||
liger_swiglu: bool | None = None
|
||||
liger_glu_activation: bool | None = None
|
||||
|
||||
175
src/axolotl/integrations/liger/models/qwen3_5.py
Normal file
175
src/axolotl/integrations/liger/models/qwen3_5.py
Normal file
@@ -0,0 +1,175 @@
|
||||
"""
|
||||
Liger FLCE for Qwen3.5. Based on transformers v5.3.0.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from copy import deepcopy
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
|
||||
|
||||
def lce_forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs,
|
||||
) -> CausalLMOutputWithPast:
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
Returns:
|
||||
"""
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
|
||||
logits = None
|
||||
loss = None
|
||||
# if in training mode, don't materialize logits
|
||||
if self.training and (labels is not None):
|
||||
loss = LigerForCausalLMLoss(
|
||||
hidden_states=hidden_states,
|
||||
lm_head_weight=self.lm_head.weight,
|
||||
labels=labels,
|
||||
hidden_size=self.config.hidden_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
else: # if in inference mode materialize logits
|
||||
slice_indices = (
|
||||
slice(-logits_to_keep, None)
|
||||
if isinstance(logits_to_keep, int)
|
||||
else logits_to_keep
|
||||
)
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
if labels is not None:
|
||||
loss = self.loss_function(
|
||||
logits=logits,
|
||||
labels=labels,
|
||||
vocab_size=self.config.vocab_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
def apply_liger_kernel_to_qwen3_5(
|
||||
cross_entropy: bool = False,
|
||||
fused_linear_cross_entropy: bool = False,
|
||||
rms_norm: bool = False,
|
||||
rms_norm_gated: bool = False,
|
||||
glu_activation: bool = False,
|
||||
layer_norm: bool = False,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Apply Liger kernels to replace original implementation in HuggingFace Qwen3.5 models.
|
||||
|
||||
Note: Qwen3_5RMSNorm uses zero-init weight with offset 1.0 (like Gemma),
|
||||
so we use LigerRMSNorm with offset=1.0 and init_fn="zeros".
|
||||
|
||||
Args:
|
||||
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
||||
fused_linear_cross_entropy (bool):
|
||||
Whether to apply Liger's fused linear cross entropy loss. Default is False.
|
||||
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
||||
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
||||
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False.
|
||||
rms_norm_gated (bool): Whether to apply fused RMSNorm+SiLU gate kernel for
|
||||
Qwen3_5RMSNormGated (used in linear attention layers). Default is False.
|
||||
glu_activation (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
|
||||
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is False.
|
||||
"""
|
||||
|
||||
import transformers.models.qwen3_5.modeling_qwen3_5 # noqa: F401
|
||||
from liger_kernel.transformers.functional import liger_cross_entropy
|
||||
from liger_kernel.transformers.layer_norm import LigerLayerNorm
|
||||
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
||||
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
||||
|
||||
assert not (cross_entropy and fused_linear_cross_entropy), (
|
||||
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
||||
)
|
||||
|
||||
modeling_qwen3_5 = sys.modules["transformers.models.qwen3_5.modeling_qwen3_5"]
|
||||
|
||||
if rms_norm:
|
||||
# Qwen3_5RMSNorm uses zero-init weight with `output * (1.0 + weight)` pattern
|
||||
class LigerRMSNormForQwen3_5(LigerRMSNorm):
|
||||
def __init__(self, dim, eps=1e-6, **kwargs):
|
||||
super().__init__(
|
||||
dim,
|
||||
eps=eps,
|
||||
offset=1.0,
|
||||
casting_mode="gemma",
|
||||
init_fn="zeros",
|
||||
in_place=False,
|
||||
)
|
||||
|
||||
modeling_qwen3_5.Qwen3_5RMSNorm = LigerRMSNormForQwen3_5
|
||||
|
||||
if rms_norm_gated:
|
||||
from axolotl.kernels.rms_norm_gated import FusedRMSNormGated
|
||||
|
||||
modeling_qwen3_5.Qwen3_5RMSNormGated = FusedRMSNormGated
|
||||
|
||||
if glu_activation:
|
||||
|
||||
def _liger_swiglu_mlp_wrapper(config, intermediate_size=None, **kwargs):
|
||||
"""Accepts intermediate_size to pass to LigerSwiGLUMLP"""
|
||||
config = deepcopy(config)
|
||||
if intermediate_size is not None:
|
||||
config.intermediate_size = intermediate_size
|
||||
return LigerSwiGLUMLP(config, **kwargs)
|
||||
|
||||
modeling_qwen3_5.Qwen3_5MLP = _liger_swiglu_mlp_wrapper
|
||||
|
||||
if layer_norm:
|
||||
modeling_qwen3_5.nn.LayerNorm = LigerLayerNorm
|
||||
|
||||
if cross_entropy:
|
||||
from transformers.loss.loss_utils import nn
|
||||
|
||||
nn.functional.cross_entropy = liger_cross_entropy
|
||||
|
||||
if fused_linear_cross_entropy:
|
||||
modeling_qwen3_5.Qwen3_5ForCausalLM.forward = lce_forward
|
||||
198
src/axolotl/integrations/liger/models/qwen3_5_moe.py
Normal file
198
src/axolotl/integrations/liger/models/qwen3_5_moe.py
Normal file
@@ -0,0 +1,198 @@
|
||||
"""
|
||||
Liger FLCE for Qwen3.5 MoE. Based on transformers v5.3.0.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from copy import deepcopy
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
||||
from transformers.modeling_outputs import MoeCausalLMOutputWithPast
|
||||
|
||||
|
||||
def lce_forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values=None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_router_logits: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs,
|
||||
) -> MoeCausalLMOutputWithPast:
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
Returns:
|
||||
"""
|
||||
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import (
|
||||
load_balancing_loss_func,
|
||||
)
|
||||
|
||||
output_router_logits = (
|
||||
output_router_logits
|
||||
if output_router_logits is not None
|
||||
else self.config.output_router_logits
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_router_logits=output_router_logits,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
|
||||
logits = None
|
||||
loss = None
|
||||
# if in training mode, don't materialize logits
|
||||
if self.training and (labels is not None):
|
||||
loss = LigerForCausalLMLoss(
|
||||
hidden_states=hidden_states,
|
||||
lm_head_weight=self.lm_head.weight,
|
||||
labels=labels,
|
||||
hidden_size=self.config.hidden_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
else: # if in inference mode materialize logits
|
||||
slice_indices = (
|
||||
slice(-logits_to_keep, None)
|
||||
if isinstance(logits_to_keep, int)
|
||||
else logits_to_keep
|
||||
)
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
if labels is not None:
|
||||
loss = self.loss_function(
|
||||
logits,
|
||||
labels,
|
||||
self.vocab_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
aux_loss = None
|
||||
if output_router_logits:
|
||||
aux_loss = load_balancing_loss_func(
|
||||
outputs.router_logits,
|
||||
self.num_experts,
|
||||
self.num_experts_per_tok,
|
||||
attention_mask,
|
||||
)
|
||||
if labels is not None:
|
||||
loss += self.router_aux_loss_coef * aux_loss.to(loss.device)
|
||||
|
||||
return MoeCausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
aux_loss=aux_loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
router_logits=outputs.router_logits,
|
||||
)
|
||||
|
||||
|
||||
def apply_liger_kernel_to_qwen3_5_moe(
|
||||
cross_entropy: bool = False,
|
||||
fused_linear_cross_entropy: bool = False,
|
||||
rms_norm: bool = False,
|
||||
rms_norm_gated: bool = False,
|
||||
glu_activation: bool = False,
|
||||
layer_norm: bool = False,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Apply Liger kernels to replace original implementation in HuggingFace Qwen3.5 MoE models.
|
||||
|
||||
Note: Qwen3_5MoeRMSNorm uses zero-init weight with offset 1.0 (like Gemma),
|
||||
so we use LigerRMSNorm with offset=1.0 and init_fn="zeros".
|
||||
|
||||
Args:
|
||||
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
||||
fused_linear_cross_entropy (bool):
|
||||
Whether to apply Liger's fused linear cross entropy loss. Default is False.
|
||||
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
||||
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
||||
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False.
|
||||
rms_norm_gated (bool): Whether to apply fused RMSNorm+SiLU gate kernel for
|
||||
Qwen3_5MoeRMSNormGated (used in linear attention layers). Default is False.
|
||||
glu_activation (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
|
||||
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is False.
|
||||
"""
|
||||
|
||||
import transformers.models.qwen3_5_moe.modeling_qwen3_5_moe # noqa: F401
|
||||
from liger_kernel.transformers.functional import liger_cross_entropy
|
||||
from liger_kernel.transformers.layer_norm import LigerLayerNorm
|
||||
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
||||
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
||||
|
||||
assert not (cross_entropy and fused_linear_cross_entropy), (
|
||||
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
||||
)
|
||||
|
||||
modeling_mod = sys.modules["transformers.models.qwen3_5_moe.modeling_qwen3_5_moe"]
|
||||
|
||||
if rms_norm:
|
||||
# Qwen3_5MoeRMSNorm uses zero-init weight with `output * (1.0 + weight)` pattern
|
||||
class LigerRMSNormForQwen3_5Moe(LigerRMSNorm):
|
||||
def __init__(self, dim, eps=1e-6, **kwargs):
|
||||
super().__init__(
|
||||
dim,
|
||||
eps=eps,
|
||||
offset=1.0,
|
||||
casting_mode="gemma",
|
||||
init_fn="zeros",
|
||||
in_place=False,
|
||||
)
|
||||
|
||||
modeling_mod.Qwen3_5MoeRMSNorm = LigerRMSNormForQwen3_5Moe
|
||||
|
||||
if rms_norm_gated:
|
||||
from axolotl.kernels.rms_norm_gated import FusedRMSNormGated
|
||||
|
||||
modeling_mod.Qwen3_5MoeRMSNormGated = FusedRMSNormGated
|
||||
|
||||
if glu_activation:
|
||||
|
||||
def _liger_swiglu_mlp_wrapper(config, intermediate_size=None, **kwargs):
|
||||
"""Accepts intermediate_size to pass to LigerSwiGLUMLP"""
|
||||
config = deepcopy(config)
|
||||
if intermediate_size is not None:
|
||||
config.intermediate_size = intermediate_size
|
||||
return LigerSwiGLUMLP(config, **kwargs)
|
||||
|
||||
modeling_mod.Qwen3_5MoeMLP = _liger_swiglu_mlp_wrapper
|
||||
|
||||
if layer_norm:
|
||||
modeling_mod.nn.LayerNorm = LigerLayerNorm
|
||||
|
||||
if cross_entropy:
|
||||
from transformers.loss.loss_utils import nn
|
||||
|
||||
nn.functional.cross_entropy = liger_cross_entropy
|
||||
|
||||
if fused_linear_cross_entropy:
|
||||
modeling_mod.Qwen3_5MoeForCausalLM.forward = lce_forward
|
||||
@@ -174,6 +174,19 @@ class LigerPlugin(BasePlugin):
|
||||
rms_norm=cfg.liger_rms_norm,
|
||||
layer_norm=cfg.liger_layer_norm,
|
||||
)
|
||||
elif cfg.model_config_type == "qwen3_5":
|
||||
from axolotl.integrations.liger.models.qwen3_5 import (
|
||||
apply_liger_kernel_to_qwen3_5,
|
||||
)
|
||||
|
||||
apply_liger_kernel_to_qwen3_5(
|
||||
cross_entropy=cfg.liger_cross_entropy,
|
||||
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
|
||||
glu_activation=cfg.liger_glu_activation,
|
||||
rms_norm=cfg.liger_rms_norm,
|
||||
rms_norm_gated=getattr(cfg, "liger_rms_norm_gated", False),
|
||||
layer_norm=cfg.liger_layer_norm,
|
||||
)
|
||||
elif cfg.model_config_type == "qwen3_moe":
|
||||
from axolotl.integrations.liger.models.qwen3_moe import (
|
||||
apply_liger_kernel_to_qwen3_moe,
|
||||
@@ -186,6 +199,19 @@ class LigerPlugin(BasePlugin):
|
||||
rms_norm=cfg.liger_rms_norm,
|
||||
layer_norm=cfg.liger_layer_norm,
|
||||
)
|
||||
elif cfg.model_config_type == "qwen3_5_moe":
|
||||
from axolotl.integrations.liger.models.qwen3_5_moe import (
|
||||
apply_liger_kernel_to_qwen3_5_moe,
|
||||
)
|
||||
|
||||
apply_liger_kernel_to_qwen3_5_moe(
|
||||
cross_entropy=cfg.liger_cross_entropy,
|
||||
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
|
||||
glu_activation=cfg.liger_glu_activation,
|
||||
rms_norm=cfg.liger_rms_norm,
|
||||
rms_norm_gated=getattr(cfg, "liger_rms_norm_gated", False),
|
||||
layer_norm=cfg.liger_layer_norm,
|
||||
)
|
||||
elif cfg.model_config_type == "granitemoe":
|
||||
from liger_kernel.transformers import apply_liger_kernel_to_granite
|
||||
|
||||
|
||||
333
src/axolotl/kernels/rms_norm_gated.py
Normal file
333
src/axolotl/kernels/rms_norm_gated.py
Normal file
@@ -0,0 +1,333 @@
|
||||
"""
|
||||
Fused RMSNorm + SiLU Gate Triton kernel.
|
||||
|
||||
Computes: Y = (W + offset) * RMSNorm(X) * silu(G)
|
||||
where RMSNorm(X) = X / sqrt(mean(X^2) + eps)
|
||||
and silu(G) = G * sigmoid(G)
|
||||
|
||||
Used by Qwen3.5's GatedDeltaNet linear attention layers (Qwen3_5RMSNormGated).
|
||||
"""
|
||||
|
||||
import math
|
||||
import operator
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from liger_kernel.ops.utils import (
|
||||
calculate_settings,
|
||||
compare_version,
|
||||
ensure_contiguous,
|
||||
torch_to_triton_dtype,
|
||||
)
|
||||
from liger_kernel.utils import is_npu_available
|
||||
|
||||
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
||||
try:
|
||||
from triton.language.extra.libdevice import rsqrt
|
||||
except ModuleNotFoundError:
|
||||
from triton.language.extra.cuda.libdevice import rsqrt
|
||||
else:
|
||||
from triton.language.math import rsqrt
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _rms_norm_gated_forward_kernel(
|
||||
Y_ptr,
|
||||
Y_row_stride,
|
||||
X_ptr,
|
||||
X_row_stride,
|
||||
G_ptr,
|
||||
G_row_stride,
|
||||
W_ptr,
|
||||
W_row_stride,
|
||||
RSTD_ptr,
|
||||
RSTD_row_stride,
|
||||
n_cols,
|
||||
eps,
|
||||
offset,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Y = (W + offset) * (X / RMS(X)) * silu(G)
|
||||
|
||||
All computation done in fp32 (Gemma-style), result cast to input dtype.
|
||||
"""
|
||||
row_idx = tl.program_id(0).to(tl.int64)
|
||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||
mask = col_offsets < n_cols
|
||||
|
||||
X_row = tl.load(X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0)
|
||||
G_row = tl.load(G_ptr + row_idx * G_row_stride + col_offsets, mask=mask, other=0)
|
||||
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
|
||||
|
||||
X_row_dtype = X_row.dtype
|
||||
|
||||
# Cast everything to fp32
|
||||
X_fp32 = X_row.to(tl.float32)
|
||||
G_fp32 = G_row.to(tl.float32)
|
||||
W_fp32 = W_row.to(tl.float32)
|
||||
|
||||
# RMS norm
|
||||
mean_sq = tl.sum(X_fp32 * X_fp32, axis=0) / n_cols
|
||||
rstd = rsqrt(mean_sq + eps)
|
||||
tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd)
|
||||
|
||||
X_norm = X_fp32 * rstd
|
||||
|
||||
# SiLU gate: silu(G) = G * sigmoid(G)
|
||||
sig_G = tl.sigmoid(G_fp32)
|
||||
silu_G = G_fp32 * sig_G
|
||||
|
||||
# Fused output
|
||||
Y_row = (offset + W_fp32) * X_norm * silu_G
|
||||
|
||||
tl.store(
|
||||
Y_ptr + row_idx * Y_row_stride + col_offsets,
|
||||
Y_row.to(X_row_dtype),
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _rms_norm_gated_backward_kernel(
|
||||
dY_ptr,
|
||||
dY_row_stride,
|
||||
dX_ptr,
|
||||
dX_row_stride,
|
||||
dG_ptr,
|
||||
dG_row_stride,
|
||||
X_ptr,
|
||||
X_row_stride,
|
||||
X_dtype: tl.constexpr,
|
||||
G_ptr,
|
||||
G_row_stride,
|
||||
W_ptr,
|
||||
W_row_stride,
|
||||
RSTD_ptr,
|
||||
RSTD_row_stride,
|
||||
dW_ptr,
|
||||
dW_row_stride,
|
||||
n_rows,
|
||||
n_cols,
|
||||
offset,
|
||||
rows_per_program,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Backward for Y = (W + offset) * (X * RSTD) * silu(G)
|
||||
|
||||
dW = sum_batch(dY * X_norm * silu(G))
|
||||
dG = dY * (W + offset) * X_norm * silu'(G)
|
||||
where silu'(G) = sigmoid(G) * (1 + G * (1 - sigmoid(G)))
|
||||
dX = RSTD * (m - (1/N) * RSTD^2 * dot(m, X) * X)
|
||||
where m = dY * (W + offset) * silu(G)
|
||||
"""
|
||||
row_block_id = tl.program_id(0).to(tl.int64)
|
||||
row_start = row_block_id * rows_per_program
|
||||
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
|
||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||
mask = col_offsets < n_cols
|
||||
|
||||
dW_acc = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
||||
|
||||
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
|
||||
W_row = W_row.to(tl.float32) + offset
|
||||
|
||||
for row_idx in range(row_start, row_end):
|
||||
dY_row = tl.load(
|
||||
dY_ptr + row_idx * dY_row_stride + col_offsets, mask=mask, other=0.0
|
||||
)
|
||||
X_row = tl.load(
|
||||
X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0.0
|
||||
)
|
||||
G_row = tl.load(
|
||||
G_ptr + row_idx * G_row_stride + col_offsets, mask=mask, other=0.0
|
||||
)
|
||||
rstd_row = tl.load(RSTD_ptr + row_idx * RSTD_row_stride)
|
||||
|
||||
# Cast to fp32
|
||||
dY_fp32 = dY_row.to(tl.float32)
|
||||
X_fp32 = X_row.to(tl.float32)
|
||||
G_fp32 = G_row.to(tl.float32)
|
||||
|
||||
# Recompute intermediates
|
||||
X_norm = X_fp32 * rstd_row
|
||||
sig_G = tl.sigmoid(G_fp32)
|
||||
silu_G = G_fp32 * sig_G
|
||||
|
||||
# dW: accumulate dY * X_norm * silu(G)
|
||||
dW_acc += dY_fp32 * X_norm * silu_G
|
||||
|
||||
# dG: dY * (W + offset) * X_norm * silu'(G)
|
||||
# silu'(G) = sigmoid(G) * (1 + G * (1 - sigmoid(G)))
|
||||
silu_prime_G = sig_G * (1.0 + G_fp32 * (1.0 - sig_G))
|
||||
dG_row = dY_fp32 * W_row * X_norm * silu_prime_G
|
||||
tl.store(
|
||||
dG_ptr + row_idx * dG_row_stride + col_offsets,
|
||||
dG_row.to(X_dtype),
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
# dX: standard RMSNorm backward with effective gradient m = dY * W * silu(G)
|
||||
m = dY_fp32 * W_row * silu_G
|
||||
dX_row = rstd_row * m
|
||||
dX_row += rstd_row * (
|
||||
-(1.0 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_fp32, axis=0) * X_fp32
|
||||
)
|
||||
tl.store(
|
||||
dX_ptr + row_idx * dX_row_stride + col_offsets,
|
||||
dX_row.to(X_dtype),
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
tl.store(
|
||||
dW_ptr + row_block_id * dW_row_stride + col_offsets,
|
||||
dW_acc,
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
|
||||
def rms_norm_gated_forward(X, G, W, eps, offset):
|
||||
shape = X.shape
|
||||
dim = shape[-1]
|
||||
X = X.view(-1, dim)
|
||||
G = G.view(-1, dim)
|
||||
n_rows, n_cols = X.shape
|
||||
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
||||
|
||||
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
|
||||
RSTD = torch.empty(n_rows, dtype=torch.float32, device=X.device)
|
||||
|
||||
assert X.shape[1] == W.shape[0], (
|
||||
f"Incompatible hidden size: X.shape[1]={X.shape[1]} vs W.shape[0]={W.shape[0]}"
|
||||
)
|
||||
assert X.shape == G.shape, (
|
||||
f"X and G must have same shape, got {X.shape} and {G.shape}"
|
||||
)
|
||||
|
||||
_rms_norm_gated_forward_kernel[(n_rows,)](
|
||||
Y,
|
||||
Y.stride(0),
|
||||
X,
|
||||
X.stride(0),
|
||||
G,
|
||||
G.stride(0),
|
||||
W,
|
||||
W.stride(0),
|
||||
RSTD,
|
||||
RSTD.stride(0),
|
||||
n_cols,
|
||||
eps,
|
||||
offset,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
num_warps=num_warps,
|
||||
)
|
||||
return Y.view(*shape), X, G, RSTD, BLOCK_SIZE, num_warps
|
||||
|
||||
|
||||
def rms_norm_gated_backward(dY, X, G, W, RSTD, offset, BLOCK_SIZE, num_warps):
|
||||
shape = dY.shape
|
||||
dim = shape[-1]
|
||||
dY = dY.view(-1, dim)
|
||||
n_rows, n_cols = dY.shape
|
||||
|
||||
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
|
||||
|
||||
_dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
|
||||
dX = torch.empty_like(dY)
|
||||
dG = torch.empty_like(dY)
|
||||
|
||||
rows_per_program = math.ceil(n_rows / sm_count)
|
||||
grid = (sm_count,)
|
||||
|
||||
_rms_norm_gated_backward_kernel[grid](
|
||||
dY,
|
||||
dY.stride(0),
|
||||
dX,
|
||||
dX.stride(0),
|
||||
dG,
|
||||
dG.stride(0),
|
||||
X,
|
||||
X.stride(0),
|
||||
torch_to_triton_dtype[X.dtype],
|
||||
G,
|
||||
G.stride(0),
|
||||
W,
|
||||
W.stride(0),
|
||||
RSTD,
|
||||
RSTD.stride(0),
|
||||
_dW,
|
||||
_dW.stride(0),
|
||||
n_rows,
|
||||
n_cols,
|
||||
offset,
|
||||
rows_per_program,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
num_warps=num_warps,
|
||||
)
|
||||
|
||||
dX = dX.view(*shape)
|
||||
dG = dG.view(*shape)
|
||||
dW = _dW.sum(dim=0).to(W.dtype)
|
||||
return dX, dG, dW
|
||||
|
||||
|
||||
class FusedRMSNormGatedFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@ensure_contiguous
|
||||
def forward(ctx, X, G, W, eps, offset=0.0):
|
||||
"""
|
||||
X: (B, T, H) or (BxT, H) — input hidden states
|
||||
G: (B, T, H) or (BxT, H) — gate tensor
|
||||
W: (H,) — weight parameter
|
||||
"""
|
||||
Y, X, G, RSTD, BLOCK_SIZE, num_warps = rms_norm_gated_forward(
|
||||
X, G, W, eps, offset
|
||||
)
|
||||
ctx.offset = offset
|
||||
ctx.BLOCK_SIZE = BLOCK_SIZE
|
||||
ctx.num_warps = num_warps
|
||||
ctx.save_for_backward(X, G, W, RSTD)
|
||||
return Y
|
||||
|
||||
@staticmethod
|
||||
@ensure_contiguous
|
||||
def backward(ctx, dY):
|
||||
X, G, W, RSTD = ctx.saved_tensors
|
||||
dX, dG, dW = rms_norm_gated_backward(
|
||||
dY, X, G, W, RSTD, ctx.offset, ctx.BLOCK_SIZE, ctx.num_warps
|
||||
)
|
||||
return dX, dG, dW, None, None
|
||||
|
||||
|
||||
class FusedRMSNormGated(torch.nn.Module):
|
||||
"""
|
||||
Fused RMSNorm + SiLU Gate.
|
||||
|
||||
Computes: Y = W * RMSNorm(X) * silu(G)
|
||||
|
||||
Drop-in replacement for Qwen3_5RMSNormGated with matching
|
||||
init signature: __init__(hidden_size, eps=1e-6, **kwargs)
|
||||
and forward signature: forward(hidden_states, gate=None)
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, eps=1e-6, offset=0.0, **kwargs):
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
self.offset = offset
|
||||
|
||||
def forward(self, hidden_states, gate=None):
|
||||
if gate is None:
|
||||
raise ValueError("FusedRMSNormGated requires a gate tensor")
|
||||
if hidden_states.device.type != "cuda":
|
||||
raise ValueError(
|
||||
f"FusedRMSNormGated requires CUDA tensors, got device={hidden_states.device}"
|
||||
)
|
||||
return FusedRMSNormGatedFunction.apply(
|
||||
hidden_states, gate, self.weight, self.variance_epsilon, self.offset
|
||||
)
|
||||
|
||||
def extra_repr(self):
|
||||
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
||||
229
tests/kernels/test_rms_norm_gated.py
Normal file
229
tests/kernels/test_rms_norm_gated.py
Normal file
@@ -0,0 +1,229 @@
|
||||
"""
|
||||
Correctness tests for fused RMSNorm + SiLU Gate kernel.
|
||||
|
||||
Tests against the eager Qwen3_5RMSNormGated implementation.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
pytest.importorskip("triton", reason="triton required for fused kernels")
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA required for fused kernel tests", allow_module_level=True)
|
||||
|
||||
from axolotl.kernels.rms_norm_gated import FusedRMSNormGated
|
||||
|
||||
|
||||
class EagerRMSNormGated(torch.nn.Module):
|
||||
"""Reference implementation matching Qwen3_5RMSNormGated exactly."""
|
||||
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states, gate=None):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
hidden_states = self.weight * hidden_states.to(input_dtype)
|
||||
hidden_states = hidden_states * F.silu(gate.to(torch.float32))
|
||||
return hidden_states.to(input_dtype)
|
||||
|
||||
|
||||
def _sync_weights(eager_mod, fused_mod):
|
||||
"""Copy weights from eager to fused module."""
|
||||
fused_mod.weight.data.copy_(eager_mod.weight.data)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize(
|
||||
"shape",
|
||||
[
|
||||
(2, 128, 256),
|
||||
(4, 64, 512),
|
||||
(1, 32, 1024),
|
||||
(2, 16, 2560), # Qwen3.5-4B hidden_size
|
||||
(2, 16, 4096), # Qwen3.5-9B hidden_size
|
||||
(1, 8, 5120), # Qwen3.5-27B hidden_size
|
||||
(4, 16, 2048), # Qwen3.5-35B-A3B (MoE) hidden_size
|
||||
(4, 16, 3072), # Qwen3.5-122B-A10B (MoE) hidden_size
|
||||
],
|
||||
)
|
||||
class TestRMSNormGatedForward:
|
||||
def test_output_matches_eager(self, dtype, shape):
|
||||
torch.manual_seed(42)
|
||||
B, T, H = shape
|
||||
X = torch.randn(B, T, H, dtype=dtype, device="cuda")
|
||||
G = torch.randn(B, T, H, dtype=dtype, device="cuda")
|
||||
|
||||
eager = EagerRMSNormGated(H).to(dtype=dtype, device="cuda")
|
||||
fused = FusedRMSNormGated(H).to(dtype=dtype, device="cuda")
|
||||
_sync_weights(eager, fused)
|
||||
|
||||
y_eager = eager(X, gate=G)
|
||||
y_fused = fused(X, gate=G)
|
||||
|
||||
if dtype == torch.float32:
|
||||
torch.testing.assert_close(y_fused, y_eager, atol=1e-5, rtol=1e-5)
|
||||
else:
|
||||
torch.testing.assert_close(y_fused, y_eager, atol=1e-2, rtol=1e-2)
|
||||
|
||||
def test_output_shape(self, dtype, shape):
|
||||
B, T, H = shape
|
||||
X = torch.randn(B, T, H, dtype=dtype, device="cuda")
|
||||
G = torch.randn(B, T, H, dtype=dtype, device="cuda")
|
||||
|
||||
fused = FusedRMSNormGated(H).to(dtype=dtype, device="cuda")
|
||||
y = fused(X, gate=G)
|
||||
assert y.shape == (B, T, H)
|
||||
assert y.dtype == dtype
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize(
|
||||
"shape",
|
||||
[
|
||||
(2, 32, 256),
|
||||
(2, 16, 512),
|
||||
(2, 16, 2560), # Qwen3.5-4B
|
||||
(1, 8, 4096), # Qwen3.5-9B
|
||||
(1, 8, 5120), # Qwen3.5-27B
|
||||
(2, 16, 2048), # Qwen3.5-35B-A3B (MoE)
|
||||
(2, 16, 3072), # Qwen3.5-122B-A10B (MoE)
|
||||
],
|
||||
)
|
||||
class TestRMSNormGatedBackward:
|
||||
def test_grad_x(self, dtype, shape):
|
||||
torch.manual_seed(42)
|
||||
B, T, H = shape
|
||||
X = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
|
||||
G = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
|
||||
X_ref = X.detach().clone().requires_grad_(True)
|
||||
G_ref = G.detach().clone().requires_grad_(True)
|
||||
|
||||
eager = EagerRMSNormGated(H).to(dtype=dtype, device="cuda")
|
||||
fused = FusedRMSNormGated(H).to(dtype=dtype, device="cuda")
|
||||
_sync_weights(eager, fused)
|
||||
|
||||
y_eager = eager(X_ref, gate=G_ref)
|
||||
y_fused = fused(X, gate=G)
|
||||
|
||||
grad_out = torch.randn_like(y_eager)
|
||||
y_eager.backward(grad_out)
|
||||
y_fused.backward(grad_out)
|
||||
|
||||
if dtype == torch.float32:
|
||||
atol, rtol = 1e-4, 1e-4
|
||||
else:
|
||||
atol, rtol = 5e-2, 5e-2
|
||||
|
||||
torch.testing.assert_close(X.grad, X_ref.grad, atol=atol, rtol=rtol)
|
||||
|
||||
def test_grad_gate(self, dtype, shape):
|
||||
torch.manual_seed(42)
|
||||
B, T, H = shape
|
||||
X = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
|
||||
G = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
|
||||
X_ref = X.detach().clone().requires_grad_(True)
|
||||
G_ref = G.detach().clone().requires_grad_(True)
|
||||
|
||||
eager = EagerRMSNormGated(H).to(dtype=dtype, device="cuda")
|
||||
fused = FusedRMSNormGated(H).to(dtype=dtype, device="cuda")
|
||||
_sync_weights(eager, fused)
|
||||
|
||||
y_eager = eager(X_ref, gate=G_ref)
|
||||
y_fused = fused(X, gate=G)
|
||||
|
||||
grad_out = torch.randn_like(y_eager)
|
||||
y_eager.backward(grad_out)
|
||||
y_fused.backward(grad_out)
|
||||
|
||||
if dtype == torch.float32:
|
||||
atol, rtol = 1e-4, 1e-4
|
||||
else:
|
||||
atol, rtol = 5e-2, 5e-2
|
||||
|
||||
torch.testing.assert_close(G.grad, G_ref.grad, atol=atol, rtol=rtol)
|
||||
|
||||
def test_grad_weight(self, dtype, shape):
|
||||
torch.manual_seed(42)
|
||||
B, T, H = shape
|
||||
X = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
|
||||
G = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
|
||||
X_ref = X.detach().clone().requires_grad_(True)
|
||||
G_ref = G.detach().clone().requires_grad_(True)
|
||||
|
||||
eager = EagerRMSNormGated(H).to(dtype=dtype, device="cuda")
|
||||
fused = FusedRMSNormGated(H).to(dtype=dtype, device="cuda")
|
||||
_sync_weights(eager, fused)
|
||||
|
||||
y_eager = eager(X_ref, gate=G_ref)
|
||||
y_fused = fused(X, gate=G)
|
||||
|
||||
grad_out = torch.randn_like(y_eager)
|
||||
y_eager.backward(grad_out)
|
||||
y_fused.backward(grad_out)
|
||||
|
||||
if dtype == torch.float32:
|
||||
atol, rtol = 1e-4, 1e-4
|
||||
else:
|
||||
atol, rtol = 5e-2, 5e-2
|
||||
|
||||
torch.testing.assert_close(
|
||||
fused.weight.grad, eager.weight.grad, atol=atol, rtol=rtol
|
||||
)
|
||||
|
||||
|
||||
class TestRMSNormGatedEdgeCases:
|
||||
def test_gate_none_raises(self):
|
||||
fused = FusedRMSNormGated(256).cuda()
|
||||
X = torch.randn(2, 4, 256, device="cuda")
|
||||
with pytest.raises(ValueError, match="requires a gate tensor"):
|
||||
fused(X, gate=None)
|
||||
|
||||
def test_2d_input(self):
|
||||
"""Test with (BxT, H) shaped input instead of (B, T, H)."""
|
||||
torch.manual_seed(42)
|
||||
H = 512
|
||||
X = torch.randn(64, H, dtype=torch.bfloat16, device="cuda", requires_grad=True)
|
||||
G = torch.randn(64, H, dtype=torch.bfloat16, device="cuda", requires_grad=True)
|
||||
X_ref = X.detach().clone().requires_grad_(True)
|
||||
G_ref = G.detach().clone().requires_grad_(True)
|
||||
|
||||
eager = EagerRMSNormGated(H).to(dtype=torch.bfloat16, device="cuda")
|
||||
fused = FusedRMSNormGated(H).to(dtype=torch.bfloat16, device="cuda")
|
||||
_sync_weights(eager, fused)
|
||||
|
||||
y_eager = eager(X_ref, gate=G_ref)
|
||||
y_fused = fused(X, gate=G)
|
||||
|
||||
torch.testing.assert_close(y_fused, y_eager, atol=1e-2, rtol=1e-2)
|
||||
|
||||
grad_out = torch.randn_like(y_eager)
|
||||
y_eager.backward(grad_out)
|
||||
y_fused.backward(grad_out)
|
||||
|
||||
torch.testing.assert_close(X.grad, X_ref.grad, atol=5e-2, rtol=5e-2)
|
||||
torch.testing.assert_close(G.grad, G_ref.grad, atol=5e-2, rtol=5e-2)
|
||||
|
||||
def test_random_weight_init(self):
|
||||
"""Test with non-default weight values."""
|
||||
torch.manual_seed(123)
|
||||
H = 256
|
||||
X = torch.randn(2, 16, H, dtype=torch.bfloat16, device="cuda")
|
||||
G = torch.randn(2, 16, H, dtype=torch.bfloat16, device="cuda")
|
||||
|
||||
eager = EagerRMSNormGated(H).to(dtype=torch.bfloat16, device="cuda")
|
||||
# Randomize weights
|
||||
eager.weight.data = torch.randn_like(eager.weight.data)
|
||||
|
||||
fused = FusedRMSNormGated(H).to(dtype=torch.bfloat16, device="cuda")
|
||||
_sync_weights(eager, fused)
|
||||
|
||||
y_eager = eager(X, gate=G)
|
||||
y_fused = fused(X, gate=G)
|
||||
torch.testing.assert_close(y_fused, y_eager, atol=1e-2, rtol=1e-2)
|
||||
Reference in New Issue
Block a user