diff --git a/src/axolotl/integrations/liger/args.py b/src/axolotl/integrations/liger/args.py index eb7a6c59b..a5f88ffe2 100644 --- a/src/axolotl/integrations/liger/args.py +++ b/src/axolotl/integrations/liger/args.py @@ -30,6 +30,15 @@ class LigerArgs(BaseModel): liger_rope: bool | None = None liger_rms_norm: bool | None = None + liger_rms_norm_gated: bool | None = Field( + default=None, + json_schema_extra={ + "description": ( + "Enables fused RMSNorm+SiLU gate Triton kernel for models with " + "gated RMSNorm (e.g. Qwen3.5 / Qwen3.5 MoE linear attention layers)." + ) + }, + ) liger_layer_norm: bool | None = None liger_swiglu: bool | None = None liger_glu_activation: bool | None = None diff --git a/src/axolotl/integrations/liger/models/qwen3_5.py b/src/axolotl/integrations/liger/models/qwen3_5.py new file mode 100644 index 000000000..ee4b9b1c1 --- /dev/null +++ b/src/axolotl/integrations/liger/models/qwen3_5.py @@ -0,0 +1,175 @@ +""" +Liger FLCE for Qwen3.5. Based on transformers v5.3.0. +""" + +import sys +from copy import deepcopy +from typing import Optional, Union + +import torch +from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from transformers.cache_utils import Cache +from transformers.modeling_outputs import CausalLMOutputWithPast + + +def lce_forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs, +) -> CausalLMOutputWithPast: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + """ + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + + logits = None + loss = None + # if in training mode, don't materialize logits + if self.training and (labels is not None): + loss = LigerForCausalLMLoss( + hidden_states=hidden_states, + lm_head_weight=self.lm_head.weight, + labels=labels, + hidden_size=self.config.hidden_size, + **kwargs, + ) + + else: # if in inference mode materialize logits + slice_indices = ( + slice(-logits_to_keep, None) + if isinstance(logits_to_keep, int) + else logits_to_keep + ) + logits = self.lm_head(hidden_states[:, slice_indices, :]) + if labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +def apply_liger_kernel_to_qwen3_5( + cross_entropy: bool = False, + fused_linear_cross_entropy: bool = False, + rms_norm: bool = False, + rms_norm_gated: bool = False, + glu_activation: bool = False, + layer_norm: bool = False, + **kwargs, +) -> None: + """ + Apply Liger kernels to replace original implementation in HuggingFace Qwen3.5 models. + + Note: Qwen3_5RMSNorm uses zero-init weight with offset 1.0 (like Gemma), + so we use LigerRMSNorm with offset=1.0 and init_fn="zeros". + + Args: + cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. + fused_linear_cross_entropy (bool): + Whether to apply Liger's fused linear cross entropy loss. Default is False. + `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. + If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. + rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False. + rms_norm_gated (bool): Whether to apply fused RMSNorm+SiLU gate kernel for + Qwen3_5RMSNormGated (used in linear attention layers). Default is False. + glu_activation (bool): Whether to apply Liger's SwiGLU MLP. Default is False. + layer_norm (bool): Whether to apply Liger's LayerNorm. Default is False. + """ + + import transformers.models.qwen3_5.modeling_qwen3_5 # noqa: F401 + from liger_kernel.transformers.functional import liger_cross_entropy + from liger_kernel.transformers.layer_norm import LigerLayerNorm + from liger_kernel.transformers.rms_norm import LigerRMSNorm + from liger_kernel.transformers.swiglu import LigerSwiGLUMLP + + assert not (cross_entropy and fused_linear_cross_entropy), ( + "cross_entropy and fused_linear_cross_entropy cannot both be True." + ) + + modeling_qwen3_5 = sys.modules["transformers.models.qwen3_5.modeling_qwen3_5"] + + if rms_norm: + # Qwen3_5RMSNorm uses zero-init weight with `output * (1.0 + weight)` pattern + class LigerRMSNormForQwen3_5(LigerRMSNorm): + def __init__(self, dim, eps=1e-6, **kwargs): + super().__init__( + dim, + eps=eps, + offset=1.0, + casting_mode="gemma", + init_fn="zeros", + in_place=False, + ) + + modeling_qwen3_5.Qwen3_5RMSNorm = LigerRMSNormForQwen3_5 + + if rms_norm_gated: + from axolotl.kernels.rms_norm_gated import FusedRMSNormGated + + modeling_qwen3_5.Qwen3_5RMSNormGated = FusedRMSNormGated + + if glu_activation: + + def _liger_swiglu_mlp_wrapper(config, intermediate_size=None, **kwargs): + """Accepts intermediate_size to pass to LigerSwiGLUMLP""" + config = deepcopy(config) + if intermediate_size is not None: + config.intermediate_size = intermediate_size + return LigerSwiGLUMLP(config, **kwargs) + + modeling_qwen3_5.Qwen3_5MLP = _liger_swiglu_mlp_wrapper + + if layer_norm: + modeling_qwen3_5.nn.LayerNorm = LigerLayerNorm + + if cross_entropy: + from transformers.loss.loss_utils import nn + + nn.functional.cross_entropy = liger_cross_entropy + + if fused_linear_cross_entropy: + modeling_qwen3_5.Qwen3_5ForCausalLM.forward = lce_forward diff --git a/src/axolotl/integrations/liger/models/qwen3_5_moe.py b/src/axolotl/integrations/liger/models/qwen3_5_moe.py new file mode 100644 index 000000000..b10b34ad5 --- /dev/null +++ b/src/axolotl/integrations/liger/models/qwen3_5_moe.py @@ -0,0 +1,198 @@ +""" +Liger FLCE for Qwen3.5 MoE. Based on transformers v5.3.0. +""" + +import sys +from copy import deepcopy +from typing import Optional, Union + +import torch +from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from transformers.modeling_outputs import MoeCausalLMOutputWithPast + + +def lce_forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values=None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs, +) -> MoeCausalLMOutputWithPast: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + """ + from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import ( + load_balancing_loss_func, + ) + + output_router_logits = ( + output_router_logits + if output_router_logits is not None + else self.config.output_router_logits + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_router_logits=output_router_logits, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + + logits = None + loss = None + # if in training mode, don't materialize logits + if self.training and (labels is not None): + loss = LigerForCausalLMLoss( + hidden_states=hidden_states, + lm_head_weight=self.lm_head.weight, + labels=labels, + hidden_size=self.config.hidden_size, + **kwargs, + ) + + else: # if in inference mode materialize logits + slice_indices = ( + slice(-logits_to_keep, None) + if isinstance(logits_to_keep, int) + else logits_to_keep + ) + logits = self.lm_head(hidden_states[:, slice_indices, :]) + if labels is not None: + loss = self.loss_function( + logits, + labels, + self.vocab_size, + **kwargs, + ) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits, + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss.to(loss.device) + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) + + +def apply_liger_kernel_to_qwen3_5_moe( + cross_entropy: bool = False, + fused_linear_cross_entropy: bool = False, + rms_norm: bool = False, + rms_norm_gated: bool = False, + glu_activation: bool = False, + layer_norm: bool = False, + **kwargs, +) -> None: + """ + Apply Liger kernels to replace original implementation in HuggingFace Qwen3.5 MoE models. + + Note: Qwen3_5MoeRMSNorm uses zero-init weight with offset 1.0 (like Gemma), + so we use LigerRMSNorm with offset=1.0 and init_fn="zeros". + + Args: + cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. + fused_linear_cross_entropy (bool): + Whether to apply Liger's fused linear cross entropy loss. Default is False. + `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. + If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. + rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False. + rms_norm_gated (bool): Whether to apply fused RMSNorm+SiLU gate kernel for + Qwen3_5MoeRMSNormGated (used in linear attention layers). Default is False. + glu_activation (bool): Whether to apply Liger's SwiGLU MLP. Default is False. + layer_norm (bool): Whether to apply Liger's LayerNorm. Default is False. + """ + + import transformers.models.qwen3_5_moe.modeling_qwen3_5_moe # noqa: F401 + from liger_kernel.transformers.functional import liger_cross_entropy + from liger_kernel.transformers.layer_norm import LigerLayerNorm + from liger_kernel.transformers.rms_norm import LigerRMSNorm + from liger_kernel.transformers.swiglu import LigerSwiGLUMLP + + assert not (cross_entropy and fused_linear_cross_entropy), ( + "cross_entropy and fused_linear_cross_entropy cannot both be True." + ) + + modeling_mod = sys.modules["transformers.models.qwen3_5_moe.modeling_qwen3_5_moe"] + + if rms_norm: + # Qwen3_5MoeRMSNorm uses zero-init weight with `output * (1.0 + weight)` pattern + class LigerRMSNormForQwen3_5Moe(LigerRMSNorm): + def __init__(self, dim, eps=1e-6, **kwargs): + super().__init__( + dim, + eps=eps, + offset=1.0, + casting_mode="gemma", + init_fn="zeros", + in_place=False, + ) + + modeling_mod.Qwen3_5MoeRMSNorm = LigerRMSNormForQwen3_5Moe + + if rms_norm_gated: + from axolotl.kernels.rms_norm_gated import FusedRMSNormGated + + modeling_mod.Qwen3_5MoeRMSNormGated = FusedRMSNormGated + + if glu_activation: + + def _liger_swiglu_mlp_wrapper(config, intermediate_size=None, **kwargs): + """Accepts intermediate_size to pass to LigerSwiGLUMLP""" + config = deepcopy(config) + if intermediate_size is not None: + config.intermediate_size = intermediate_size + return LigerSwiGLUMLP(config, **kwargs) + + modeling_mod.Qwen3_5MoeMLP = _liger_swiglu_mlp_wrapper + + if layer_norm: + modeling_mod.nn.LayerNorm = LigerLayerNorm + + if cross_entropy: + from transformers.loss.loss_utils import nn + + nn.functional.cross_entropy = liger_cross_entropy + + if fused_linear_cross_entropy: + modeling_mod.Qwen3_5MoeForCausalLM.forward = lce_forward diff --git a/src/axolotl/integrations/liger/plugin.py b/src/axolotl/integrations/liger/plugin.py index cfd652872..d56109570 100644 --- a/src/axolotl/integrations/liger/plugin.py +++ b/src/axolotl/integrations/liger/plugin.py @@ -174,6 +174,19 @@ class LigerPlugin(BasePlugin): rms_norm=cfg.liger_rms_norm, layer_norm=cfg.liger_layer_norm, ) + elif cfg.model_config_type == "qwen3_5": + from axolotl.integrations.liger.models.qwen3_5 import ( + apply_liger_kernel_to_qwen3_5, + ) + + apply_liger_kernel_to_qwen3_5( + cross_entropy=cfg.liger_cross_entropy, + fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy, + glu_activation=cfg.liger_glu_activation, + rms_norm=cfg.liger_rms_norm, + rms_norm_gated=getattr(cfg, "liger_rms_norm_gated", False), + layer_norm=cfg.liger_layer_norm, + ) elif cfg.model_config_type == "qwen3_moe": from axolotl.integrations.liger.models.qwen3_moe import ( apply_liger_kernel_to_qwen3_moe, @@ -186,6 +199,19 @@ class LigerPlugin(BasePlugin): rms_norm=cfg.liger_rms_norm, layer_norm=cfg.liger_layer_norm, ) + elif cfg.model_config_type == "qwen3_5_moe": + from axolotl.integrations.liger.models.qwen3_5_moe import ( + apply_liger_kernel_to_qwen3_5_moe, + ) + + apply_liger_kernel_to_qwen3_5_moe( + cross_entropy=cfg.liger_cross_entropy, + fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy, + glu_activation=cfg.liger_glu_activation, + rms_norm=cfg.liger_rms_norm, + rms_norm_gated=getattr(cfg, "liger_rms_norm_gated", False), + layer_norm=cfg.liger_layer_norm, + ) elif cfg.model_config_type == "granitemoe": from liger_kernel.transformers import apply_liger_kernel_to_granite diff --git a/src/axolotl/kernels/rms_norm_gated.py b/src/axolotl/kernels/rms_norm_gated.py new file mode 100644 index 000000000..6a5ff81bf --- /dev/null +++ b/src/axolotl/kernels/rms_norm_gated.py @@ -0,0 +1,333 @@ +""" +Fused RMSNorm + SiLU Gate Triton kernel. + +Computes: Y = (W + offset) * RMSNorm(X) * silu(G) +where RMSNorm(X) = X / sqrt(mean(X^2) + eps) +and silu(G) = G * sigmoid(G) + +Used by Qwen3.5's GatedDeltaNet linear attention layers (Qwen3_5RMSNormGated). +""" + +import math +import operator + +import torch +import triton +import triton.language as tl +from liger_kernel.ops.utils import ( + calculate_settings, + compare_version, + ensure_contiguous, + torch_to_triton_dtype, +) +from liger_kernel.utils import is_npu_available + +if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available(): + try: + from triton.language.extra.libdevice import rsqrt + except ModuleNotFoundError: + from triton.language.extra.cuda.libdevice import rsqrt +else: + from triton.language.math import rsqrt + + +@triton.jit +def _rms_norm_gated_forward_kernel( + Y_ptr, + Y_row_stride, + X_ptr, + X_row_stride, + G_ptr, + G_row_stride, + W_ptr, + W_row_stride, + RSTD_ptr, + RSTD_row_stride, + n_cols, + eps, + offset, + BLOCK_SIZE: tl.constexpr, +): + """ + Y = (W + offset) * (X / RMS(X)) * silu(G) + + All computation done in fp32 (Gemma-style), result cast to input dtype. + """ + row_idx = tl.program_id(0).to(tl.int64) + col_offsets = tl.arange(0, BLOCK_SIZE) + mask = col_offsets < n_cols + + X_row = tl.load(X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0) + G_row = tl.load(G_ptr + row_idx * G_row_stride + col_offsets, mask=mask, other=0) + W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0) + + X_row_dtype = X_row.dtype + + # Cast everything to fp32 + X_fp32 = X_row.to(tl.float32) + G_fp32 = G_row.to(tl.float32) + W_fp32 = W_row.to(tl.float32) + + # RMS norm + mean_sq = tl.sum(X_fp32 * X_fp32, axis=0) / n_cols + rstd = rsqrt(mean_sq + eps) + tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd) + + X_norm = X_fp32 * rstd + + # SiLU gate: silu(G) = G * sigmoid(G) + sig_G = tl.sigmoid(G_fp32) + silu_G = G_fp32 * sig_G + + # Fused output + Y_row = (offset + W_fp32) * X_norm * silu_G + + tl.store( + Y_ptr + row_idx * Y_row_stride + col_offsets, + Y_row.to(X_row_dtype), + mask=mask, + ) + + +@triton.jit +def _rms_norm_gated_backward_kernel( + dY_ptr, + dY_row_stride, + dX_ptr, + dX_row_stride, + dG_ptr, + dG_row_stride, + X_ptr, + X_row_stride, + X_dtype: tl.constexpr, + G_ptr, + G_row_stride, + W_ptr, + W_row_stride, + RSTD_ptr, + RSTD_row_stride, + dW_ptr, + dW_row_stride, + n_rows, + n_cols, + offset, + rows_per_program, + BLOCK_SIZE: tl.constexpr, +): + """ + Backward for Y = (W + offset) * (X * RSTD) * silu(G) + + dW = sum_batch(dY * X_norm * silu(G)) + dG = dY * (W + offset) * X_norm * silu'(G) + where silu'(G) = sigmoid(G) * (1 + G * (1 - sigmoid(G))) + dX = RSTD * (m - (1/N) * RSTD^2 * dot(m, X) * X) + where m = dY * (W + offset) * silu(G) + """ + row_block_id = tl.program_id(0).to(tl.int64) + row_start = row_block_id * rows_per_program + row_end = min((row_block_id + 1) * rows_per_program, n_rows) + col_offsets = tl.arange(0, BLOCK_SIZE) + mask = col_offsets < n_cols + + dW_acc = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) + + W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0) + W_row = W_row.to(tl.float32) + offset + + for row_idx in range(row_start, row_end): + dY_row = tl.load( + dY_ptr + row_idx * dY_row_stride + col_offsets, mask=mask, other=0.0 + ) + X_row = tl.load( + X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0.0 + ) + G_row = tl.load( + G_ptr + row_idx * G_row_stride + col_offsets, mask=mask, other=0.0 + ) + rstd_row = tl.load(RSTD_ptr + row_idx * RSTD_row_stride) + + # Cast to fp32 + dY_fp32 = dY_row.to(tl.float32) + X_fp32 = X_row.to(tl.float32) + G_fp32 = G_row.to(tl.float32) + + # Recompute intermediates + X_norm = X_fp32 * rstd_row + sig_G = tl.sigmoid(G_fp32) + silu_G = G_fp32 * sig_G + + # dW: accumulate dY * X_norm * silu(G) + dW_acc += dY_fp32 * X_norm * silu_G + + # dG: dY * (W + offset) * X_norm * silu'(G) + # silu'(G) = sigmoid(G) * (1 + G * (1 - sigmoid(G))) + silu_prime_G = sig_G * (1.0 + G_fp32 * (1.0 - sig_G)) + dG_row = dY_fp32 * W_row * X_norm * silu_prime_G + tl.store( + dG_ptr + row_idx * dG_row_stride + col_offsets, + dG_row.to(X_dtype), + mask=mask, + ) + + # dX: standard RMSNorm backward with effective gradient m = dY * W * silu(G) + m = dY_fp32 * W_row * silu_G + dX_row = rstd_row * m + dX_row += rstd_row * ( + -(1.0 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_fp32, axis=0) * X_fp32 + ) + tl.store( + dX_ptr + row_idx * dX_row_stride + col_offsets, + dX_row.to(X_dtype), + mask=mask, + ) + + tl.store( + dW_ptr + row_block_id * dW_row_stride + col_offsets, + dW_acc, + mask=mask, + ) + + +def rms_norm_gated_forward(X, G, W, eps, offset): + shape = X.shape + dim = shape[-1] + X = X.view(-1, dim) + G = G.view(-1, dim) + n_rows, n_cols = X.shape + BLOCK_SIZE, num_warps = calculate_settings(n_cols) + + Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device) + RSTD = torch.empty(n_rows, dtype=torch.float32, device=X.device) + + assert X.shape[1] == W.shape[0], ( + f"Incompatible hidden size: X.shape[1]={X.shape[1]} vs W.shape[0]={W.shape[0]}" + ) + assert X.shape == G.shape, ( + f"X and G must have same shape, got {X.shape} and {G.shape}" + ) + + _rms_norm_gated_forward_kernel[(n_rows,)]( + Y, + Y.stride(0), + X, + X.stride(0), + G, + G.stride(0), + W, + W.stride(0), + RSTD, + RSTD.stride(0), + n_cols, + eps, + offset, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + ) + return Y.view(*shape), X, G, RSTD, BLOCK_SIZE, num_warps + + +def rms_norm_gated_backward(dY, X, G, W, RSTD, offset, BLOCK_SIZE, num_warps): + shape = dY.shape + dim = shape[-1] + dY = dY.view(-1, dim) + n_rows, n_cols = dY.shape + + sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count + + _dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device) + dX = torch.empty_like(dY) + dG = torch.empty_like(dY) + + rows_per_program = math.ceil(n_rows / sm_count) + grid = (sm_count,) + + _rms_norm_gated_backward_kernel[grid]( + dY, + dY.stride(0), + dX, + dX.stride(0), + dG, + dG.stride(0), + X, + X.stride(0), + torch_to_triton_dtype[X.dtype], + G, + G.stride(0), + W, + W.stride(0), + RSTD, + RSTD.stride(0), + _dW, + _dW.stride(0), + n_rows, + n_cols, + offset, + rows_per_program, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + ) + + dX = dX.view(*shape) + dG = dG.view(*shape) + dW = _dW.sum(dim=0).to(W.dtype) + return dX, dG, dW + + +class FusedRMSNormGatedFunction(torch.autograd.Function): + @staticmethod + @ensure_contiguous + def forward(ctx, X, G, W, eps, offset=0.0): + """ + X: (B, T, H) or (BxT, H) — input hidden states + G: (B, T, H) or (BxT, H) — gate tensor + W: (H,) — weight parameter + """ + Y, X, G, RSTD, BLOCK_SIZE, num_warps = rms_norm_gated_forward( + X, G, W, eps, offset + ) + ctx.offset = offset + ctx.BLOCK_SIZE = BLOCK_SIZE + ctx.num_warps = num_warps + ctx.save_for_backward(X, G, W, RSTD) + return Y + + @staticmethod + @ensure_contiguous + def backward(ctx, dY): + X, G, W, RSTD = ctx.saved_tensors + dX, dG, dW = rms_norm_gated_backward( + dY, X, G, W, RSTD, ctx.offset, ctx.BLOCK_SIZE, ctx.num_warps + ) + return dX, dG, dW, None, None + + +class FusedRMSNormGated(torch.nn.Module): + """ + Fused RMSNorm + SiLU Gate. + + Computes: Y = W * RMSNorm(X) * silu(G) + + Drop-in replacement for Qwen3_5RMSNormGated with matching + init signature: __init__(hidden_size, eps=1e-6, **kwargs) + and forward signature: forward(hidden_states, gate=None) + """ + + def __init__(self, hidden_size, eps=1e-6, offset=0.0, **kwargs): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + self.offset = offset + + def forward(self, hidden_states, gate=None): + if gate is None: + raise ValueError("FusedRMSNormGated requires a gate tensor") + if hidden_states.device.type != "cuda": + raise ValueError( + f"FusedRMSNormGated requires CUDA tensors, got device={hidden_states.device}" + ) + return FusedRMSNormGatedFunction.apply( + hidden_states, gate, self.weight, self.variance_epsilon, self.offset + ) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/tests/kernels/test_rms_norm_gated.py b/tests/kernels/test_rms_norm_gated.py new file mode 100644 index 000000000..10b379184 --- /dev/null +++ b/tests/kernels/test_rms_norm_gated.py @@ -0,0 +1,229 @@ +""" +Correctness tests for fused RMSNorm + SiLU Gate kernel. + +Tests against the eager Qwen3_5RMSNormGated implementation. +""" + +import pytest +import torch +import torch.nn.functional as F + +pytest.importorskip("triton", reason="triton required for fused kernels") + +if not torch.cuda.is_available(): + pytest.skip("CUDA required for fused kernel tests", allow_module_level=True) + +from axolotl.kernels.rms_norm_gated import FusedRMSNormGated + + +class EagerRMSNormGated(torch.nn.Module): + """Reference implementation matching Qwen3_5RMSNormGated exactly.""" + + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states, gate=None): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + hidden_states = self.weight * hidden_states.to(input_dtype) + hidden_states = hidden_states * F.silu(gate.to(torch.float32)) + return hidden_states.to(input_dtype) + + +def _sync_weights(eager_mod, fused_mod): + """Copy weights from eager to fused module.""" + fused_mod.weight.data.copy_(eager_mod.weight.data) + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) +@pytest.mark.parametrize( + "shape", + [ + (2, 128, 256), + (4, 64, 512), + (1, 32, 1024), + (2, 16, 2560), # Qwen3.5-4B hidden_size + (2, 16, 4096), # Qwen3.5-9B hidden_size + (1, 8, 5120), # Qwen3.5-27B hidden_size + (4, 16, 2048), # Qwen3.5-35B-A3B (MoE) hidden_size + (4, 16, 3072), # Qwen3.5-122B-A10B (MoE) hidden_size + ], +) +class TestRMSNormGatedForward: + def test_output_matches_eager(self, dtype, shape): + torch.manual_seed(42) + B, T, H = shape + X = torch.randn(B, T, H, dtype=dtype, device="cuda") + G = torch.randn(B, T, H, dtype=dtype, device="cuda") + + eager = EagerRMSNormGated(H).to(dtype=dtype, device="cuda") + fused = FusedRMSNormGated(H).to(dtype=dtype, device="cuda") + _sync_weights(eager, fused) + + y_eager = eager(X, gate=G) + y_fused = fused(X, gate=G) + + if dtype == torch.float32: + torch.testing.assert_close(y_fused, y_eager, atol=1e-5, rtol=1e-5) + else: + torch.testing.assert_close(y_fused, y_eager, atol=1e-2, rtol=1e-2) + + def test_output_shape(self, dtype, shape): + B, T, H = shape + X = torch.randn(B, T, H, dtype=dtype, device="cuda") + G = torch.randn(B, T, H, dtype=dtype, device="cuda") + + fused = FusedRMSNormGated(H).to(dtype=dtype, device="cuda") + y = fused(X, gate=G) + assert y.shape == (B, T, H) + assert y.dtype == dtype + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) +@pytest.mark.parametrize( + "shape", + [ + (2, 32, 256), + (2, 16, 512), + (2, 16, 2560), # Qwen3.5-4B + (1, 8, 4096), # Qwen3.5-9B + (1, 8, 5120), # Qwen3.5-27B + (2, 16, 2048), # Qwen3.5-35B-A3B (MoE) + (2, 16, 3072), # Qwen3.5-122B-A10B (MoE) + ], +) +class TestRMSNormGatedBackward: + def test_grad_x(self, dtype, shape): + torch.manual_seed(42) + B, T, H = shape + X = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True) + G = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True) + X_ref = X.detach().clone().requires_grad_(True) + G_ref = G.detach().clone().requires_grad_(True) + + eager = EagerRMSNormGated(H).to(dtype=dtype, device="cuda") + fused = FusedRMSNormGated(H).to(dtype=dtype, device="cuda") + _sync_weights(eager, fused) + + y_eager = eager(X_ref, gate=G_ref) + y_fused = fused(X, gate=G) + + grad_out = torch.randn_like(y_eager) + y_eager.backward(grad_out) + y_fused.backward(grad_out) + + if dtype == torch.float32: + atol, rtol = 1e-4, 1e-4 + else: + atol, rtol = 5e-2, 5e-2 + + torch.testing.assert_close(X.grad, X_ref.grad, atol=atol, rtol=rtol) + + def test_grad_gate(self, dtype, shape): + torch.manual_seed(42) + B, T, H = shape + X = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True) + G = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True) + X_ref = X.detach().clone().requires_grad_(True) + G_ref = G.detach().clone().requires_grad_(True) + + eager = EagerRMSNormGated(H).to(dtype=dtype, device="cuda") + fused = FusedRMSNormGated(H).to(dtype=dtype, device="cuda") + _sync_weights(eager, fused) + + y_eager = eager(X_ref, gate=G_ref) + y_fused = fused(X, gate=G) + + grad_out = torch.randn_like(y_eager) + y_eager.backward(grad_out) + y_fused.backward(grad_out) + + if dtype == torch.float32: + atol, rtol = 1e-4, 1e-4 + else: + atol, rtol = 5e-2, 5e-2 + + torch.testing.assert_close(G.grad, G_ref.grad, atol=atol, rtol=rtol) + + def test_grad_weight(self, dtype, shape): + torch.manual_seed(42) + B, T, H = shape + X = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True) + G = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True) + X_ref = X.detach().clone().requires_grad_(True) + G_ref = G.detach().clone().requires_grad_(True) + + eager = EagerRMSNormGated(H).to(dtype=dtype, device="cuda") + fused = FusedRMSNormGated(H).to(dtype=dtype, device="cuda") + _sync_weights(eager, fused) + + y_eager = eager(X_ref, gate=G_ref) + y_fused = fused(X, gate=G) + + grad_out = torch.randn_like(y_eager) + y_eager.backward(grad_out) + y_fused.backward(grad_out) + + if dtype == torch.float32: + atol, rtol = 1e-4, 1e-4 + else: + atol, rtol = 5e-2, 5e-2 + + torch.testing.assert_close( + fused.weight.grad, eager.weight.grad, atol=atol, rtol=rtol + ) + + +class TestRMSNormGatedEdgeCases: + def test_gate_none_raises(self): + fused = FusedRMSNormGated(256).cuda() + X = torch.randn(2, 4, 256, device="cuda") + with pytest.raises(ValueError, match="requires a gate tensor"): + fused(X, gate=None) + + def test_2d_input(self): + """Test with (BxT, H) shaped input instead of (B, T, H).""" + torch.manual_seed(42) + H = 512 + X = torch.randn(64, H, dtype=torch.bfloat16, device="cuda", requires_grad=True) + G = torch.randn(64, H, dtype=torch.bfloat16, device="cuda", requires_grad=True) + X_ref = X.detach().clone().requires_grad_(True) + G_ref = G.detach().clone().requires_grad_(True) + + eager = EagerRMSNormGated(H).to(dtype=torch.bfloat16, device="cuda") + fused = FusedRMSNormGated(H).to(dtype=torch.bfloat16, device="cuda") + _sync_weights(eager, fused) + + y_eager = eager(X_ref, gate=G_ref) + y_fused = fused(X, gate=G) + + torch.testing.assert_close(y_fused, y_eager, atol=1e-2, rtol=1e-2) + + grad_out = torch.randn_like(y_eager) + y_eager.backward(grad_out) + y_fused.backward(grad_out) + + torch.testing.assert_close(X.grad, X_ref.grad, atol=5e-2, rtol=5e-2) + torch.testing.assert_close(G.grad, G_ref.grad, atol=5e-2, rtol=5e-2) + + def test_random_weight_init(self): + """Test with non-default weight values.""" + torch.manual_seed(123) + H = 256 + X = torch.randn(2, 16, H, dtype=torch.bfloat16, device="cuda") + G = torch.randn(2, 16, H, dtype=torch.bfloat16, device="cuda") + + eager = EagerRMSNormGated(H).to(dtype=torch.bfloat16, device="cuda") + # Randomize weights + eager.weight.data = torch.randn_like(eager.weight.data) + + fused = FusedRMSNormGated(H).to(dtype=torch.bfloat16, device="cuda") + _sync_weights(eager, fused) + + y_eager = eager(X, gate=G) + y_fused = fused(X, gate=G) + torch.testing.assert_close(y_fused, y_eager, atol=1e-2, rtol=1e-2)