From 08fc7de87e79f38c367f6776c5111b40a914062e Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 2 Apr 2026 17:46:46 -0400 Subject: [PATCH] gemma4 support (#3574) * gemma4 support * fixes * chore: lint --- examples/gemma4/26b-a4b-moe-qlora.yaml | 104 ++ requirements.txt | 2 +- src/axolotl/core/trainers/base.py | 31 +- src/axolotl/integrations/kernels/README.md | 23 +- src/axolotl/integrations/kernels/args.py | 12 +- src/axolotl/integrations/kernels/constants.py | 44 + .../libs/scattermoe_lora/gemma4_experts.py | 235 ++++ .../kernels/libs/sonicmoe/gemma4_experts.py | 106 ++ .../kernels/libs/sonicmoe/routing.py | 73 ++ src/axolotl/integrations/kernels/plugin.py | 52 +- src/axolotl/monkeypatch/lora_kernels.py | 76 +- .../monkeypatch/models/qwen3_next/modeling.py | 41 +- .../chat_templates/templates/gemma4.jinja | 271 +++++ src/axolotl/utils/schemas/enums.py | 1 + src/axolotl/utils/schemas/trl.py | 4 +- tests/integrations/test_gemma4_moe.py | 1052 +++++++++++++++++ 16 files changed, 2082 insertions(+), 45 deletions(-) create mode 100644 examples/gemma4/26b-a4b-moe-qlora.yaml create mode 100644 src/axolotl/integrations/kernels/libs/scattermoe_lora/gemma4_experts.py create mode 100644 src/axolotl/integrations/kernels/libs/sonicmoe/gemma4_experts.py create mode 100644 src/axolotl/utils/chat_templates/templates/gemma4.jinja create mode 100644 tests/integrations/test_gemma4_moe.py diff --git a/examples/gemma4/26b-a4b-moe-qlora.yaml b/examples/gemma4/26b-a4b-moe-qlora.yaml new file mode 100644 index 000000000..0972b93f6 --- /dev/null +++ b/examples/gemma4/26b-a4b-moe-qlora.yaml @@ -0,0 +1,104 @@ +# Gemma 4 26B-A4B MoE QLoRA with ScatterMoE kernels +# +# Validated: 50 steps on FineTome-100k, loss 7.4 -> 2.4, single RTX 5090 (32GB) +# +# Key notes: +# - Flash Attention 2 is NOT supported (global_head_dim=512 > FA2 max of 256). +# Use sdp_attention instead. +# - Gemma 4 is multimodal (text+vision+audio). For text-only SFT, restrict +# LoRA to the text backbone via lora_target_linear_modules regex. +# - MoE experts use `experts_implementation: scattermoe` — Gemma 4 embeds MoE +# directly in the decoder layer (no SparseMoeBlock), so we register ScatterMoE +# via the transformers ExpertsInterface. +# - Expert LoRA targets are `experts.gate_up_proj` / `experts.down_proj` +# (no `mlp.` prefix, unlike Qwen/Mixtral). +# - micro_batch_size: 1 fits 2048 seq_len on 32GB GPU with SDP attention. +# Use micro_batch_size: 4 with 1024 seq_len, or on 48GB+ GPUs. + +base_model: google/gemma-4-26B-A4B + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + - axolotl.integrations.kernels.KernelsPlugin + - axolotl.integrations.liger.LigerPlugin +use_kernels: true +use_scattermoe: true +experts_implementation: scattermoe +torch_compile: false +liger_layer_norm: true +liger_rope: true +liger_rms_norm: true +liger_glu_activation: true +liger_rms_norm_gated: true +strict: false + +chat_template: gemma4 +datasets: + - path: mlabonne/FineTome-100k + type: chat_template + split: train[:10%] + field_messages: conversations + message_property_mappings: + role: from + content: value +val_set_size: 0.05 +output_dir: ./outputs/gemma4-26b-a4b-qlora + +sequence_len: 2048 +sample_packing: true + +load_in_4bit: true +quantize_moe_experts: true +adapter: qlora +lora_r: 16 +lora_alpha: 32 +lora_dropout: 0 + +# Restrict LoRA to text backbone only (skip vision/audio encoders). +# lora_target_modules is intentionally empty — all module targeting is done +# via regex in lora_target_linear_modules below. +lora_target_modules: [] +lora_target_linear_modules: + - language_model\.model\.layers\.\d+\.self_attn\.(q|k|v|o)_proj + +# MoE expert LoRA (3D Parameter tensors, not nn.Linear) +lora_target_parameters: + - experts.gate_up_proj + - experts.down_proj + +lora_mlp_kernel: false +lora_qkv_kernel: false +lora_o_kernel: false + +bnb_config_kwargs: + bnb_4bit_use_double_quant: true + +wandb_project: gemma4-qlora +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 1 +num_epochs: 1 +optimizer: adamw_torch_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: auto +tf32: true + +gradient_checkpointing: true +activation_offloading: true +logging_steps: 1 + +# FA2 not supported — Gemma4 global_head_dim=512 exceeds FA2 max of 256 +flash_attention: false +sdp_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: 4 +saves_per_epoch: 1 +weight_decay: 0.0 +special_tokens: diff --git a/requirements.txt b/requirements.txt index fb429df90..446febb83 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,7 +12,7 @@ packaging==26.0 huggingface_hub>=1.1.7 peft>=0.18.1 tokenizers>=0.22.1 -transformers==5.4.0 +transformers==5.5.0 accelerate==1.13.0 datasets==4.5.0 deepspeed>=0.18.6,<0.19.0 diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 5bc44a1dd..6beff8055 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -381,6 +381,15 @@ class AxolotlTrainer( # Store per-step trainable tokens for throughput calculation self.state.tokens["trainable_tokens"] = trainable_tokens.detach().cpu() + # Gemma4 requires mm_token_type_ids during training (even for text-only). + # Inject zeros (= text token type) when not provided by the data collator. + if ( + "mm_token_type_ids" not in inputs + and "input_ids" in inputs + and getattr(getattr(model, "config", None), "model_type", None) == "gemma4" + ): + inputs["mm_token_type_ids"] = torch.zeros_like(inputs["input_ids"]) + if self.args.orpo_alpha: return self.orpo_compute_loss( model, @@ -508,12 +517,24 @@ class AxolotlTrainer( ) # Perform a single forward pass + forward_kwargs = { + "input_ids": concat_inputs["input_ids"], + "attention_mask": concat_inputs["attention_mask"], + "labels": concat_inputs["labels"], + } + # Gemma4 requires mm_token_type_ids during training (even for text-only) + if ( + getattr(getattr(model, "config", None), "model_type", None) == "gemma4" + and "mm_token_type_ids" not in concat_inputs + ): + forward_kwargs["mm_token_type_ids"] = torch.zeros_like( + concat_inputs["input_ids"] + ) + elif "mm_token_type_ids" in concat_inputs: + forward_kwargs["mm_token_type_ids"] = concat_inputs["mm_token_type_ids"] + outputs = model( - **{ - "input_ids": concat_inputs["input_ids"], - "attention_mask": concat_inputs["attention_mask"], - "labels": concat_inputs["labels"], - }, + **forward_kwargs, output_hidden_states=True, ) diff --git a/src/axolotl/integrations/kernels/README.md b/src/axolotl/integrations/kernels/README.md index a852cd6cf..9293c1727 100644 --- a/src/axolotl/integrations/kernels/README.md +++ b/src/axolotl/integrations/kernels/README.md @@ -28,7 +28,7 @@ use_scattermoe: true use_sonicmoe: true ``` -**Important:** Setting `experts_implementation` is incompatible with custom kernel options. +**Important:** Setting `experts_implementation` to `batched_mm` or `grouped_mm` is incompatible with custom kernel options. The exception is `experts_implementation: scattermoe`, which is used for models like Gemma 4 that embed MoE directly in the decoder layer (no SparseMoeBlock) and dispatch through the transformers `ExpertsInterface`. ### SonicMoE installation @@ -63,7 +63,7 @@ Both paths use the shared `resolve_moe_block_classes` utility in `constants.py` ## Model Support Matrix -All models use the **SwiGLU** activation (`act_fn(gate) * up`). Neither kernel currently supports non-SwiGLU MoE architectures. +Most models use the **SwiGLU** activation (`silu(gate) * up`). Gemma 4 uses **GEGLU** (`gelu(gate) * up`). ScatterMoE supports any gated activation (activation is applied in Python between kernel calls). SonicMoE supports SwiGLU, GEGLU, and REGLU via its `ActivationType` enum. ### Routing strategies @@ -76,6 +76,7 @@ All models use the **SwiGLU** activation (`act_fn(gate) * up`). Neither kernel c | softmax → bias correction → topk | Softmax, bias via `gate.moe_statics`, topk, gather from original probs, clamp-based renorm | No | Yes | | softmax → group_limited_greedy | Softmax, group selection (max per group), topk, scale only (no renorm) | No | Yes | | softmax → topk via gate.wg | Softmax, gate weight at `gate.wg.weight` (not `gate.weight`), always renormalize | No | Yes | +| softmax → topk + per_expert_scale | RMSNorm → scale → proj → softmax → topk → renorm → per-expert learned scales | Yes | Yes | | fused topk → softmax | Routing + expert computation fused in a single kernel | No | Planned | ### Per-model support @@ -102,10 +103,13 @@ All models use the **SwiGLU** activation (`act_fn(gate) * up`). Neither kernel c | `ernie4_5_moe` | ERNIE 4.5 MoE | softmax → bias → topk | No | **Yes** | | `deepseek_v2` | DeepSeek-V2 | softmax → group_limited_greedy | No | **Yes** | | `hunyuan_v1_moe` | HunYuan V1 MoE | softmax → topk (gate.wg) | No | **Yes** | +| `gemma4_text` | Gemma 4 (26B-A4B) | softmax → topk + per_expert_scale | **Yes**\*\* | **Yes**\*\* | | `gpt_oss` | GPT-OSS | fused topk → softmax | No | Planned | \* `glm4_moe_lite` with ScatterMoE may have issues — see Limitations. +\*\* Gemma 4 uses `experts_implementation: scattermoe` path (registered via `ExpertsInterface`) instead of SparseMoeBlock patching, since Gemma 4 embeds MoE directly in its decoder layer (no separate SparseMoeBlock). See the [Gemma 4 section](#gemma-4) below. + ### Feature comparison | Feature | ScatterMoE | SonicMoE | @@ -131,6 +135,21 @@ Both kernels handle shared experts identically. Shared expert attribute names ar If `shared_expert_gate` exists, sigmoid gating is applied to the shared expert contribution before adding it to the routed output. PEFT wraps shared expert linear layers with standard LoRA — no special handling is needed. +## Gemma 4 + +Gemma 4 (e.g. `google/gemma-4-26B-A4B`) has a unique hybrid MoE architecture: + +- **No SparseMoeBlock**: MoE is embedded directly in the decoder layer alongside a dense MLP. Both run in parallel and their outputs are summed. +- **Custom router** (`Gemma4TextRouter`): RMSNorm → learned scale → linear projection → softmax → top-k → renormalization → per-expert learned scales. +- **GEGLU activation**: Uses `gelu_pytorch_tanh` (not SiLU/SwiGLU like most other MoE models). +- **128 experts, top-k=8** for the 26B-A4B variant. + +Because there is no SparseMoeBlock class to patch, Gemma 4 uses a different integration path: we register `"scattermoe"` as a custom implementation in the transformers `ExpertsInterface`, and set `experts_implementation: scattermoe` in the config. The `@use_experts_implementation` decorator on `Gemma4TextExperts` then dispatches to our ScatterMoE kernel automatically. The router is untouched — it runs as-is. + +**Important limitations:** +- **Flash Attention 2 is not supported** — Gemma 4 uses `global_head_dim: 512` for full attention layers, which exceeds FA2's maximum head dimension of 256. Use `sdp_attention: true` instead. +- **Multimodal model**: Gemma 4 includes vision and audio encoders. For text-only SFT, use `lora_target_linear_modules` with a regex to restrict LoRA to the text backbone (e.g. `language_model\.model\.layers\.\d+\.self_attn\.(q|k|v|o)_proj`). + ## Limitations - **ScatterMoE + GLM4-MoE Lite**: ScatterMoE does not work reliably for GLM 4.7 Flash (`glm4_moe_lite`). diff --git a/src/axolotl/integrations/kernels/args.py b/src/axolotl/integrations/kernels/args.py index 7c9e23b6c..3afeb79c3 100644 --- a/src/axolotl/integrations/kernels/args.py +++ b/src/axolotl/integrations/kernels/args.py @@ -34,12 +34,20 @@ class KernelsArgs(BaseModel): @classmethod def check_experts_implementation(cls, data): experts_implementation = data.get("experts_implementation") + use_scattermoe = data.get("use_scattermoe", False) if experts_implementation is None: # transformers may default to batched_mm when unset data["experts_implementation"] = "eager" - elif experts_implementation != "eager": + elif experts_implementation == "scattermoe" and not use_scattermoe: LOG.warning( - "`experts_implementation` must be set to 'eager' to use this. Automatically setting it to 'eager'." + "`experts_implementation='scattermoe'` requires `use_scattermoe: true`. " + "Automatically setting to 'eager'." + ) + data["experts_implementation"] = "eager" + elif experts_implementation not in ("eager", "scattermoe"): + LOG.warning( + f"`experts_implementation={experts_implementation!r}` is not compatible with " + f"custom MoE kernels. Automatically setting to 'eager'." ) data["experts_implementation"] = "eager" diff --git a/src/axolotl/integrations/kernels/constants.py b/src/axolotl/integrations/kernels/constants.py index a03761484..5239c9877 100644 --- a/src/axolotl/integrations/kernels/constants.py +++ b/src/axolotl/integrations/kernels/constants.py @@ -11,6 +11,7 @@ Models with custom routing (see sonicmoe/routing.py for implementations): - ernie4_5_moe: softmax→bias correction→topk (softmax_bias_topk_routing) - deepseek_v2: softmax→group_limited_greedy (softmax_group_limited_topk_routing) - hunyuan_v1_moe: softmax→topk via gate.wg (softmax_topk_wg_routing) +- gemma4_text: RMSNorm→scale→proj→softmax→topk→renorm→per_expert_scale (experts-level patch) """ import importlib @@ -53,6 +54,49 @@ SPARSE_MOE_BLOCK = { } +# Models where MoE is NOT in a separate SparseMoeBlock but embedded in the +# decoder layer. For these, we patch the Experts class forward directly +# (same signature: hidden_states, top_k_index, top_k_weights -> Tensor). +# Routing stays untouched — the original model router runs as-is. +EXPERTS_ONLY_BLOCK = { + # gemma4: hybrid MLP+MoE in decoder layer, custom Gemma4TextRouter, + # no SparseMoeBlock. Experts use @use_experts_implementation with + # standard 3D param layout (gate_up_proj [E, 2*I, H], down_proj [E, H, I]). + "gemma4_text": "Gemma4TextExperts", +} + + +def resolve_experts_class(model_type: str): + """Resolve the Experts class for models that need experts-level patching. + + Returns the class, or None if the model uses SparseMoeBlock-level patching. + """ + entry = EXPERTS_ONLY_BLOCK.get(model_type) + if entry is None: + return None + + module_path = f"transformers.models.{model_type}.modeling_{model_type}" + try: + module = importlib.import_module(module_path) + except ModuleNotFoundError: + if model_type.endswith("_text"): + parent_type = model_type.removesuffix("_text") + module_path = f"transformers.models.{parent_type}.modeling_{parent_type}" + module = importlib.import_module(module_path) + else: + raise + + cls = getattr(module, entry, None) + if cls is None: + raise ValueError(f"Could not find class '{entry}' in '{module_path}'") + return cls + + +def is_experts_only_model(model_type: str) -> bool: + """Check if a model type requires experts-level (not block-level) patching.""" + return model_type in EXPERTS_ONLY_BLOCK + + def resolve_moe_block_classes(model_type: str): """Resolve all MoE block classes from transformers for the given model type. diff --git a/src/axolotl/integrations/kernels/libs/scattermoe_lora/gemma4_experts.py b/src/axolotl/integrations/kernels/libs/scattermoe_lora/gemma4_experts.py new file mode 100644 index 000000000..66623e017 --- /dev/null +++ b/src/axolotl/integrations/kernels/libs/scattermoe_lora/gemma4_experts.py @@ -0,0 +1,235 @@ +""" +ScatterMoE-accelerated experts forward for Gemma4. + +Gemma4 has no separate SparseMoeBlock — MoE is embedded in the decoder layer. +The decoder layer handles routing (Gemma4TextRouter) and calls +``experts(hidden_states, top_k_index, top_k_weights)`` directly. + +This module registers a ``"scattermoe"`` implementation in the transformers +``ExpertsInterface``, which the ``@use_experts_implementation`` decorator +dispatches to when ``config._experts_implementation == "scattermoe"``. + +This is the clean way to hook into transformers' MoE dispatch — no +monkeypatching required. Works for Gemma4 and any future model that uses +``@use_experts_implementation`` with the standard forward signature +``(hidden_states, top_k_index, top_k_weights) -> Tensor``. +""" + +import torch + +from .parallel_experts import flatten_sort_count, parallel_linear +from .parallel_linear_lora import get_lora_params_from_wrapper, parallel_linear_lora + + +def _has_peft_wrapper(module): + """Check if a module's parameter has been wrapped by PEFT ParamWrapper.""" + try: + from peft.tuners.param_wrapper import ParamWrapper + + for attr in ("gate_up_proj", "down_proj"): + param = getattr(module, attr, None) + if isinstance(param, ParamWrapper): + return True + except ImportError: + pass + return False + + +def _unwrap_experts_lora(experts): + """Extract base weights and LoRA params from a PEFT-wrapped Experts module. + + Returns: + (base_experts, gup_lora, down_lora) where each lora is + (lora_A, lora_B, scaling) or None. + """ + try: + from peft.tuners.param_wrapper import ParamWrapper + except ImportError: + return experts, None, None + + if not isinstance(getattr(experts, "gate_up_proj", None), ParamWrapper): + return experts, None, None + + base_experts = experts + gup_lora = None + down_lora = None + + gup_param = experts.gate_up_proj + if isinstance(gup_param, ParamWrapper): + lora_A, lora_B, scaling = get_lora_params_from_wrapper(gup_param) + if lora_A is not None: + num_experts = experts.num_experts + rank = lora_A.shape[0] // num_experts + from .layers import peft_lora_to_scattermoe + + sm_A, sm_B = peft_lora_to_scattermoe(lora_A, lora_B, num_experts, rank) + gup_lora = (sm_A, sm_B, scaling) + + down_param = experts.down_proj + if isinstance(down_param, ParamWrapper): + lora_A, lora_B, scaling = get_lora_params_from_wrapper(down_param) + if lora_A is not None: + num_experts = experts.num_experts + rank = lora_A.shape[0] // num_experts + from .layers import peft_lora_to_scattermoe + + sm_A, sm_B = peft_lora_to_scattermoe(lora_A, lora_B, num_experts, rank) + down_lora = (sm_A, sm_B, scaling) + + return base_experts, gup_lora, down_lora + + +def _get_base_param(param): + """Get the base tensor from a PEFT ParamWrapper or regular Parameter.""" + try: + from peft.tuners.param_wrapper import ParamWrapper + + while isinstance(param, ParamWrapper): + param = param.original_parameter + except ImportError: + pass + return param + + +def _parallel_linear_maybe_lora( + x, + weight, + top_k, + sorted_expert_idxs, + sorted_scattered_idxs, + expert_offsets, + lora_tuple, + grouped_in, + grouped_out, + gates=None, +): + """Call parallel_linear or parallel_linear_lora depending on whether LoRA is active.""" + if lora_tuple is not None: + lora_A, lora_B, scaling = lora_tuple + return parallel_linear_lora( + x, + weight, + top_k, + sorted_expert_idxs, + sorted_scattered_idxs, + expert_offsets, + lora_A, + lora_B, + scaling, + grouped_in=grouped_in, + grouped_out=grouped_out, + gates=gates, + ) + return parallel_linear( + x, + weight, + top_k, + sorted_expert_idxs, + sorted_scattered_idxs, + expert_offsets, + grouped_in=grouped_in, + grouped_out=grouped_out, + gates=gates, + ) + + +def scattermoe_experts_forward( + self, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, +) -> torch.Tensor: + """ScatterMoE-accelerated experts forward. + + Drop-in replacement for the standard Experts forward signature used by + ``@use_experts_implementation``-decorated classes (Gemma4, Mixtral, etc.): + ``(hidden_states [T, H], top_k_index [T, K], top_k_weights [T, K]) -> [T, H]`` + """ + K = top_k_index.shape[1] + + routing_weights = top_k_weights.to(hidden_states.dtype) + sorted_expert_idxs, sorted_scattered_idxs, expert_offsets = flatten_sort_count( + top_k_index, num_experts=self.num_experts + ) + + # Get base weights (unwrap PEFT if needed) + gate_up_weight = _get_base_param(self.gate_up_proj).transpose(2, 1) + down_weight = _get_base_param(self.down_proj).transpose(2, 1) + + # Extract LoRA params if PEFT is active + gup_lora, down_lora = None, None + if _has_peft_wrapper(self): + _, gup_lora, down_lora = _unwrap_experts_lora(self) + + # Gate-up projection (with optional LoRA) + gates_h = _parallel_linear_maybe_lora( + hidden_states, + gate_up_weight, + K, + sorted_expert_idxs, + sorted_scattered_idxs, + expert_offsets, + gup_lora, + grouped_in=False, + grouped_out=True, + ) + gates, h = gates_h.chunk(2, dim=-1) + h = self.act_fn(gates) * h + + # Down projection (with optional LoRA + routing weights) + output = _parallel_linear_maybe_lora( + h, + down_weight, + 1, + sorted_expert_idxs, + sorted_scattered_idxs, + expert_offsets, + down_lora, + grouped_in=True, + grouped_out=False, + gates=routing_weights, + ) + + return output + + +def register_scattermoe_experts(): + """Register ``"scattermoe"`` in the transformers ExpertsInterface. + + After calling this, any model with ``@use_experts_implementation`` will + dispatch to ScatterMoE when ``config._experts_implementation == "scattermoe"``. + + Also patches ``get_correct_experts_implementation`` to accept ``"scattermoe"`` + as a valid value (transformers hardcodes an allowlist). + """ + from transformers.integrations.moe import ALL_EXPERTS_FUNCTIONS + from transformers.modeling_utils import PreTrainedModel + + # 1. Register the forward function in the global interface + ALL_EXPERTS_FUNCTIONS.register("scattermoe", scattermoe_experts_forward) + + # 2. Patch the validation to accept "scattermoe" + _original_get_correct = PreTrainedModel.get_correct_experts_implementation + + def _patched_get_correct(self_model, requested_experts: str | None) -> str: + if requested_experts == "scattermoe": + return "scattermoe" + return _original_get_correct(self_model, requested_experts) + + PreTrainedModel.get_correct_experts_implementation = _patched_get_correct + + +# Legacy monkeypatch approach (kept for backward compat with existing tests) +def patch_gemma4_scattermoe(): + """Monkeypatch Gemma4TextExperts.forward with ScatterMoE kernel.""" + from axolotl.integrations.kernels.constants import resolve_experts_class + + experts_cls = resolve_experts_class("gemma4_text") + if experts_cls is None: + raise ValueError("Could not resolve Gemma4TextExperts class") + + if hasattr(experts_cls, "_original_forward"): + return # already patched + + experts_cls._original_forward = experts_cls.forward + experts_cls.forward = scattermoe_experts_forward diff --git a/src/axolotl/integrations/kernels/libs/sonicmoe/gemma4_experts.py b/src/axolotl/integrations/kernels/libs/sonicmoe/gemma4_experts.py new file mode 100644 index 000000000..a4025dd84 --- /dev/null +++ b/src/axolotl/integrations/kernels/libs/sonicmoe/gemma4_experts.py @@ -0,0 +1,106 @@ +""" +SonicMoE-accelerated experts forward for Gemma4. + +Gemma4 has no separate SparseMoeBlock — MoE is embedded in the decoder layer. +This module provides a drop-in replacement for ``Gemma4TextExperts.forward`` +that uses SonicMoE kernels while preserving the original call signature. +""" + +import torch + +from .lora import has_lora, materialize_expert_lora, unwrap_experts_lora + + +def _get_expert_weights_gemma4(experts_module): + """Extract expert weights from Gemma4TextExperts, applying LoRA if active. + + Returns: + (gate_up_weight, down_weight) in SonicMoE layout [dim, dim, E]. + """ + if has_lora(experts_module): + base_experts, lora_dict = unwrap_experts_lora(experts_module) + gate_up = materialize_expert_lora( + base_experts.gate_up_proj, lora_dict.get("gate_up_proj") + ) + down = materialize_expert_lora( + base_experts.down_proj, lora_dict.get("down_proj") + ) + else: + gate_up = experts_module.gate_up_proj + down = experts_module.down_proj + + # Permute to SonicMoE layout: + # gate_up: [E, 2*I, H] -> [2*I, H, E] + # down: [E, H, I] -> [H, I, E] + return gate_up.permute(1, 2, 0), down.permute(1, 2, 0) + + +def gemma4_sonicmoe_experts_forward( + self, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, +) -> torch.Tensor: + """SonicMoE-accelerated replacement for Gemma4TextExperts.forward. + + Same signature as the original: (hidden_states [T, H], top_k_index [T, K], + top_k_weights [T, K]) -> output [T, H]. + """ + from sonicmoe import moe_general_routing_inputs + from sonicmoe.enums import ActivationType + + T, _ = hidden_states.shape + K = top_k_index.shape[1] + E = self.num_experts + + # Convert routing outputs to SonicMoE's flat format + # Token indices sorted ascending (required by SonicMoE) + token_indices = ( + torch.arange(T, device=hidden_states.device, dtype=torch.int32) + .unsqueeze(1) + .expand(T, K) + ) + flat_scores = top_k_weights.to(torch.float32).reshape(-1) # [T*K] + flat_token_idx = token_indices.reshape(-1) # [T*K] + flat_expert_idx = top_k_index.to(torch.int32).reshape(-1) # [T*K] + + # Get weights (with LoRA materialization if needed) + gate_up_weight, down_weight = _get_expert_weights_gemma4(self) + gate_up_weight = gate_up_weight.to(hidden_states.dtype) + down_weight = down_weight.to(hidden_states.dtype) + + if not torch.cuda.is_available(): + raise RuntimeError("SonicMoE requires CUDA. No CUDA device available.") + cuda_stream = torch.cuda.current_stream().cuda_stream + + output, _ = moe_general_routing_inputs( + hidden_states, + flat_scores, + flat_token_idx, + flat_expert_idx, + gate_up_weight, + None, # b1 (no gate/up bias) + down_weight, + None, # b2 (no down bias) + E, + cuda_stream, + ActivationType.GEGLU, + False, # is_inference_mode + ) + + return output + + +def patch_gemma4_sonicmoe(): + """Monkeypatch Gemma4TextExperts.forward with SonicMoE kernel.""" + from axolotl.integrations.kernels.constants import resolve_experts_class + + experts_cls = resolve_experts_class("gemma4_text") + if experts_cls is None: + raise ValueError("Could not resolve Gemma4TextExperts class") + + if hasattr(experts_cls, "_original_forward"): + return # already patched + + experts_cls._original_forward = experts_cls.forward + experts_cls.forward = gemma4_sonicmoe_experts_forward diff --git a/src/axolotl/integrations/kernels/libs/sonicmoe/routing.py b/src/axolotl/integrations/kernels/libs/sonicmoe/routing.py index 4bdb37890..68654d086 100644 --- a/src/axolotl/integrations/kernels/libs/sonicmoe/routing.py +++ b/src/axolotl/integrations/kernels/libs/sonicmoe/routing.py @@ -7,6 +7,7 @@ Different MoE architectures use different routing strategies: - glm_moe_dsa / deepseek_v3 / minimax_m2: sigmoid -> topk (with group-based expert selection) - ernie4_5_moe: softmax -> bias correction -> topk -> gather (softmax_bias_topk_routing) - hunyuan_v1_moe: softmax -> topk via gate.wg (softmax_topk_wg_routing) +- gemma4_text: RMSNorm -> scale -> proj -> softmax -> topk -> renorm -> per_expert_scale (gemma4_routing) - gpt_oss: topk -> softmax (uses fused moe_TC_softmax_topk_layer, routing_fn=None) [NOT YET SUPPORTED] Each model type maps to a (routing_fn, activation_type, router_attr) triple. @@ -66,6 +67,8 @@ def get_model_moe_config(model_type: str): return softmax_bias_topk_routing, ActivationType.SWIGLU, "gate" elif model_type in ("hunyuan_v1_moe",): return softmax_topk_wg_routing, ActivationType.SWIGLU, "gate" + elif model_type in ("gemma4_text",): + return gemma4_routing, ActivationType.GEGLU, "router" # Fused topk -> softmax path (routing_fn=None): # elif model_type in ("gpt_oss",): # # NOTE: gpt_oss has a router bias which moe_TC_softmax_topk_layer @@ -501,3 +504,73 @@ def softmax_topk_wg_routing( flat_expert_idx = top_indices.to(torch.int32).reshape(-1) # [T*K] return flat_scores, flat_token_idx, flat_expert_idx, router_logits + + +def gemma4_routing( + hidden_states: torch.Tensor, moe_block +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Gemma4-style routing: RMSNorm → scale → proj → softmax → topk → renorm → per_expert_scale. + + Gemma4's router (``Gemma4TextRouter``) has a unique structure: + 1. RMSNorm (without learnable scale) on hidden states + 2. Multiply by ``scale * hidden_size**-0.5`` + 3. Linear projection to expert scores + 4. Softmax → topk + 5. Normalize top-k weights to sum to 1 + 6. Multiply by per-expert learned scales + + The router lives at ``moe_block.router`` (not ``moe_block.gate``). + LoRA on the router targets ``router.proj`` (nn.Linear). + + Args: + hidden_states: [T, H] flattened token representations + moe_block: MoE block module (accesses moe_block.router) + + Returns: + router_scores: [T*K] flattened scores (float32) + token_indices: [T*K] which token each entry belongs to (int32), sorted ascending + expert_indices: [T*K] which expert (int32) + router_logits: [T, E] original logits for aux loss + """ + router = moe_block.router + + # Unwrap PEFT LoRA on router.proj (the nn.Linear) + _, proj_weight, proj_lora_delta = unwrap_gate_lora(router.proj) + + T, _ = hidden_states.shape + K = router.top_k if hasattr(router, "top_k") else router.config.top_k_experts + + # Reproduce Gemma4TextRouter.forward: + # 1. RMSNorm (no scale) + scale param * hidden_size**-0.5 + normed = router.norm(hidden_states) + scaled = normed * router.scale * router.scalar_root_size + + # 2. Project to expert scores + router_logits = F.linear(scaled.float(), proj_weight.float()) # [T, E] + if proj_lora_delta is not None: + router_logits = router_logits + F.linear( + scaled.float(), proj_lora_delta.float() + ) + + # 3. Softmax → topk + router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E] + top_values, top_indices = torch.topk(router_probs, K, dim=-1) # [T, K] + + # 4. Normalize top-k weights + top_values = top_values / top_values.sum(dim=-1, keepdim=True) + + # 5. Per-expert scale + top_values = top_values * router.per_expert_scale[top_indices] + + # Flatten for moe_general_routing_inputs + token_indices = ( + torch.arange(T, device=hidden_states.device, dtype=torch.int32) + .unsqueeze(1) + .expand(T, K) + ) + + flat_scores = top_values.to(torch.float32).reshape(-1) # [T*K] + flat_token_idx = token_indices.reshape(-1) # [T*K] + flat_expert_idx = top_indices.to(torch.int32).reshape(-1) # [T*K] + + return flat_scores, flat_token_idx, flat_expert_idx, router_logits diff --git a/src/axolotl/integrations/kernels/plugin.py b/src/axolotl/integrations/kernels/plugin.py index 4ab22bfce..e9291c9c1 100644 --- a/src/axolotl/integrations/kernels/plugin.py +++ b/src/axolotl/integrations/kernels/plugin.py @@ -61,20 +61,31 @@ class KernelsPlugin(BasePlugin): return "axolotl.integrations.kernels.KernelsArgs" def pre_model_load(self, cfg): - from axolotl.integrations.kernels.constants import SPARSE_MOE_BLOCK + from axolotl.integrations.kernels.constants import ( + SPARSE_MOE_BLOCK, + is_experts_only_model, + ) # Prefer text backbone type for VLMs, but fall back to base type # when the text type isn't in the supported mapping (e.g. qwen3_5_moe_text) moe_model_type = cfg.model_config_type_text or cfg.model_config_type if ( moe_model_type not in SPARSE_MOE_BLOCK + and not is_experts_only_model(moe_model_type) and cfg.model_config_type in SPARSE_MOE_BLOCK ): moe_model_type = cfg.model_config_type if cfg.use_scattermoe: self._register_kernels() - self._kernelize_model(moe_model_type) + if is_experts_only_model(moe_model_type): + # Models like Gemma4 where MoE is embedded in the decoder layer + # — register ScatterMoE in the ExpertsInterface so that + # @use_experts_implementation dispatches to it. + self._register_experts_interface() + cfg.experts_implementation = "scattermoe" + else: + self._kernelize_model(moe_model_type) elif cfg.use_sonicmoe: if not importlib.util.find_spec("sonicmoe"): raise RuntimeError( @@ -84,14 +95,24 @@ class KernelsPlugin(BasePlugin): _check_sonicmoe_gpu_compat() - from axolotl.integrations.kernels.libs.sonicmoe import patch_sonicmoe + if is_experts_only_model(moe_model_type): + from axolotl.integrations.kernels.libs.sonicmoe.gemma4_experts import ( + patch_gemma4_sonicmoe, + ) - LOG.info(f"Applying SonicMoE patches for model type: {moe_model_type}") - patch_sonicmoe( - moe_model_type, - torch_compile=bool(getattr(cfg, "torch_compile", False)), - base_model_type=cfg.model_config_type, - ) + LOG.info( + f"Applying SonicMoE experts-level patch for model type: {moe_model_type}" + ) + patch_gemma4_sonicmoe() + else: + from axolotl.integrations.kernels.libs.sonicmoe import patch_sonicmoe + + LOG.info(f"Applying SonicMoE patches for model type: {moe_model_type}") + patch_sonicmoe( + moe_model_type, + torch_compile=bool(getattr(cfg, "torch_compile", False)), + base_model_type=cfg.model_config_type, + ) def _register_kernels(self): from kernels import ( @@ -139,3 +160,16 @@ class KernelsPlugin(BasePlugin): replace_kernel_forward_from_hub( model_moe_cls, "HFScatterMoEParallelExperts" ) + + def _register_experts_interface(self): + """Register ScatterMoE in the transformers ExpertsInterface. + + This allows @use_experts_implementation-decorated Experts classes + to dispatch to ScatterMoE when config._experts_implementation == "scattermoe". + """ + from axolotl.integrations.kernels.libs.scattermoe_lora.gemma4_experts import ( + register_scattermoe_experts, + ) + + register_scattermoe_experts() + LOG.info("Registered 'scattermoe' in transformers ExpertsInterface") diff --git a/src/axolotl/monkeypatch/lora_kernels.py b/src/axolotl/monkeypatch/lora_kernels.py index c5d552c03..d569d5925 100644 --- a/src/axolotl/monkeypatch/lora_kernels.py +++ b/src/axolotl/monkeypatch/lora_kernels.py @@ -73,6 +73,44 @@ QKV_PATCHES = [ query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2) key_states = self.k_norm(key_states.view(hidden_shape)).transpose(1, 2) value_states = value_states.view(hidden_shape).transpose(1, 2) +""".lstrip("\n"), + ), + # Gemma4: norm between proj and transpose, RoPE between norm and transpose, + # conditional KV sharing (is_kv_shared_layer), v_proj may be None (attention_k_eq_v). + # We only fuse the projection calls; norms, RoPE, and KV sharing stay as-is. + ( + """ + query_states = self.q_proj(hidden_states).view(hidden_shape) + query_states = self.q_norm(query_states) + query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2) + query_states = query_states.transpose(1, 2) + + # For layers with shared KV (from kv sharing point onwards), we reuse the same keys/values states as the last non-sharing layer + if self.is_kv_shared_layer and past_key_values is not None: + key_states, value_states = past_key_values.shared_layers[self.kv_shared_layer_index] + # Device of past layer may be different from current one + key_states = key_states.to(query_states.device) + value_states = value_states.to(query_states.device) + else: + key_states = self.k_proj(hidden_states).view(hidden_shape) + value_states = self.v_proj(hidden_states).view(hidden_shape) if self.v_proj is not None else key_states +""".lstrip("\n"), + """ + query_states, key_states, value_states = self.apply_qkv(hidden_states) + query_states = query_states.view(hidden_shape) + query_states = self.q_norm(query_states) + query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2) + query_states = query_states.transpose(1, 2) + + # For layers with shared KV (from kv sharing point onwards), we reuse the same keys/values states as the last non-sharing layer + if self.is_kv_shared_layer and past_key_values is not None: + key_states, value_states = past_key_values.shared_layers[self.kv_shared_layer_index] + # Device of past layer may be different from current one + key_states = key_states.to(query_states.device) + value_states = value_states.to(query_states.device) + else: + key_states = key_states.view(hidden_shape) + value_states = value_states.view(hidden_shape) if self.v_proj is not None else key_states """.lstrip("\n"), ), ] @@ -113,6 +151,23 @@ def original_apply_qkv( return query_states, key_states, value_states +def original_apply_qkv_optional_v( + self: nn.Module, hidden_states: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """QKV projection for models where v_proj may be None (e.g. Gemma4 attention_k_eq_v). + + When v_proj is None, key_states are reused as value_states. + """ + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + if self.v_proj is not None: + value_states = self.v_proj(hidden_states) + else: + value_states = key_states + + return query_states, key_states, value_states + + def original_apply_o(self: nn.Module, hidden_states: torch.Tensor) -> torch.Tensor: """ Original implementation of output projection without optimizations. @@ -183,6 +238,11 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]: return Gemma3Attention + if model_type in ("gemma4", "gemma4_text"): + from transformers.models.gemma4.modeling_gemma4 import Gemma4TextAttention + + return Gemma4TextAttention + try: # Dynamically import the module and attention class module_path = f"transformers.models.{model_type}.modeling_{model_type}" @@ -410,14 +470,24 @@ def apply_lora_kernel_patches( # Add QKV, O fallback implementations to start # These will be overwritten later (if some conditions apply) for self_attn in find_self_attn_in_layer(layer): - self_attn.apply_qkv = types.MethodType(original_apply_qkv, self_attn) + # Use v_proj-optional fallback for models where v_proj can be None + # (e.g. Gemma4 with attention_k_eq_v=True) + if getattr(self_attn, "v_proj", None) is None: + self_attn.apply_qkv = types.MethodType( + original_apply_qkv_optional_v, self_attn + ) + else: + self_attn.apply_qkv = types.MethodType(original_apply_qkv, self_attn) self_attn.apply_o = types.MethodType(original_apply_o, self_attn) if cfg.lora_qkv_kernel: # Query, key, value patching + # Filter out None projections (e.g. Gemma4 v_proj when attention_k_eq_v=True) + proj_names = ["q_proj", "k_proj", "v_proj"] layer_modules = [ - getattr(self_attn, linear_proj) - for linear_proj in ["q_proj", "k_proj", "v_proj"] + getattr(self_attn, name) + for name in proj_names + if getattr(self_attn, name, None) is not None ] can_patch_qkv = all( hasattr(module, "lora_A") for module in layer_modules diff --git a/src/axolotl/monkeypatch/models/qwen3_next/modeling.py b/src/axolotl/monkeypatch/models/qwen3_next/modeling.py index 48570ba42..fb4cb1bc7 100644 --- a/src/axolotl/monkeypatch/models/qwen3_next/modeling.py +++ b/src/axolotl/monkeypatch/models/qwen3_next/modeling.py @@ -111,7 +111,6 @@ def patch_qwen3_next_gateddelta_layer(): """Patch Qwen3NextGatedDeltaNet to parse cu_seqlens and pass to chunk_gated_delta_rule""" try: from transformers.models.qwen3_next.modeling_qwen3_next import ( - Qwen3NextDynamicCache, Qwen3NextGatedDeltaNet, apply_mask_to_padding_states, ) @@ -125,8 +124,7 @@ def patch_qwen3_next_gateddelta_layer(): def patched_gated_delta_net_forward( self, hidden_states: torch.Tensor, - cache_params: Optional[Qwen3NextDynamicCache] = None, - cache_position: Optional[torch.LongTensor] = None, + cache_params=None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, ): @@ -137,9 +135,8 @@ def patch_qwen3_next_gateddelta_layer(): use_precomputed_states = ( cache_params is not None - and cache_params.has_previous_state + and cache_params.has_previous_state(self.layer_idx) and seq_len == 1 - and cache_position is not None ) # Compute cu_seqlens early for use by both causal_conv1d and chunk_gated_delta_rule @@ -148,9 +145,9 @@ def patch_qwen3_next_gateddelta_layer(): cu_seqlens = get_cu_seqlens(position_ids=position_ids) # getting projected states from cache if it exists - if cache_params is not None: - conv_state = cache_params.conv_states[self.layer_idx] - recurrent_state = cache_params.recurrent_states[self.layer_idx] + if use_precomputed_states: + conv_state = cache_params.layers[self.layer_idx].conv_states + recurrent_state = cache_params.layers[self.layer_idx].recurrent_states projected_states_qkvz = self.in_proj_qkvz(hidden_states) projected_states_ba = self.in_proj_ba(hidden_states) @@ -162,10 +159,9 @@ def patch_qwen3_next_gateddelta_layer(): ) mixed_qkv = torch.cat((query, key, value), dim=-1) # [B, T, D] + mixed_qkv = mixed_qkv.transpose(1, 2) # [B, D, T] if use_precomputed_states: - # Inference single-token path: causal_conv1d_update expects [B, D, T] - mixed_qkv = mixed_qkv.transpose(1, 2) mixed_qkv = self.causal_conv1d_update( mixed_qkv, conv_state, @@ -173,19 +169,17 @@ def patch_qwen3_next_gateddelta_layer(): self.conv1d.bias, self.activation, ) - mixed_qkv = mixed_qkv.transpose(1, 2) else: if cache_params is not None: - # Cache state expects [B, D, T] for the inference update path - mixed_qkv_t = mixed_qkv.transpose(1, 2) conv_state = F.pad( - mixed_qkv_t, - (self.conv_kernel_size - mixed_qkv_t.shape[-1], 0), + mixed_qkv, + (self.conv_kernel_size - mixed_qkv.shape[-1], 0), ) - cache_params.conv_states[self.layer_idx] = conv_state + cache_params.update_conv_state(conv_state, self.layer_idx) if fla_causal_conv1d is not None: # FLA Triton causal_conv1d: [B, T, D] in/out, with cu_seqlens support + mixed_qkv = mixed_qkv.transpose(1, 2) # [B, T, D] for FLA mixed_qkv, _ = fla_causal_conv1d( x=mixed_qkv, weight=self.conv1d.weight.squeeze(1), @@ -193,6 +187,15 @@ def patch_qwen3_next_gateddelta_layer(): activation=self.activation, cu_seqlens=cu_seqlens, ) + mixed_qkv = mixed_qkv.transpose(1, 2) # back to [B, D, T] + elif self.causal_conv1d_fn is not None: + mixed_qkv = self.causal_conv1d_fn( + x=mixed_qkv, + weight=self.conv1d.weight.squeeze(1), + bias=self.conv1d.bias, + activation=self.activation, + seq_idx=None, + ) else: # PyTorch fallback (no cu_seqlens support) if cu_seqlens is not None and cu_seqlens.shape[0] > batch_size + 1: @@ -203,11 +206,9 @@ def patch_qwen3_next_gateddelta_layer(): LOG.warning_once( "FLA causal_conv1d not available. Falling back to PyTorch conv1d." ) - mixed_qkv = mixed_qkv.transpose(1, 2) mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len]) - mixed_qkv = mixed_qkv.transpose(1, 2) - # mixed_qkv is [B, T, D] in all paths + mixed_qkv = mixed_qkv.transpose(1, 2) # [B, T, D] query, key, value = torch.split( mixed_qkv, [ @@ -255,7 +256,7 @@ def patch_qwen3_next_gateddelta_layer(): # Update cache if cache_params is not None: - cache_params.recurrent_states[self.layer_idx] = last_recurrent_state + cache_params.update_recurrent_state(last_recurrent_state, self.layer_idx) z_shape_og = z.shape # reshape input data into 2D tensor diff --git a/src/axolotl/utils/chat_templates/templates/gemma4.jinja b/src/axolotl/utils/chat_templates/templates/gemma4.jinja new file mode 100644 index 000000000..780957c94 --- /dev/null +++ b/src/axolotl/utils/chat_templates/templates/gemma4.jinja @@ -0,0 +1,271 @@ +{%- macro format_parameters(properties, required) -%} + {%- set standard_keys = ['description', 'type', 'properties', 'required', 'nullable'] -%} + {%- set ns = namespace(found_first=false) -%} + {%- for key, value in properties | dictsort -%} + {%- set add_comma = false -%} + {%- if key not in standard_keys -%} + {%- if ns.found_first %},{% endif -%} + {%- set ns.found_first = true -%} + {{ key }}:{ + {%- if value['description'] -%} + description:<|"|>{{ value['description'] }}<|"|> + {%- set add_comma = true -%} + {%- endif -%} + {%- if value['nullable'] %} + {%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%} + nullable:true + {%- endif -%} + {%- if value['type'] | upper == 'STRING' -%} + {%- if value['enum'] -%} + {%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%} + enum:{{ format_argument(value['enum']) }} + {%- endif -%} + {%- elif value['type'] | upper == 'OBJECT' -%} + ,properties:{ + {%- if value['properties'] is defined and value['properties'] is mapping -%} + {{- format_parameters(value['properties'], value['required'] | default([])) -}} + {%- elif value is mapping -%} + {{- format_parameters(value, value['required'] | default([])) -}} + {%- endif -%} + } + {%- if value['required'] -%} + ,required:[ + {%- for item in value['required'] | default([]) -%} + <|"|>{{- item -}}<|"|> + {%- if not loop.last %},{% endif -%} + {%- endfor -%} + ] + {%- endif -%} + {%- elif value['type'] | upper == 'ARRAY' -%} + {%- if value['items'] is mapping and value['items'] -%} + ,items:{ + {%- set ns_items = namespace(found_first=false) -%} + {%- for item_key, item_value in value['items'] | dictsort -%} + {%- if item_value is not none -%} + {%- if ns_items.found_first %},{% endif -%} + {%- set ns_items.found_first = true -%} + {%- if item_key == 'properties' -%} + properties:{ + {%- if item_value is mapping -%} + {{- format_parameters(item_value, value['items']['required'] | default([])) -}} + {%- endif -%} + } + {%- elif item_key == 'required' -%} + required:[ + {%- for req_item in item_value -%} + <|"|>{{- req_item -}}<|"|> + {%- if not loop.last %},{% endif -%} + {%- endfor -%} + ] + {%- elif item_key == 'type' -%} + {%- if item_value is string -%} + type:{{ format_argument(item_value | upper) }} + {%- else -%} + type:{{ format_argument(item_value | map('upper') | list) }} + {%- endif -%} + {%- else -%} + {{ item_key }}:{{ format_argument(item_value) }} + {%- endif -%} + {%- endif -%} + {%- endfor -%} + } + {%- endif -%} + {%- endif -%} + {%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%} + type:<|"|>{{ value['type'] | upper }}<|"|>} + {%- endif -%} + {%- endfor -%} +{%- endmacro -%} +{%- macro format_function_declaration(tool_data) -%} + declaration:{{- tool_data['function']['name'] -}}{description:<|"|>{{- tool_data['function']['description'] -}}<|"|> + {%- set params = tool_data['function']['parameters'] -%} + {%- if params -%} + ,parameters:{ + {%- if params['properties'] -%} + properties:{ {{- format_parameters(params['properties'], params['required']) -}} }, + {%- endif -%} + {%- if params['required'] -%} + required:[ + {%- for item in params['required'] -%} + <|"|>{{- item -}}<|"|> + {{- ',' if not loop.last -}} + {%- endfor -%} + ], + {%- endif -%} + {%- if params['type'] -%} + type:<|"|>{{- params['type'] | upper -}}<|"|>} + {%- endif -%} + {%- endif -%} + {%- if 'response' in tool_data['function'] -%} + {%- set response_declaration = tool_data['function']['response'] -%} + ,response:{ + {%- if response_declaration['description'] -%} + description:<|"|>{{- response_declaration['description'] -}}<|"|>, + {%- endif -%} + {%- if response_declaration['type'] | upper == 'OBJECT' -%} + type:<|"|>{{- response_declaration['type'] | upper -}}<|"|>} + {%- endif -%} + {%- endif -%} + } +{%- endmacro -%} +{%- macro format_argument(argument, escape_keys=True) -%} + {%- if argument is string -%} + {{- '<|"|>' + argument + '<|"|>' -}} + {%- elif argument is boolean -%} + {{- 'true' if argument else 'false' -}} + {%- elif argument is mapping -%} + {{- '{' -}} + {%- set ns = namespace(found_first=false) -%} + {%- for key, value in argument | dictsort -%} + {%- if ns.found_first %},{% endif -%} + {%- set ns.found_first = true -%} + {%- if escape_keys -%} + {{- '<|"|>' + key + '<|"|>' -}} + {%- else -%} + {{- key -}} + {%- endif -%} + :{{- format_argument(value, escape_keys=escape_keys) -}} + {%- endfor -%} + {{- '}' -}} + {%- elif argument is sequence -%} + {{- '[' -}} + {%- for item in argument -%} + {{- format_argument(item, escape_keys=escape_keys) -}} + {%- if not loop.last %},{% endif -%} + {%- endfor -%} + {{- ']' -}} + {%- else -%} + {{- argument -}} + {%- endif -%} +{%- endmacro -%} +{#- Removes '<|channel>...' thinking blocks from model output. + Splits on the end token '', then checks each part for the start + token '<|channel>' and keeps only the text before it. -#} +{%- macro strip_thinking(text) -%} + {%- set ns = namespace(cleaned='') -%} + {%- for part in text.split('') -%} + {%- if '<|channel>' in part -%} + {%- set ns.cleaned = ns.cleaned + part.split('<|channel>')[0] -%} + {%- else -%} + {%- set ns.cleaned = ns.cleaned + part -%} + {%- endif -%} + {%- endfor -%} + {{- ns.cleaned | trim -}} +{%- endmacro -%} + +{%- set ns = namespace(prev_message_type=None) -%} +{%- set loop_messages = messages -%} +{{ bos_token }} +{#- Handle System/Tool Definitions Block -#} +{%- if (enable_thinking is defined and enable_thinking) or tools or messages[0]['role'] in ['system', 'developer'] -%} + {{- '<|turn>system\n' -}} + + {#- Inject Thinking token at the very top of the FIRST system turn -#} + {%- if enable_thinking is defined and enable_thinking -%} + {{- '<|think|>' -}} + {%- set ns.prev_message_type = 'think' -%} + {%- endif -%} + + {%- if messages[0]['role'] in ['system', 'developer'] -%} + {{- messages[0]['content'] | trim -}} + {%- set loop_messages = messages[1:] -%} + {%- endif -%} + + {%- if tools -%} + {%- for tool in tools %} + {{- '<|tool>' -}} + {{- format_function_declaration(tool) | trim -}} + {{- '' -}} + {%- endfor %} + {%- set ns.prev_message_type = 'tool' -%} + {%- endif -%} + + {{- '\n' -}} +{%- endif %} + +{#- Loop through messages -#} +{%- for message in loop_messages -%} + {#- Reset so only special message types (tool_call, image, etc.) influence + the generation prompt formatting below. Plain text leaves it as None. -#} + {%- set ns.prev_message_type = None -%} + {%- set role = 'model' if message['role'] == 'assistant' else message['role'] -%} + {{- '<|turn>' + role + '\n' }} + + {%- if message['tool_calls'] -%} + {%- for tool_call in message['tool_calls'] -%} + {%- set function = tool_call['function'] -%} + {{- '<|tool_call>call:' + function['name'] + '{' -}} + {%- if function['arguments'] is mapping -%} + {%- set ns_args = namespace(found_first=false) -%} + {%- for key, value in function['arguments'] | dictsort -%} + {%- if ns_args.found_first %},{% endif -%} + {%- set ns_args.found_first = true -%} + {{- key -}}:{{- format_argument(value, escape_keys=False) -}} + {%- endfor -%} + {%- elif function['arguments'] is string -%} + {{- function['arguments'] -}} + {%- endif -%} + {{- '}' -}} + {%- endfor -%} + {%- set ns.prev_message_type = 'tool_call' -%} + {%- endif -%} + + {%- if message['tool_responses'] -%} + {#- Tool Response handling -#} + {%- for tool_response in message['tool_responses'] -%} + {{- '<|tool_response>' -}} + {%- if tool_response['response'] is mapping -%} + {{- 'response:' + tool_response['name'] | default('unknown') + '{' -}} + {%- for key, value in tool_response['response'] | dictsort -%} + {{- key -}}:{{- format_argument(value, escape_keys=False) -}} + {%- if not loop.last %},{% endif -%} + {%- endfor -%} + {{- '}' -}} + {%- else -%} + {{- 'response:' + tool_response['name'] | default('unknown') + '{value:' + format_argument(tool_response['response'], escape_keys=False) + '}' -}} + {%- endif -%} + {{- '' -}} + {%- endfor -%} + {%- set ns.prev_message_type = 'tool_response' -%} + {%- endif -%} + + {%- if message['content'] is string -%} + {%- if role == 'model' -%} + {{- strip_thinking(message['content']) -}} + {%- else -%} + {{- message['content'] | trim -}} + {%- endif -%} + {%- elif message['content'] is sequence -%} + {%- for item in message['content'] -%} + {%- if item['type'] == 'text' -%} + {%- if role == 'model' -%} + {{- strip_thinking(item['text']) -}} + {%- else -%} + {{- item['text'] | trim -}} + {%- endif -%} + {%- elif item['type'] == 'image' -%} + {{- '\n\n<|image|>\n\n' -}} + {%- set ns.prev_message_type = 'image' -%} + {%- elif item['type'] == 'audio' -%} + {{- '<|audio|>' -}} + {%- set ns.prev_message_type = 'audio' -%} + {%- elif item['type'] == 'video' -%} + {{- '\n\n<|video|>\n\n' -}} + {%- set ns.prev_message_type = 'video' -%} + {%- endif -%} + {%- endfor -%} + {%- endif -%} + + {%- if not (message['tool_responses'] and not message['content']) -%} + {{- '\n' -}} + {%- endif -%} +{%- endfor -%} + +{%- if add_generation_prompt -%} + {%- if ns.prev_message_type != 'tool_response' -%} + {{- '<|turn>model\n' -}} + {%- endif -%} + {%- if not enable_thinking | default(false) -%} + {{- '<|channel>thought\n' -}} + {%- endif -%} +{%- endif -%} diff --git a/src/axolotl/utils/schemas/enums.py b/src/axolotl/utils/schemas/enums.py index 4b759237e..d4ff27ac9 100644 --- a/src/axolotl/utils/schemas/enums.py +++ b/src/axolotl/utils/schemas/enums.py @@ -72,6 +72,7 @@ class ChatTemplate(str, Enum): qwen2_vl = "qwen2_vl" gemma3 = "gemma3" gemma3n = "gemma3n" + gemma4 = "gemma4" command_a = "command_a" command_a_tool_use = "command_a_tool_use" command_a_rag = "command_a_rag" diff --git a/src/axolotl/utils/schemas/trl.py b/src/axolotl/utils/schemas/trl.py index cd6a9c57a..a36242162 100644 --- a/src/axolotl/utils/schemas/trl.py +++ b/src/axolotl/utils/schemas/trl.py @@ -69,9 +69,7 @@ class TRLConfig(BaseModel): generation_batch_size: int | None = Field( default=None, json_schema_extra={ - "description": "Batch size for generation. Controls how many unique " - "prompts are generated per step. For full DP utilization, set to " - "num_generations * data_parallel_size (or a multiple thereof)." + "description": "Batch size for generation. Controls how many unique prompts are generated per step. Should be num_generations * data_parallel_size for full DP utilization." }, ) num_generations: int | None = Field( diff --git a/tests/integrations/test_gemma4_moe.py b/tests/integrations/test_gemma4_moe.py new file mode 100644 index 000000000..412d49b2f --- /dev/null +++ b/tests/integrations/test_gemma4_moe.py @@ -0,0 +1,1052 @@ +"""Validation tests for Gemma 4 MoE compatibility with ScatterMoE and SonicMoE. + +Gemma 4 has a unique MoE architecture: +- No separate SparseMoeBlock — MoE is embedded in the decoder layer +- Hybrid MLP+MoE: dense MLP runs in parallel with sparse MoE, outputs summed +- Custom router (Gemma4TextRouter): RMSNorm → scale → proj → softmax → topk → renorm → per_expert_scale +- Router is `self.router` (not `self.gate`) +- Experts use standard 3D param layout with @use_experts_implementation + +These tests validate that: +1. ScatterMoE kernels produce correct output for Gemma4 expert layout +2. ScatterMoE + LoRA produces correct output +3. SonicMoE integration code handles Gemma4 routing correctly +4. Weight layouts are compatible +""" + +import pytest +import torch +import torch.nn.functional as F +from torch import nn + +# ============================================================================ +# Gemma4 reference implementation (extracted from transformers) +# ============================================================================ + + +class Gemma4RMSNorm(nn.Module): + def __init__(self, dim, eps=1e-6, with_scale=True): + super().__init__() + self.eps = eps + self.with_scale = with_scale + if with_scale: + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + variance = x.float().pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(variance + self.eps) + if self.with_scale: + return (self.weight * x).to(x.dtype) + return x.to(x.dtype) + + +class Gemma4TextRouter(nn.Module): + def __init__(self, hidden_size, num_experts, top_k, eps=1e-6): + super().__init__() + self.hidden_size = hidden_size + self.num_experts = num_experts + self.top_k = top_k + self.scalar_root_size = hidden_size**-0.5 + self.eps = eps + + self.norm = Gemma4RMSNorm(hidden_size, eps=eps, with_scale=False) + self.proj = nn.Linear(hidden_size, num_experts, bias=False) + self.scale = nn.Parameter(torch.ones(hidden_size)) + self.per_expert_scale = nn.Parameter(torch.ones(num_experts)) + + def forward(self, hidden_states): + hidden_states = self.norm(hidden_states) + hidden_states = hidden_states * self.scale * self.scalar_root_size + + expert_scores = self.proj(hidden_states.to(self.proj.weight.dtype)) + router_probabilities = F.softmax(expert_scores, dim=-1) + + top_k_weights, top_k_index = torch.topk( + router_probabilities, k=self.top_k, dim=-1 + ) + + top_k_weights /= top_k_weights.sum(dim=-1, keepdim=True) + top_k_weights = top_k_weights * self.per_expert_scale[top_k_index] + + return router_probabilities, top_k_weights, top_k_index + + +class Gemma4TextExperts(nn.Module): + def __init__(self, num_experts, hidden_size, intermediate_size, act_fn): + super().__init__() + self.num_experts = num_experts + self.hidden_dim = hidden_size + self.intermediate_dim = intermediate_size + self.gate_up_proj = nn.Parameter( + torch.empty(num_experts, 2 * intermediate_size, hidden_size) + ) + self.down_proj = nn.Parameter( + torch.empty(num_experts, hidden_size, intermediate_size) + ) + self.act_fn = act_fn + + def forward(self, hidden_states, top_k_index, top_k_weights): + final_hidden_states = torch.zeros_like(hidden_states) + with torch.no_grad(): + expert_mask = F.one_hot(top_k_index, num_classes=self.num_experts) + expert_mask = expert_mask.permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + + for expert_idx in expert_hit: + expert_idx = expert_idx[0] + if expert_idx == self.num_experts: + continue + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) + current_state = hidden_states[token_idx] + gate, up = F.linear(current_state, self.gate_up_proj[expert_idx]).chunk( + 2, dim=-1 + ) + current_hidden_states = self.act_fn(gate) * up + current_hidden_states = F.linear( + current_hidden_states, self.down_proj[expert_idx] + ) + current_hidden_states = ( + current_hidden_states * top_k_weights[token_idx, top_k_pos, None] + ) + final_hidden_states.index_add_( + 0, token_idx, current_hidden_states.to(final_hidden_states.dtype) + ) + + return final_hidden_states + + +# ============================================================================ +# Test fixtures +# ============================================================================ + + +@pytest.fixture +def gemma4_config(): + """Small Gemma4 MoE config for testing.""" + return { + "hidden_size": 128, + "num_experts": 8, + "top_k": 2, + "intermediate_size": 64, + "eps": 1e-6, + } + + +@pytest.fixture +def device(): + if not torch.cuda.is_available(): + pytest.skip("CUDA required") + return torch.device("cuda:0") + + +@pytest.fixture +def gemma4_moe_layer(gemma4_config, device): + """Create a Gemma4 MoE layer (router + experts) on GPU.""" + from transformers.activations import ACT2FN + + act_fn = ACT2FN["gelu_pytorch_tanh"] + + router = Gemma4TextRouter( + hidden_size=gemma4_config["hidden_size"], + num_experts=gemma4_config["num_experts"], + top_k=gemma4_config["top_k"], + eps=gemma4_config["eps"], + ) + experts = Gemma4TextExperts( + num_experts=gemma4_config["num_experts"], + hidden_size=gemma4_config["hidden_size"], + intermediate_size=gemma4_config["intermediate_size"], + act_fn=act_fn, + ) + + # Initialize weights + nn.init.kaiming_uniform_(experts.gate_up_proj) + nn.init.kaiming_uniform_(experts.down_proj) + nn.init.normal_(router.proj.weight, std=0.01) + + router = router.to(device).to(torch.bfloat16) + experts = experts.to(device).to(torch.bfloat16) + + return router, experts + + +# ============================================================================ +# ScatterMoE Tests +# ============================================================================ + + +class TestGemma4ScatterMoE: + """Test ScatterMoE kernel compatibility with Gemma4 expert layout.""" + + def test_scattermoe_experts_match_reference( + self, gemma4_moe_layer, gemma4_config, device + ): + """ScatterMoE kernel output matches reference expert computation.""" + from transformers.activations import ACT2FN + + from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_experts import ( + flatten_sort_count, + parallel_linear, + ) + + router, experts = gemma4_moe_layer + act_fn = ACT2FN["gelu_pytorch_tanh"] + T = 16 # num tokens + H = gemma4_config["hidden_size"] + K = gemma4_config["top_k"] + E = gemma4_config["num_experts"] + + hidden_states = torch.randn(T, H, device=device, dtype=torch.bfloat16) + + # Reference forward + _, top_k_weights, top_k_index = router(hidden_states) + ref_output = experts(hidden_states, top_k_index, top_k_weights) + + # ScatterMoE forward + routing_weights = top_k_weights.to(hidden_states.dtype) + sorted_expert_idxs, sorted_scattered_idxs, expert_offsets = flatten_sort_count( + top_k_index, num_experts=E + ) + + # gate_up_proj is [E, 2*inter, H], ScatterMoE expects transposed: [E, H, 2*I] + gate_up_weight = experts.gate_up_proj.transpose(2, 1) + down_weight = experts.down_proj.transpose(2, 1) + + gates_h = parallel_linear( + hidden_states, + gate_up_weight, + K, + sorted_expert_idxs, + sorted_scattered_idxs, + expert_offsets, + grouped_in=False, + grouped_out=True, + ) + gates, h = gates_h.chunk(2, dim=-1) + h = act_fn(gates) * h + + scatter_output = parallel_linear( + h, + down_weight, + 1, + sorted_expert_idxs, + sorted_scattered_idxs, + expert_offsets, + grouped_in=True, + grouped_out=False, + gates=routing_weights, + ) + + # Allow bf16 tolerance + torch.testing.assert_close(scatter_output, ref_output, atol=1e-2, rtol=1e-2) + + def test_scattermoe_with_lora(self, gemma4_moe_layer, gemma4_config, device): + """ScatterMoE + LoRA kernel matches reference LoRA computation.""" + from transformers.activations import ACT2FN + + from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_experts import ( + flatten_sort_count, + parallel_linear, + ) + from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_linear_lora import ( + parallel_linear_lora, + ) + + router, experts = gemma4_moe_layer + act_fn = ACT2FN["gelu_pytorch_tanh"] + T = 16 + H = gemma4_config["hidden_size"] + K = gemma4_config["top_k"] + E = gemma4_config["num_experts"] + inter = gemma4_config["intermediate_size"] + rank = 4 + scaling = 0.5 + + hidden_states = torch.randn(T, H, device=device, dtype=torch.bfloat16) + + # Create LoRA weights for gate_up_proj + # ScatterMoE layout: A=[r*E, K], B=[N, r*E] + lora_A_gup = ( + torch.randn(rank * E, H, device=device, dtype=torch.bfloat16) * 0.01 + ) + lora_B_gup = ( + torch.randn(2 * inter, rank * E, device=device, dtype=torch.bfloat16) * 0.01 + ) + + # Reference: manual LoRA application per expert + _, top_k_weights, top_k_index = router(hidden_states) + ref_output = torch.zeros(T, H, device=device, dtype=torch.bfloat16) + with torch.no_grad(): + expert_mask = F.one_hot(top_k_index, num_classes=E).permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + + for eidx in expert_hit: + eidx = eidx[0] + if eidx == E: + continue + top_k_pos, token_idx = torch.where(expert_mask[eidx]) + current_state = hidden_states[token_idx] + + # Base gate_up + LoRA delta + base_out = F.linear(current_state, experts.gate_up_proj[eidx]) + lora_a_slice = lora_A_gup[eidx * rank : (eidx + 1) * rank, :] + lora_b_slice = lora_B_gup[:, eidx * rank : (eidx + 1) * rank] + lora_delta = ( + F.linear(F.linear(current_state, lora_a_slice), lora_b_slice) * scaling + ) + combined = base_out + lora_delta + + gate, up = combined.chunk(2, dim=-1) + h = act_fn(gate) * up + h = F.linear(h, experts.down_proj[eidx]) + h = h * top_k_weights[token_idx, top_k_pos, None] + ref_output.index_add_(0, token_idx, h.to(ref_output.dtype)) + + # ScatterMoE LoRA forward + routing_weights = top_k_weights.to(hidden_states.dtype) + sorted_expert_idxs, sorted_scattered_idxs, expert_offsets = flatten_sort_count( + top_k_index, num_experts=E + ) + + gate_up_weight = experts.gate_up_proj.transpose(2, 1) + down_weight = experts.down_proj.transpose(2, 1) + + gates_h = parallel_linear_lora( + hidden_states, + gate_up_weight, + K, + sorted_expert_idxs, + sorted_scattered_idxs, + expert_offsets, + lora_A_gup, + lora_B_gup, + scaling, + grouped_in=False, + grouped_out=True, + ) + gates, h = gates_h.chunk(2, dim=-1) + h = act_fn(gates) * h + + scatter_output = parallel_linear( + h, + down_weight, + 1, + sorted_expert_idxs, + sorted_scattered_idxs, + expert_offsets, + grouped_in=True, + grouped_out=False, + gates=routing_weights, + ) + + torch.testing.assert_close(scatter_output, ref_output, atol=5e-2, rtol=5e-2) + + def test_gemma4_routing_correctness(self, gemma4_moe_layer, gemma4_config, device): + """Gemma4 custom routing (norm+scale+per_expert_scale) produces valid outputs.""" + router, _ = gemma4_moe_layer + T = 32 + H = gemma4_config["hidden_size"] + K = gemma4_config["top_k"] + E = gemma4_config["num_experts"] + + hidden_states = torch.randn(T, H, device=device, dtype=torch.bfloat16) + router_probs, top_k_weights, top_k_index = router(hidden_states) + + # Check shapes + assert router_probs.shape == (T, E) + assert top_k_weights.shape == (T, K) + assert top_k_index.shape == (T, K) + + # Router probs should be valid probability distribution + assert (router_probs >= 0).all() + assert torch.allclose( + router_probs.sum(dim=-1), + torch.ones(T, device=device, dtype=router_probs.dtype), + atol=1e-3, + ) + + # Top-k indices should be valid expert indices + assert (top_k_index >= 0).all() + assert (top_k_index < E).all() + + # Top-k weights should be non-negative (per_expert_scale can change sign though) + # Just verify finite + assert top_k_weights.isfinite().all() + + def test_scattermoe_gradients_flow(self, gemma4_moe_layer, gemma4_config, device): + """Verify gradients flow through ScatterMoE kernels for Gemma4.""" + from transformers.activations import ACT2FN + + from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_experts import ( + flatten_sort_count, + parallel_linear, + ) + + router, experts = gemma4_moe_layer + + # Enable grad for expert weights + experts.gate_up_proj.requires_grad_(True) + experts.down_proj.requires_grad_(True) + + act_fn = ACT2FN["gelu_pytorch_tanh"] + T = 16 + H = gemma4_config["hidden_size"] + K = gemma4_config["top_k"] + E = gemma4_config["num_experts"] + + hidden_states = torch.randn(T, H, device=device, dtype=torch.bfloat16) + + with torch.no_grad(): + _, top_k_weights, top_k_index = router(hidden_states) + + routing_weights = top_k_weights.to(hidden_states.dtype) + sorted_expert_idxs, sorted_scattered_idxs, expert_offsets = flatten_sort_count( + top_k_index, num_experts=E + ) + + gate_up_weight = experts.gate_up_proj.transpose(2, 1) + down_weight = experts.down_proj.transpose(2, 1) + + gates_h = parallel_linear( + hidden_states, + gate_up_weight, + K, + sorted_expert_idxs, + sorted_scattered_idxs, + expert_offsets, + grouped_in=False, + grouped_out=True, + ) + gates, h = gates_h.chunk(2, dim=-1) + h = act_fn(gates) * h + + output = parallel_linear( + h, + down_weight, + 1, + sorted_expert_idxs, + sorted_scattered_idxs, + expert_offsets, + grouped_in=True, + grouped_out=False, + gates=routing_weights, + ) + + loss = output.sum() + loss.backward() + + assert experts.gate_up_proj.grad is not None + assert experts.down_proj.grad is not None + assert experts.gate_up_proj.grad.isfinite().all() + assert experts.down_proj.grad.isfinite().all() + + +# ============================================================================ +# SonicMoE Tests +# ============================================================================ + + +def _can_import_sonicmoe(): + try: + from sonicmoe.enums import ActivationType # noqa: F401 + + return True + except Exception: + return False + + +class TestGemma4SonicMoE: + """Test SonicMoE compatibility with Gemma4. + + SonicMoE requires Hopper/Blackwell GPU. Tests that need sonicmoe + import are skipped on unsupported GPUs. + """ + + @pytest.mark.skipif( + not _can_import_sonicmoe(), + reason="sonicmoe requires Hopper/Blackwell GPU", + ) + def test_gemma4_routing_function_config(self, gemma4_config): + """Gemma4 is registered with correct routing config.""" + from axolotl.integrations.kernels.libs.sonicmoe.routing import ( + get_model_moe_config, + ) + + routing_fn, activation, router_attr = get_model_moe_config("gemma4_text") + + assert router_attr == "router" + assert routing_fn is not None + assert routing_fn.__name__ == "gemma4_routing" + + from sonicmoe.enums import ActivationType + + assert activation == ActivationType.GEGLU + + @pytest.mark.skipif( + not _can_import_sonicmoe(), + reason="sonicmoe requires Hopper/Blackwell GPU", + ) + def test_gemma4_routing_matches_reference(self, gemma4_config): + """Routing function output matches reference Gemma4TextRouter.""" + from axolotl.integrations.kernels.libs.sonicmoe.routing import ( + get_model_moe_config, + ) + + routing_fn, _, _ = get_model_moe_config("gemma4_text") + H = gemma4_config["hidden_size"] + E = gemma4_config["num_experts"] + K = gemma4_config["top_k"] + T = 16 + + router = Gemma4TextRouter(H, E, K) + nn.init.normal_(router.proj.weight, std=0.01) + + class MockGemma4MoeBlock: + pass + + mock_block = MockGemma4MoeBlock() + mock_block.router = router + + hidden_states = torch.randn(T, H) + + # Reference + _ref_probs, ref_weights, ref_indices = router(hidden_states) + + # Routing function + flat_scores, flat_token_idx, flat_expert_idx, router_logits = routing_fn( + hidden_states, mock_block + ) + + # Check shapes + assert flat_scores.shape == (T * K,) + assert flat_token_idx.shape == (T * K,) + assert flat_expert_idx.shape == (T * K,) + assert router_logits.shape == (T, E) + + # Reconstruct per-token routing from flat output and compare + for t in range(T): + mask = flat_token_idx == t + assert mask.sum() == K, f"Token {t} should have {K} entries" + + flat_experts_for_t = flat_expert_idx[mask].sort().values + ref_experts_for_t = ref_indices[t].sort().values.to(torch.int32) + assert torch.equal(flat_experts_for_t, ref_experts_for_t), ( + f"Token {t}: experts mismatch" + ) + + # Verify scores match reference per-token + for t in range(T): + mask = flat_token_idx == t + flat_experts_t = flat_expert_idx[mask] + flat_scores_t = flat_scores[mask] + + sort_idx = flat_experts_t.argsort() + flat_scores_sorted = flat_scores_t[sort_idx] + + ref_sort_idx = ref_indices[t].argsort() + ref_scores_sorted = ref_weights[t][ref_sort_idx].float() + + torch.testing.assert_close( + flat_scores_sorted, ref_scores_sorted, atol=1e-4, rtol=1e-4 + ) + + def test_gemma4_weight_layout_compatible(self, gemma4_config): + """Verify Gemma4 expert weight layout is compatible with SonicMoE.""" + E = gemma4_config["num_experts"] + H = gemma4_config["hidden_size"] + inter = gemma4_config["intermediate_size"] + + gate_up_proj = torch.randn(E, 2 * inter, H) + down_proj = torch.randn(E, H, inter) + + # SonicMoE expects [dim, dim, E] (experts last) + gate_up_sonic = gate_up_proj.permute(1, 2, 0) + down_sonic = down_proj.permute(1, 2, 0) + + assert gate_up_sonic.shape == (2 * inter, H, E) + assert down_sonic.shape == (H, inter, E) + + # Verify roundtrip + recovered_gate_up = gate_up_sonic.permute(2, 0, 1) + assert torch.equal(gate_up_proj, recovered_gate_up) + + def test_gemma4_is_experts_only_model(self): + """Verify gemma4_text is recognized as experts-only model.""" + from axolotl.integrations.kernels.constants import ( + is_experts_only_model, + resolve_experts_class, + ) + + assert is_experts_only_model("gemma4_text") + cls = resolve_experts_class("gemma4_text") + assert cls is not None + assert cls.__name__ == "Gemma4TextExperts" + + def test_gemma4_not_in_sparse_moe_block(self): + """Verify gemma4_text is NOT in SPARSE_MOE_BLOCK (has no SparseMoeBlock).""" + from axolotl.integrations.kernels.constants import SPARSE_MOE_BLOCK + + assert "gemma4_text" not in SPARSE_MOE_BLOCK + + +# ============================================================================ +# Integration Tests (full layer with real model config) +# ============================================================================ + + +class TestGemma4FullLayerIntegration: + """Test with realistic Gemma4 config (26B-A4B dimensions, single layer).""" + + @pytest.fixture + def real_config(self): + return { + "hidden_size": 2816, + "num_experts": 128, + "top_k": 8, + "intermediate_size": 704, + "eps": 1e-6, + } + + def test_scattermoe_real_dimensions(self, real_config, device): + """ScatterMoE works with real Gemma4-26B-A4B expert dimensions.""" + from transformers.activations import ACT2FN + + from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_experts import ( + flatten_sort_count, + parallel_linear, + ) + + act_fn = ACT2FN["gelu_pytorch_tanh"] + H = real_config["hidden_size"] + E = real_config["num_experts"] + K = real_config["top_k"] + inter = real_config["intermediate_size"] + T = 32 + + # Create experts on GPU + gate_up_proj = ( + torch.randn(E, 2 * inter, H, device=device, dtype=torch.bfloat16) * 0.01 + ) + down_proj = torch.randn(E, H, inter, device=device, dtype=torch.bfloat16) * 0.01 + hidden_states = torch.randn(T, H, device=device, dtype=torch.bfloat16) + + # Simulate routing (random valid assignment) + top_k_index = torch.stack( + [torch.randperm(E, device=device)[:K] for _ in range(T)] + ) + top_k_weights = torch.softmax( + torch.randn(T, K, device=device, dtype=torch.bfloat16), dim=-1 + ) + + # Reference + ref_output = torch.zeros(T, H, device=device, dtype=torch.bfloat16) + with torch.no_grad(): + expert_mask = F.one_hot(top_k_index, num_classes=E).permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + + for eidx in expert_hit: + eidx = eidx[0] + top_k_pos, token_idx = torch.where(expert_mask[eidx]) + current_state = hidden_states[token_idx] + gate, up = F.linear(current_state, gate_up_proj[eidx]).chunk(2, dim=-1) + h = act_fn(gate) * up + h = F.linear(h, down_proj[eidx]) + h = h * top_k_weights[token_idx, top_k_pos, None] + ref_output.index_add_(0, token_idx, h.to(ref_output.dtype)) + + # ScatterMoE + routing_weights = top_k_weights.to(hidden_states.dtype) + sorted_expert_idxs, sorted_scattered_idxs, expert_offsets = flatten_sort_count( + top_k_index, num_experts=E + ) + + gates_h = parallel_linear( + hidden_states, + gate_up_proj.transpose(2, 1), + K, + sorted_expert_idxs, + sorted_scattered_idxs, + expert_offsets, + grouped_in=False, + grouped_out=True, + ) + gates, h = gates_h.chunk(2, dim=-1) + h = act_fn(gates) * h + + scatter_output = parallel_linear( + h, + down_proj.transpose(2, 1), + 1, + sorted_expert_idxs, + sorted_scattered_idxs, + expert_offsets, + grouped_in=True, + grouped_out=False, + gates=routing_weights, + ) + + torch.testing.assert_close(scatter_output, ref_output, atol=5e-2, rtol=5e-2) + + def test_scattermoe_lora_real_dimensions(self, real_config, device): + """ScatterMoE + LoRA works with real Gemma4-26B-A4B dimensions.""" + from transformers.activations import ACT2FN + + from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_experts import ( + flatten_sort_count, + parallel_linear, + ) + from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_linear_lora import ( + parallel_linear_lora, + ) + + act_fn = ACT2FN["gelu_pytorch_tanh"] + H = real_config["hidden_size"] + E = real_config["num_experts"] + K = real_config["top_k"] + inter = real_config["intermediate_size"] + T = 32 + rank = 8 + scaling = 0.5 + + gate_up_proj = ( + torch.randn(E, 2 * inter, H, device=device, dtype=torch.bfloat16) * 0.01 + ) + down_proj = torch.randn(E, H, inter, device=device, dtype=torch.bfloat16) * 0.01 + lora_A = torch.randn(rank * E, H, device=device, dtype=torch.bfloat16) * 0.01 + lora_B = ( + torch.randn(2 * inter, rank * E, device=device, dtype=torch.bfloat16) * 0.01 + ) + hidden_states = torch.randn(T, H, device=device, dtype=torch.bfloat16) + + # Random routing + top_k_index = torch.stack( + [torch.randperm(E, device=device)[:K] for _ in range(T)] + ) + top_k_weights = torch.softmax( + torch.randn(T, K, device=device, dtype=torch.bfloat16), dim=-1 + ) + + routing_weights = top_k_weights.to(hidden_states.dtype) + sorted_expert_idxs, sorted_scattered_idxs, expert_offsets = flatten_sort_count( + top_k_index, num_experts=E + ) + + # ScatterMoE + LoRA on gate_up + gates_h = parallel_linear_lora( + hidden_states, + gate_up_proj.transpose(2, 1), + K, + sorted_expert_idxs, + sorted_scattered_idxs, + expert_offsets, + lora_A, + lora_B, + scaling, + grouped_in=False, + grouped_out=True, + ) + gates, h = gates_h.chunk(2, dim=-1) + h = act_fn(gates) * h + + output = parallel_linear( + h, + down_proj.transpose(2, 1), + 1, + sorted_expert_idxs, + sorted_scattered_idxs, + expert_offsets, + grouped_in=True, + grouped_out=False, + gates=routing_weights, + ) + + # Basic sanity: output should be finite and right shape + assert output.shape == (T, H) + assert output.isfinite().all() + + +class TestExpertsInterfaceIntegration: + """Test the ExpertsInterface registration (the clean transformers hook).""" + + @staticmethod + def _make_gemma4_config(): + from transformers.models.gemma4.configuration_gemma4 import Gemma4TextConfig + + return Gemma4TextConfig( + hidden_size=128, + num_experts=8, + top_k_experts=2, + moe_intermediate_size=64, + hidden_activation="gelu_pytorch_tanh", + num_hidden_layers=2, + num_attention_heads=2, + num_key_value_heads=1, + head_dim=64, + intermediate_size=256, + enable_moe_block=True, + ) + + def test_register_scattermoe_in_experts_interface(self): + """register_scattermoe_experts adds 'scattermoe' to the global interface.""" + from transformers.integrations.moe import ALL_EXPERTS_FUNCTIONS + + from axolotl.integrations.kernels.libs.scattermoe_lora.gemma4_experts import ( + register_scattermoe_experts, + scattermoe_experts_forward, + ) + + register_scattermoe_experts() + + assert "scattermoe" in ALL_EXPERTS_FUNCTIONS + assert ALL_EXPERTS_FUNCTIONS["scattermoe"] is scattermoe_experts_forward + + def test_experts_implementation_dispatches_to_scattermoe(self, device): + """Setting config._experts_implementation='scattermoe' dispatches correctly.""" + from transformers.models.gemma4.modeling_gemma4 import ( + Gemma4TextExperts as HFGemma4TextExperts, + ) + + from axolotl.integrations.kernels.libs.scattermoe_lora.gemma4_experts import ( + register_scattermoe_experts, + ) + + register_scattermoe_experts() + + cfg = self._make_gemma4_config() + cfg._experts_implementation = "scattermoe" + + with torch.device("meta"): + hf_experts = HFGemma4TextExperts(cfg) + + hf_experts = hf_experts.to_empty(device=device) + nn.init.kaiming_uniform_(hf_experts.gate_up_proj) + nn.init.kaiming_uniform_(hf_experts.down_proj) + hf_experts = hf_experts.to(torch.bfloat16) + + T, K = 16, 2 + hidden_states = torch.randn(T, 128, device=device, dtype=torch.bfloat16) + top_k_index = torch.stack( + [torch.randperm(8, device=device)[:K] for _ in range(T)] + ) + top_k_weights = torch.softmax( + torch.randn(T, K, device=device, dtype=torch.bfloat16), dim=-1 + ) + + # Get reference output with eager implementation + cfg_eager = self._make_gemma4_config() + cfg_eager._experts_implementation = "eager" + with torch.device("meta"): + eager_experts = HFGemma4TextExperts(cfg_eager) + eager_experts = eager_experts.to_empty(device=device).to(torch.bfloat16) + # Copy weights from scattermoe experts + eager_experts.gate_up_proj.data.copy_(hf_experts.gate_up_proj.data) + eager_experts.down_proj.data.copy_(hf_experts.down_proj.data) + + ref_output = eager_experts(hidden_states, top_k_index, top_k_weights) + + # ScatterMoE dispatched output + scatter_output = hf_experts(hidden_states, top_k_index, top_k_weights) + + torch.testing.assert_close(scatter_output, ref_output, atol=1e-2, rtol=1e-2) + + def test_validation_accepts_scattermoe(self): + """get_correct_experts_implementation accepts 'scattermoe' after registration.""" + from transformers.modeling_utils import PreTrainedModel + + from axolotl.integrations.kernels.libs.scattermoe_lora.gemma4_experts import ( + register_scattermoe_experts, + ) + + register_scattermoe_experts() + + # Should not raise + result = PreTrainedModel.get_correct_experts_implementation(None, "scattermoe") + assert result == "scattermoe" + + def test_eager_still_works_after_registration(self, device): + """Registering scattermoe doesn't break eager dispatch.""" + from transformers.models.gemma4.modeling_gemma4 import ( + Gemma4TextExperts as HFGemma4TextExperts, + ) + + from axolotl.integrations.kernels.libs.scattermoe_lora.gemma4_experts import ( + register_scattermoe_experts, + ) + + register_scattermoe_experts() + + cfg = self._make_gemma4_config() + cfg._experts_implementation = "eager" + + with torch.device("meta"): + hf_experts = HFGemma4TextExperts(cfg) + + hf_experts = hf_experts.to_empty(device=device) + nn.init.kaiming_uniform_(hf_experts.gate_up_proj) + nn.init.kaiming_uniform_(hf_experts.down_proj) + hf_experts = hf_experts.to(torch.bfloat16) + + T, K = 16, 2 + hidden_states = torch.randn(T, 128, device=device, dtype=torch.bfloat16) + top_k_index = torch.stack( + [torch.randperm(8, device=device)[:K] for _ in range(T)] + ) + top_k_weights = torch.softmax( + torch.randn(T, K, device=device, dtype=torch.bfloat16), dim=-1 + ) + + # Should use eager (original) forward without error + output = hf_experts(hidden_states, top_k_index, top_k_weights) + assert output.shape == (T, 128) + assert output.isfinite().all() + + +class TestScatterMoEExpertsInterfaceMultiModel: + """Test that the registered scattermoe ExpertsInterface works across model types. + + All @use_experts_implementation Experts classes share the same layout: + gate_up_proj [E, 2*inter, H], down_proj [E, H, inter], forward(hidden_states, top_k_index, top_k_weights). + """ + + MODEL_EXPERTS = [ + ( + "transformers.models.gemma4.modeling_gemma4", + "Gemma4TextExperts", + { + "hidden_size": 128, + "num_experts": 8, + "moe_intermediate_size": 64, + "hidden_activation": "gelu_pytorch_tanh", + "top_k_experts": 2, + "num_hidden_layers": 2, + "num_attention_heads": 2, + "num_key_value_heads": 1, + "head_dim": 64, + "intermediate_size": 256, + "enable_moe_block": True, + }, + "transformers.models.gemma4.configuration_gemma4.Gemma4TextConfig", + ), + ( + "transformers.models.qwen3_moe.modeling_qwen3_moe", + "Qwen3MoeExperts", + { + "hidden_size": 128, + "num_experts": 8, + "moe_intermediate_size": 64, + "hidden_act": "silu", + "num_experts_per_tok": 2, + "num_hidden_layers": 2, + "num_attention_heads": 2, + "num_key_value_heads": 1, + "intermediate_size": 256, + }, + "transformers.models.qwen3_moe.configuration_qwen3_moe.Qwen3MoeConfig", + ), + ( + "transformers.models.olmoe.modeling_olmoe", + "OlmoeExperts", + { + "hidden_size": 128, + "num_experts": 8, + "intermediate_size": 64, + "hidden_act": "silu", + "num_experts_per_tok": 2, + "num_hidden_layers": 2, + "num_attention_heads": 2, + "num_key_value_heads": 1, + }, + "transformers.models.olmoe.configuration_olmoe.OlmoeConfig", + ), + ( + "transformers.models.mixtral.modeling_mixtral", + "MixtralExperts", + { + "hidden_size": 128, + "num_local_experts": 8, + "intermediate_size": 64, + "hidden_act": "silu", + "num_experts_per_tok": 2, + "num_hidden_layers": 2, + "num_attention_heads": 2, + "num_key_value_heads": 1, + }, + "transformers.models.mixtral.configuration_mixtral.MixtralConfig", + ), + ] + + @pytest.fixture( + params=[m[1] for m in MODEL_EXPERTS], ids=[m[1] for m in MODEL_EXPERTS] + ) + def model_setup(self, request, device): + """Create an Experts instance for each model type.""" + import importlib + + from axolotl.integrations.kernels.libs.scattermoe_lora.gemma4_experts import ( + register_scattermoe_experts, + ) + + register_scattermoe_experts() + + for module_path, cls_name, cfg_kwargs, config_cls_path in self.MODEL_EXPERTS: + if cls_name == request.param: + # Import config class + config_module, config_class = config_cls_path.rsplit(".", 1) + config_cls = getattr( + importlib.import_module(config_module), config_class + ) + cfg = config_cls(**cfg_kwargs) + + # Import experts class + module = importlib.import_module(module_path) + experts_cls = getattr(module, cls_name) + + # Create eager reference + cfg_eager = config_cls(**cfg_kwargs) + cfg_eager._experts_implementation = "eager" + with torch.device("meta"): + eager = experts_cls(cfg_eager) + eager = eager.to_empty(device=device).to(torch.bfloat16) + nn.init.kaiming_uniform_(eager.gate_up_proj) + nn.init.kaiming_uniform_(eager.down_proj) + + # Create scattermoe version with same weights + cfg._experts_implementation = "scattermoe" + with torch.device("meta"): + scatter = experts_cls(cfg) + scatter = scatter.to_empty(device=device).to(torch.bfloat16) + scatter.gate_up_proj.data.copy_(eager.gate_up_proj.data) + scatter.down_proj.data.copy_(eager.down_proj.data) + + return ( + cls_name, + eager, + scatter, + cfg_kwargs.get( + "num_experts", cfg_kwargs.get("num_local_experts", 8) + ), + ) + + def test_scattermoe_matches_eager(self, model_setup, device): + """ScatterMoE ExpertsInterface output matches eager for each model type.""" + cls_name, eager, scatter, num_experts = model_setup + T, K = 16, 2 + + hidden_states = torch.randn(T, 128, device=device, dtype=torch.bfloat16) + top_k_index = torch.stack( + [torch.randperm(num_experts, device=device)[:K] for _ in range(T)] + ) + top_k_weights = torch.softmax( + torch.randn(T, K, device=device, dtype=torch.bfloat16), dim=-1 + ) + + ref_output = eager(hidden_states, top_k_index, top_k_weights) + scatter_output = scatter(hidden_states, top_k_index, top_k_weights) + + torch.testing.assert_close( + scatter_output, + ref_output, + atol=1e-2, + rtol=1e-2, + msg=f"{cls_name}: ScatterMoE output doesn't match eager", + )