Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
08fc7de87e |
104
examples/gemma4/26b-a4b-moe-qlora.yaml
Normal file
104
examples/gemma4/26b-a4b-moe-qlora.yaml
Normal file
@@ -0,0 +1,104 @@
|
||||
# Gemma 4 26B-A4B MoE QLoRA with ScatterMoE kernels
|
||||
#
|
||||
# Validated: 50 steps on FineTome-100k, loss 7.4 -> 2.4, single RTX 5090 (32GB)
|
||||
#
|
||||
# Key notes:
|
||||
# - Flash Attention 2 is NOT supported (global_head_dim=512 > FA2 max of 256).
|
||||
# Use sdp_attention instead.
|
||||
# - Gemma 4 is multimodal (text+vision+audio). For text-only SFT, restrict
|
||||
# LoRA to the text backbone via lora_target_linear_modules regex.
|
||||
# - MoE experts use `experts_implementation: scattermoe` — Gemma 4 embeds MoE
|
||||
# directly in the decoder layer (no SparseMoeBlock), so we register ScatterMoE
|
||||
# via the transformers ExpertsInterface.
|
||||
# - Expert LoRA targets are `experts.gate_up_proj` / `experts.down_proj`
|
||||
# (no `mlp.` prefix, unlike Qwen/Mixtral).
|
||||
# - micro_batch_size: 1 fits 2048 seq_len on 32GB GPU with SDP attention.
|
||||
# Use micro_batch_size: 4 with 1024 seq_len, or on 48GB+ GPUs.
|
||||
|
||||
base_model: google/gemma-4-26B-A4B
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
- axolotl.integrations.kernels.KernelsPlugin
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
use_kernels: true
|
||||
use_scattermoe: true
|
||||
experts_implementation: scattermoe
|
||||
torch_compile: false
|
||||
liger_layer_norm: true
|
||||
liger_rope: true
|
||||
liger_rms_norm: true
|
||||
liger_glu_activation: true
|
||||
liger_rms_norm_gated: true
|
||||
strict: false
|
||||
|
||||
chat_template: gemma4
|
||||
datasets:
|
||||
- path: mlabonne/FineTome-100k
|
||||
type: chat_template
|
||||
split: train[:10%]
|
||||
field_messages: conversations
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
val_set_size: 0.05
|
||||
output_dir: ./outputs/gemma4-26b-a4b-qlora
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
load_in_4bit: true
|
||||
quantize_moe_experts: true
|
||||
adapter: qlora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0
|
||||
|
||||
# Restrict LoRA to text backbone only (skip vision/audio encoders).
|
||||
# lora_target_modules is intentionally empty — all module targeting is done
|
||||
# via regex in lora_target_linear_modules below.
|
||||
lora_target_modules: []
|
||||
lora_target_linear_modules:
|
||||
- language_model\.model\.layers\.\d+\.self_attn\.(q|k|v|o)_proj
|
||||
|
||||
# MoE expert LoRA (3D Parameter tensors, not nn.Linear)
|
||||
lora_target_parameters:
|
||||
- experts.gate_up_proj
|
||||
- experts.down_proj
|
||||
|
||||
lora_mlp_kernel: false
|
||||
lora_qkv_kernel: false
|
||||
lora_o_kernel: false
|
||||
|
||||
bnb_config_kwargs:
|
||||
bnb_4bit_use_double_quant: true
|
||||
|
||||
wandb_project: gemma4-qlora
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
activation_offloading: true
|
||||
logging_steps: 1
|
||||
|
||||
# FA2 not supported — Gemma4 global_head_dim=512 exceeds FA2 max of 256
|
||||
flash_attention: false
|
||||
sdp_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
special_tokens:
|
||||
@@ -12,7 +12,7 @@ packaging==26.0
|
||||
huggingface_hub>=1.1.7
|
||||
peft>=0.18.1
|
||||
tokenizers>=0.22.1
|
||||
transformers==5.4.0
|
||||
transformers==5.5.0
|
||||
accelerate==1.13.0
|
||||
datasets==4.5.0
|
||||
deepspeed>=0.18.6,<0.19.0
|
||||
|
||||
@@ -381,6 +381,15 @@ class AxolotlTrainer(
|
||||
# Store per-step trainable tokens for throughput calculation
|
||||
self.state.tokens["trainable_tokens"] = trainable_tokens.detach().cpu()
|
||||
|
||||
# Gemma4 requires mm_token_type_ids during training (even for text-only).
|
||||
# Inject zeros (= text token type) when not provided by the data collator.
|
||||
if (
|
||||
"mm_token_type_ids" not in inputs
|
||||
and "input_ids" in inputs
|
||||
and getattr(getattr(model, "config", None), "model_type", None) == "gemma4"
|
||||
):
|
||||
inputs["mm_token_type_ids"] = torch.zeros_like(inputs["input_ids"])
|
||||
|
||||
if self.args.orpo_alpha:
|
||||
return self.orpo_compute_loss(
|
||||
model,
|
||||
@@ -508,12 +517,24 @@ class AxolotlTrainer(
|
||||
)
|
||||
|
||||
# Perform a single forward pass
|
||||
forward_kwargs = {
|
||||
"input_ids": concat_inputs["input_ids"],
|
||||
"attention_mask": concat_inputs["attention_mask"],
|
||||
"labels": concat_inputs["labels"],
|
||||
}
|
||||
# Gemma4 requires mm_token_type_ids during training (even for text-only)
|
||||
if (
|
||||
getattr(getattr(model, "config", None), "model_type", None) == "gemma4"
|
||||
and "mm_token_type_ids" not in concat_inputs
|
||||
):
|
||||
forward_kwargs["mm_token_type_ids"] = torch.zeros_like(
|
||||
concat_inputs["input_ids"]
|
||||
)
|
||||
elif "mm_token_type_ids" in concat_inputs:
|
||||
forward_kwargs["mm_token_type_ids"] = concat_inputs["mm_token_type_ids"]
|
||||
|
||||
outputs = model(
|
||||
**{
|
||||
"input_ids": concat_inputs["input_ids"],
|
||||
"attention_mask": concat_inputs["attention_mask"],
|
||||
"labels": concat_inputs["labels"],
|
||||
},
|
||||
**forward_kwargs,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ use_scattermoe: true
|
||||
use_sonicmoe: true
|
||||
```
|
||||
|
||||
**Important:** Setting `experts_implementation` is incompatible with custom kernel options.
|
||||
**Important:** Setting `experts_implementation` to `batched_mm` or `grouped_mm` is incompatible with custom kernel options. The exception is `experts_implementation: scattermoe`, which is used for models like Gemma 4 that embed MoE directly in the decoder layer (no SparseMoeBlock) and dispatch through the transformers `ExpertsInterface`.
|
||||
|
||||
### SonicMoE installation
|
||||
|
||||
@@ -63,7 +63,7 @@ Both paths use the shared `resolve_moe_block_classes` utility in `constants.py`
|
||||
|
||||
## Model Support Matrix
|
||||
|
||||
All models use the **SwiGLU** activation (`act_fn(gate) * up`). Neither kernel currently supports non-SwiGLU MoE architectures.
|
||||
Most models use the **SwiGLU** activation (`silu(gate) * up`). Gemma 4 uses **GEGLU** (`gelu(gate) * up`). ScatterMoE supports any gated activation (activation is applied in Python between kernel calls). SonicMoE supports SwiGLU, GEGLU, and REGLU via its `ActivationType` enum.
|
||||
|
||||
### Routing strategies
|
||||
|
||||
@@ -76,6 +76,7 @@ All models use the **SwiGLU** activation (`act_fn(gate) * up`). Neither kernel c
|
||||
| softmax → bias correction → topk | Softmax, bias via `gate.moe_statics`, topk, gather from original probs, clamp-based renorm | No | Yes |
|
||||
| softmax → group_limited_greedy | Softmax, group selection (max per group), topk, scale only (no renorm) | No | Yes |
|
||||
| softmax → topk via gate.wg | Softmax, gate weight at `gate.wg.weight` (not `gate.weight`), always renormalize | No | Yes |
|
||||
| softmax → topk + per_expert_scale | RMSNorm → scale → proj → softmax → topk → renorm → per-expert learned scales | Yes | Yes |
|
||||
| fused topk → softmax | Routing + expert computation fused in a single kernel | No | Planned |
|
||||
|
||||
### Per-model support
|
||||
@@ -102,10 +103,13 @@ All models use the **SwiGLU** activation (`act_fn(gate) * up`). Neither kernel c
|
||||
| `ernie4_5_moe` | ERNIE 4.5 MoE | softmax → bias → topk | No | **Yes** |
|
||||
| `deepseek_v2` | DeepSeek-V2 | softmax → group_limited_greedy | No | **Yes** |
|
||||
| `hunyuan_v1_moe` | HunYuan V1 MoE | softmax → topk (gate.wg) | No | **Yes** |
|
||||
| `gemma4_text` | Gemma 4 (26B-A4B) | softmax → topk + per_expert_scale | **Yes**\*\* | **Yes**\*\* |
|
||||
| `gpt_oss` | GPT-OSS | fused topk → softmax | No | Planned |
|
||||
|
||||
\* `glm4_moe_lite` with ScatterMoE may have issues — see Limitations.
|
||||
|
||||
\*\* Gemma 4 uses `experts_implementation: scattermoe` path (registered via `ExpertsInterface`) instead of SparseMoeBlock patching, since Gemma 4 embeds MoE directly in its decoder layer (no separate SparseMoeBlock). See the [Gemma 4 section](#gemma-4) below.
|
||||
|
||||
### Feature comparison
|
||||
|
||||
| Feature | ScatterMoE | SonicMoE |
|
||||
@@ -131,6 +135,21 @@ Both kernels handle shared experts identically. Shared expert attribute names ar
|
||||
|
||||
If `shared_expert_gate` exists, sigmoid gating is applied to the shared expert contribution before adding it to the routed output. PEFT wraps shared expert linear layers with standard LoRA — no special handling is needed.
|
||||
|
||||
## Gemma 4
|
||||
|
||||
Gemma 4 (e.g. `google/gemma-4-26B-A4B`) has a unique hybrid MoE architecture:
|
||||
|
||||
- **No SparseMoeBlock**: MoE is embedded directly in the decoder layer alongside a dense MLP. Both run in parallel and their outputs are summed.
|
||||
- **Custom router** (`Gemma4TextRouter`): RMSNorm → learned scale → linear projection → softmax → top-k → renormalization → per-expert learned scales.
|
||||
- **GEGLU activation**: Uses `gelu_pytorch_tanh` (not SiLU/SwiGLU like most other MoE models).
|
||||
- **128 experts, top-k=8** for the 26B-A4B variant.
|
||||
|
||||
Because there is no SparseMoeBlock class to patch, Gemma 4 uses a different integration path: we register `"scattermoe"` as a custom implementation in the transformers `ExpertsInterface`, and set `experts_implementation: scattermoe` in the config. The `@use_experts_implementation` decorator on `Gemma4TextExperts` then dispatches to our ScatterMoE kernel automatically. The router is untouched — it runs as-is.
|
||||
|
||||
**Important limitations:**
|
||||
- **Flash Attention 2 is not supported** — Gemma 4 uses `global_head_dim: 512` for full attention layers, which exceeds FA2's maximum head dimension of 256. Use `sdp_attention: true` instead.
|
||||
- **Multimodal model**: Gemma 4 includes vision and audio encoders. For text-only SFT, use `lora_target_linear_modules` with a regex to restrict LoRA to the text backbone (e.g. `language_model\.model\.layers\.\d+\.self_attn\.(q|k|v|o)_proj`).
|
||||
|
||||
## Limitations
|
||||
|
||||
- **ScatterMoE + GLM4-MoE Lite**: ScatterMoE does not work reliably for GLM 4.7 Flash (`glm4_moe_lite`).
|
||||
|
||||
@@ -34,12 +34,20 @@ class KernelsArgs(BaseModel):
|
||||
@classmethod
|
||||
def check_experts_implementation(cls, data):
|
||||
experts_implementation = data.get("experts_implementation")
|
||||
use_scattermoe = data.get("use_scattermoe", False)
|
||||
if experts_implementation is None:
|
||||
# transformers may default to batched_mm when unset
|
||||
data["experts_implementation"] = "eager"
|
||||
elif experts_implementation != "eager":
|
||||
elif experts_implementation == "scattermoe" and not use_scattermoe:
|
||||
LOG.warning(
|
||||
"`experts_implementation` must be set to 'eager' to use this. Automatically setting it to 'eager'."
|
||||
"`experts_implementation='scattermoe'` requires `use_scattermoe: true`. "
|
||||
"Automatically setting to 'eager'."
|
||||
)
|
||||
data["experts_implementation"] = "eager"
|
||||
elif experts_implementation not in ("eager", "scattermoe"):
|
||||
LOG.warning(
|
||||
f"`experts_implementation={experts_implementation!r}` is not compatible with "
|
||||
f"custom MoE kernels. Automatically setting to 'eager'."
|
||||
)
|
||||
data["experts_implementation"] = "eager"
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ Models with custom routing (see sonicmoe/routing.py for implementations):
|
||||
- ernie4_5_moe: softmax→bias correction→topk (softmax_bias_topk_routing)
|
||||
- deepseek_v2: softmax→group_limited_greedy (softmax_group_limited_topk_routing)
|
||||
- hunyuan_v1_moe: softmax→topk via gate.wg (softmax_topk_wg_routing)
|
||||
- gemma4_text: RMSNorm→scale→proj→softmax→topk→renorm→per_expert_scale (experts-level patch)
|
||||
"""
|
||||
|
||||
import importlib
|
||||
@@ -53,6 +54,49 @@ SPARSE_MOE_BLOCK = {
|
||||
}
|
||||
|
||||
|
||||
# Models where MoE is NOT in a separate SparseMoeBlock but embedded in the
|
||||
# decoder layer. For these, we patch the Experts class forward directly
|
||||
# (same signature: hidden_states, top_k_index, top_k_weights -> Tensor).
|
||||
# Routing stays untouched — the original model router runs as-is.
|
||||
EXPERTS_ONLY_BLOCK = {
|
||||
# gemma4: hybrid MLP+MoE in decoder layer, custom Gemma4TextRouter,
|
||||
# no SparseMoeBlock. Experts use @use_experts_implementation with
|
||||
# standard 3D param layout (gate_up_proj [E, 2*I, H], down_proj [E, H, I]).
|
||||
"gemma4_text": "Gemma4TextExperts",
|
||||
}
|
||||
|
||||
|
||||
def resolve_experts_class(model_type: str):
|
||||
"""Resolve the Experts class for models that need experts-level patching.
|
||||
|
||||
Returns the class, or None if the model uses SparseMoeBlock-level patching.
|
||||
"""
|
||||
entry = EXPERTS_ONLY_BLOCK.get(model_type)
|
||||
if entry is None:
|
||||
return None
|
||||
|
||||
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
||||
try:
|
||||
module = importlib.import_module(module_path)
|
||||
except ModuleNotFoundError:
|
||||
if model_type.endswith("_text"):
|
||||
parent_type = model_type.removesuffix("_text")
|
||||
module_path = f"transformers.models.{parent_type}.modeling_{parent_type}"
|
||||
module = importlib.import_module(module_path)
|
||||
else:
|
||||
raise
|
||||
|
||||
cls = getattr(module, entry, None)
|
||||
if cls is None:
|
||||
raise ValueError(f"Could not find class '{entry}' in '{module_path}'")
|
||||
return cls
|
||||
|
||||
|
||||
def is_experts_only_model(model_type: str) -> bool:
|
||||
"""Check if a model type requires experts-level (not block-level) patching."""
|
||||
return model_type in EXPERTS_ONLY_BLOCK
|
||||
|
||||
|
||||
def resolve_moe_block_classes(model_type: str):
|
||||
"""Resolve all MoE block classes from transformers for the given model type.
|
||||
|
||||
|
||||
@@ -0,0 +1,235 @@
|
||||
"""
|
||||
ScatterMoE-accelerated experts forward for Gemma4.
|
||||
|
||||
Gemma4 has no separate SparseMoeBlock — MoE is embedded in the decoder layer.
|
||||
The decoder layer handles routing (Gemma4TextRouter) and calls
|
||||
``experts(hidden_states, top_k_index, top_k_weights)`` directly.
|
||||
|
||||
This module registers a ``"scattermoe"`` implementation in the transformers
|
||||
``ExpertsInterface``, which the ``@use_experts_implementation`` decorator
|
||||
dispatches to when ``config._experts_implementation == "scattermoe"``.
|
||||
|
||||
This is the clean way to hook into transformers' MoE dispatch — no
|
||||
monkeypatching required. Works for Gemma4 and any future model that uses
|
||||
``@use_experts_implementation`` with the standard forward signature
|
||||
``(hidden_states, top_k_index, top_k_weights) -> Tensor``.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from .parallel_experts import flatten_sort_count, parallel_linear
|
||||
from .parallel_linear_lora import get_lora_params_from_wrapper, parallel_linear_lora
|
||||
|
||||
|
||||
def _has_peft_wrapper(module):
|
||||
"""Check if a module's parameter has been wrapped by PEFT ParamWrapper."""
|
||||
try:
|
||||
from peft.tuners.param_wrapper import ParamWrapper
|
||||
|
||||
for attr in ("gate_up_proj", "down_proj"):
|
||||
param = getattr(module, attr, None)
|
||||
if isinstance(param, ParamWrapper):
|
||||
return True
|
||||
except ImportError:
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
def _unwrap_experts_lora(experts):
|
||||
"""Extract base weights and LoRA params from a PEFT-wrapped Experts module.
|
||||
|
||||
Returns:
|
||||
(base_experts, gup_lora, down_lora) where each lora is
|
||||
(lora_A, lora_B, scaling) or None.
|
||||
"""
|
||||
try:
|
||||
from peft.tuners.param_wrapper import ParamWrapper
|
||||
except ImportError:
|
||||
return experts, None, None
|
||||
|
||||
if not isinstance(getattr(experts, "gate_up_proj", None), ParamWrapper):
|
||||
return experts, None, None
|
||||
|
||||
base_experts = experts
|
||||
gup_lora = None
|
||||
down_lora = None
|
||||
|
||||
gup_param = experts.gate_up_proj
|
||||
if isinstance(gup_param, ParamWrapper):
|
||||
lora_A, lora_B, scaling = get_lora_params_from_wrapper(gup_param)
|
||||
if lora_A is not None:
|
||||
num_experts = experts.num_experts
|
||||
rank = lora_A.shape[0] // num_experts
|
||||
from .layers import peft_lora_to_scattermoe
|
||||
|
||||
sm_A, sm_B = peft_lora_to_scattermoe(lora_A, lora_B, num_experts, rank)
|
||||
gup_lora = (sm_A, sm_B, scaling)
|
||||
|
||||
down_param = experts.down_proj
|
||||
if isinstance(down_param, ParamWrapper):
|
||||
lora_A, lora_B, scaling = get_lora_params_from_wrapper(down_param)
|
||||
if lora_A is not None:
|
||||
num_experts = experts.num_experts
|
||||
rank = lora_A.shape[0] // num_experts
|
||||
from .layers import peft_lora_to_scattermoe
|
||||
|
||||
sm_A, sm_B = peft_lora_to_scattermoe(lora_A, lora_B, num_experts, rank)
|
||||
down_lora = (sm_A, sm_B, scaling)
|
||||
|
||||
return base_experts, gup_lora, down_lora
|
||||
|
||||
|
||||
def _get_base_param(param):
|
||||
"""Get the base tensor from a PEFT ParamWrapper or regular Parameter."""
|
||||
try:
|
||||
from peft.tuners.param_wrapper import ParamWrapper
|
||||
|
||||
while isinstance(param, ParamWrapper):
|
||||
param = param.original_parameter
|
||||
except ImportError:
|
||||
pass
|
||||
return param
|
||||
|
||||
|
||||
def _parallel_linear_maybe_lora(
|
||||
x,
|
||||
weight,
|
||||
top_k,
|
||||
sorted_expert_idxs,
|
||||
sorted_scattered_idxs,
|
||||
expert_offsets,
|
||||
lora_tuple,
|
||||
grouped_in,
|
||||
grouped_out,
|
||||
gates=None,
|
||||
):
|
||||
"""Call parallel_linear or parallel_linear_lora depending on whether LoRA is active."""
|
||||
if lora_tuple is not None:
|
||||
lora_A, lora_B, scaling = lora_tuple
|
||||
return parallel_linear_lora(
|
||||
x,
|
||||
weight,
|
||||
top_k,
|
||||
sorted_expert_idxs,
|
||||
sorted_scattered_idxs,
|
||||
expert_offsets,
|
||||
lora_A,
|
||||
lora_B,
|
||||
scaling,
|
||||
grouped_in=grouped_in,
|
||||
grouped_out=grouped_out,
|
||||
gates=gates,
|
||||
)
|
||||
return parallel_linear(
|
||||
x,
|
||||
weight,
|
||||
top_k,
|
||||
sorted_expert_idxs,
|
||||
sorted_scattered_idxs,
|
||||
expert_offsets,
|
||||
grouped_in=grouped_in,
|
||||
grouped_out=grouped_out,
|
||||
gates=gates,
|
||||
)
|
||||
|
||||
|
||||
def scattermoe_experts_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
top_k_index: torch.Tensor,
|
||||
top_k_weights: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""ScatterMoE-accelerated experts forward.
|
||||
|
||||
Drop-in replacement for the standard Experts forward signature used by
|
||||
``@use_experts_implementation``-decorated classes (Gemma4, Mixtral, etc.):
|
||||
``(hidden_states [T, H], top_k_index [T, K], top_k_weights [T, K]) -> [T, H]``
|
||||
"""
|
||||
K = top_k_index.shape[1]
|
||||
|
||||
routing_weights = top_k_weights.to(hidden_states.dtype)
|
||||
sorted_expert_idxs, sorted_scattered_idxs, expert_offsets = flatten_sort_count(
|
||||
top_k_index, num_experts=self.num_experts
|
||||
)
|
||||
|
||||
# Get base weights (unwrap PEFT if needed)
|
||||
gate_up_weight = _get_base_param(self.gate_up_proj).transpose(2, 1)
|
||||
down_weight = _get_base_param(self.down_proj).transpose(2, 1)
|
||||
|
||||
# Extract LoRA params if PEFT is active
|
||||
gup_lora, down_lora = None, None
|
||||
if _has_peft_wrapper(self):
|
||||
_, gup_lora, down_lora = _unwrap_experts_lora(self)
|
||||
|
||||
# Gate-up projection (with optional LoRA)
|
||||
gates_h = _parallel_linear_maybe_lora(
|
||||
hidden_states,
|
||||
gate_up_weight,
|
||||
K,
|
||||
sorted_expert_idxs,
|
||||
sorted_scattered_idxs,
|
||||
expert_offsets,
|
||||
gup_lora,
|
||||
grouped_in=False,
|
||||
grouped_out=True,
|
||||
)
|
||||
gates, h = gates_h.chunk(2, dim=-1)
|
||||
h = self.act_fn(gates) * h
|
||||
|
||||
# Down projection (with optional LoRA + routing weights)
|
||||
output = _parallel_linear_maybe_lora(
|
||||
h,
|
||||
down_weight,
|
||||
1,
|
||||
sorted_expert_idxs,
|
||||
sorted_scattered_idxs,
|
||||
expert_offsets,
|
||||
down_lora,
|
||||
grouped_in=True,
|
||||
grouped_out=False,
|
||||
gates=routing_weights,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def register_scattermoe_experts():
|
||||
"""Register ``"scattermoe"`` in the transformers ExpertsInterface.
|
||||
|
||||
After calling this, any model with ``@use_experts_implementation`` will
|
||||
dispatch to ScatterMoE when ``config._experts_implementation == "scattermoe"``.
|
||||
|
||||
Also patches ``get_correct_experts_implementation`` to accept ``"scattermoe"``
|
||||
as a valid value (transformers hardcodes an allowlist).
|
||||
"""
|
||||
from transformers.integrations.moe import ALL_EXPERTS_FUNCTIONS
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
|
||||
# 1. Register the forward function in the global interface
|
||||
ALL_EXPERTS_FUNCTIONS.register("scattermoe", scattermoe_experts_forward)
|
||||
|
||||
# 2. Patch the validation to accept "scattermoe"
|
||||
_original_get_correct = PreTrainedModel.get_correct_experts_implementation
|
||||
|
||||
def _patched_get_correct(self_model, requested_experts: str | None) -> str:
|
||||
if requested_experts == "scattermoe":
|
||||
return "scattermoe"
|
||||
return _original_get_correct(self_model, requested_experts)
|
||||
|
||||
PreTrainedModel.get_correct_experts_implementation = _patched_get_correct
|
||||
|
||||
|
||||
# Legacy monkeypatch approach (kept for backward compat with existing tests)
|
||||
def patch_gemma4_scattermoe():
|
||||
"""Monkeypatch Gemma4TextExperts.forward with ScatterMoE kernel."""
|
||||
from axolotl.integrations.kernels.constants import resolve_experts_class
|
||||
|
||||
experts_cls = resolve_experts_class("gemma4_text")
|
||||
if experts_cls is None:
|
||||
raise ValueError("Could not resolve Gemma4TextExperts class")
|
||||
|
||||
if hasattr(experts_cls, "_original_forward"):
|
||||
return # already patched
|
||||
|
||||
experts_cls._original_forward = experts_cls.forward
|
||||
experts_cls.forward = scattermoe_experts_forward
|
||||
106
src/axolotl/integrations/kernels/libs/sonicmoe/gemma4_experts.py
Normal file
106
src/axolotl/integrations/kernels/libs/sonicmoe/gemma4_experts.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""
|
||||
SonicMoE-accelerated experts forward for Gemma4.
|
||||
|
||||
Gemma4 has no separate SparseMoeBlock — MoE is embedded in the decoder layer.
|
||||
This module provides a drop-in replacement for ``Gemma4TextExperts.forward``
|
||||
that uses SonicMoE kernels while preserving the original call signature.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from .lora import has_lora, materialize_expert_lora, unwrap_experts_lora
|
||||
|
||||
|
||||
def _get_expert_weights_gemma4(experts_module):
|
||||
"""Extract expert weights from Gemma4TextExperts, applying LoRA if active.
|
||||
|
||||
Returns:
|
||||
(gate_up_weight, down_weight) in SonicMoE layout [dim, dim, E].
|
||||
"""
|
||||
if has_lora(experts_module):
|
||||
base_experts, lora_dict = unwrap_experts_lora(experts_module)
|
||||
gate_up = materialize_expert_lora(
|
||||
base_experts.gate_up_proj, lora_dict.get("gate_up_proj")
|
||||
)
|
||||
down = materialize_expert_lora(
|
||||
base_experts.down_proj, lora_dict.get("down_proj")
|
||||
)
|
||||
else:
|
||||
gate_up = experts_module.gate_up_proj
|
||||
down = experts_module.down_proj
|
||||
|
||||
# Permute to SonicMoE layout:
|
||||
# gate_up: [E, 2*I, H] -> [2*I, H, E]
|
||||
# down: [E, H, I] -> [H, I, E]
|
||||
return gate_up.permute(1, 2, 0), down.permute(1, 2, 0)
|
||||
|
||||
|
||||
def gemma4_sonicmoe_experts_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
top_k_index: torch.Tensor,
|
||||
top_k_weights: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""SonicMoE-accelerated replacement for Gemma4TextExperts.forward.
|
||||
|
||||
Same signature as the original: (hidden_states [T, H], top_k_index [T, K],
|
||||
top_k_weights [T, K]) -> output [T, H].
|
||||
"""
|
||||
from sonicmoe import moe_general_routing_inputs
|
||||
from sonicmoe.enums import ActivationType
|
||||
|
||||
T, _ = hidden_states.shape
|
||||
K = top_k_index.shape[1]
|
||||
E = self.num_experts
|
||||
|
||||
# Convert routing outputs to SonicMoE's flat format
|
||||
# Token indices sorted ascending (required by SonicMoE)
|
||||
token_indices = (
|
||||
torch.arange(T, device=hidden_states.device, dtype=torch.int32)
|
||||
.unsqueeze(1)
|
||||
.expand(T, K)
|
||||
)
|
||||
flat_scores = top_k_weights.to(torch.float32).reshape(-1) # [T*K]
|
||||
flat_token_idx = token_indices.reshape(-1) # [T*K]
|
||||
flat_expert_idx = top_k_index.to(torch.int32).reshape(-1) # [T*K]
|
||||
|
||||
# Get weights (with LoRA materialization if needed)
|
||||
gate_up_weight, down_weight = _get_expert_weights_gemma4(self)
|
||||
gate_up_weight = gate_up_weight.to(hidden_states.dtype)
|
||||
down_weight = down_weight.to(hidden_states.dtype)
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError("SonicMoE requires CUDA. No CUDA device available.")
|
||||
cuda_stream = torch.cuda.current_stream().cuda_stream
|
||||
|
||||
output, _ = moe_general_routing_inputs(
|
||||
hidden_states,
|
||||
flat_scores,
|
||||
flat_token_idx,
|
||||
flat_expert_idx,
|
||||
gate_up_weight,
|
||||
None, # b1 (no gate/up bias)
|
||||
down_weight,
|
||||
None, # b2 (no down bias)
|
||||
E,
|
||||
cuda_stream,
|
||||
ActivationType.GEGLU,
|
||||
False, # is_inference_mode
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def patch_gemma4_sonicmoe():
|
||||
"""Monkeypatch Gemma4TextExperts.forward with SonicMoE kernel."""
|
||||
from axolotl.integrations.kernels.constants import resolve_experts_class
|
||||
|
||||
experts_cls = resolve_experts_class("gemma4_text")
|
||||
if experts_cls is None:
|
||||
raise ValueError("Could not resolve Gemma4TextExperts class")
|
||||
|
||||
if hasattr(experts_cls, "_original_forward"):
|
||||
return # already patched
|
||||
|
||||
experts_cls._original_forward = experts_cls.forward
|
||||
experts_cls.forward = gemma4_sonicmoe_experts_forward
|
||||
@@ -7,6 +7,7 @@ Different MoE architectures use different routing strategies:
|
||||
- glm_moe_dsa / deepseek_v3 / minimax_m2: sigmoid -> topk (with group-based expert selection)
|
||||
- ernie4_5_moe: softmax -> bias correction -> topk -> gather (softmax_bias_topk_routing)
|
||||
- hunyuan_v1_moe: softmax -> topk via gate.wg (softmax_topk_wg_routing)
|
||||
- gemma4_text: RMSNorm -> scale -> proj -> softmax -> topk -> renorm -> per_expert_scale (gemma4_routing)
|
||||
- gpt_oss: topk -> softmax (uses fused moe_TC_softmax_topk_layer, routing_fn=None) [NOT YET SUPPORTED]
|
||||
|
||||
Each model type maps to a (routing_fn, activation_type, router_attr) triple.
|
||||
@@ -66,6 +67,8 @@ def get_model_moe_config(model_type: str):
|
||||
return softmax_bias_topk_routing, ActivationType.SWIGLU, "gate"
|
||||
elif model_type in ("hunyuan_v1_moe",):
|
||||
return softmax_topk_wg_routing, ActivationType.SWIGLU, "gate"
|
||||
elif model_type in ("gemma4_text",):
|
||||
return gemma4_routing, ActivationType.GEGLU, "router"
|
||||
# Fused topk -> softmax path (routing_fn=None):
|
||||
# elif model_type in ("gpt_oss",):
|
||||
# # NOTE: gpt_oss has a router bias which moe_TC_softmax_topk_layer
|
||||
@@ -501,3 +504,73 @@ def softmax_topk_wg_routing(
|
||||
flat_expert_idx = top_indices.to(torch.int32).reshape(-1) # [T*K]
|
||||
|
||||
return flat_scores, flat_token_idx, flat_expert_idx, router_logits
|
||||
|
||||
|
||||
def gemma4_routing(
|
||||
hidden_states: torch.Tensor, moe_block
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Gemma4-style routing: RMSNorm → scale → proj → softmax → topk → renorm → per_expert_scale.
|
||||
|
||||
Gemma4's router (``Gemma4TextRouter``) has a unique structure:
|
||||
1. RMSNorm (without learnable scale) on hidden states
|
||||
2. Multiply by ``scale * hidden_size**-0.5``
|
||||
3. Linear projection to expert scores
|
||||
4. Softmax → topk
|
||||
5. Normalize top-k weights to sum to 1
|
||||
6. Multiply by per-expert learned scales
|
||||
|
||||
The router lives at ``moe_block.router`` (not ``moe_block.gate``).
|
||||
LoRA on the router targets ``router.proj`` (nn.Linear).
|
||||
|
||||
Args:
|
||||
hidden_states: [T, H] flattened token representations
|
||||
moe_block: MoE block module (accesses moe_block.router)
|
||||
|
||||
Returns:
|
||||
router_scores: [T*K] flattened scores (float32)
|
||||
token_indices: [T*K] which token each entry belongs to (int32), sorted ascending
|
||||
expert_indices: [T*K] which expert (int32)
|
||||
router_logits: [T, E] original logits for aux loss
|
||||
"""
|
||||
router = moe_block.router
|
||||
|
||||
# Unwrap PEFT LoRA on router.proj (the nn.Linear)
|
||||
_, proj_weight, proj_lora_delta = unwrap_gate_lora(router.proj)
|
||||
|
||||
T, _ = hidden_states.shape
|
||||
K = router.top_k if hasattr(router, "top_k") else router.config.top_k_experts
|
||||
|
||||
# Reproduce Gemma4TextRouter.forward:
|
||||
# 1. RMSNorm (no scale) + scale param * hidden_size**-0.5
|
||||
normed = router.norm(hidden_states)
|
||||
scaled = normed * router.scale * router.scalar_root_size
|
||||
|
||||
# 2. Project to expert scores
|
||||
router_logits = F.linear(scaled.float(), proj_weight.float()) # [T, E]
|
||||
if proj_lora_delta is not None:
|
||||
router_logits = router_logits + F.linear(
|
||||
scaled.float(), proj_lora_delta.float()
|
||||
)
|
||||
|
||||
# 3. Softmax → topk
|
||||
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
|
||||
top_values, top_indices = torch.topk(router_probs, K, dim=-1) # [T, K]
|
||||
|
||||
# 4. Normalize top-k weights
|
||||
top_values = top_values / top_values.sum(dim=-1, keepdim=True)
|
||||
|
||||
# 5. Per-expert scale
|
||||
top_values = top_values * router.per_expert_scale[top_indices]
|
||||
|
||||
# Flatten for moe_general_routing_inputs
|
||||
token_indices = (
|
||||
torch.arange(T, device=hidden_states.device, dtype=torch.int32)
|
||||
.unsqueeze(1)
|
||||
.expand(T, K)
|
||||
)
|
||||
|
||||
flat_scores = top_values.to(torch.float32).reshape(-1) # [T*K]
|
||||
flat_token_idx = token_indices.reshape(-1) # [T*K]
|
||||
flat_expert_idx = top_indices.to(torch.int32).reshape(-1) # [T*K]
|
||||
|
||||
return flat_scores, flat_token_idx, flat_expert_idx, router_logits
|
||||
|
||||
@@ -61,20 +61,31 @@ class KernelsPlugin(BasePlugin):
|
||||
return "axolotl.integrations.kernels.KernelsArgs"
|
||||
|
||||
def pre_model_load(self, cfg):
|
||||
from axolotl.integrations.kernels.constants import SPARSE_MOE_BLOCK
|
||||
from axolotl.integrations.kernels.constants import (
|
||||
SPARSE_MOE_BLOCK,
|
||||
is_experts_only_model,
|
||||
)
|
||||
|
||||
# Prefer text backbone type for VLMs, but fall back to base type
|
||||
# when the text type isn't in the supported mapping (e.g. qwen3_5_moe_text)
|
||||
moe_model_type = cfg.model_config_type_text or cfg.model_config_type
|
||||
if (
|
||||
moe_model_type not in SPARSE_MOE_BLOCK
|
||||
and not is_experts_only_model(moe_model_type)
|
||||
and cfg.model_config_type in SPARSE_MOE_BLOCK
|
||||
):
|
||||
moe_model_type = cfg.model_config_type
|
||||
|
||||
if cfg.use_scattermoe:
|
||||
self._register_kernels()
|
||||
self._kernelize_model(moe_model_type)
|
||||
if is_experts_only_model(moe_model_type):
|
||||
# Models like Gemma4 where MoE is embedded in the decoder layer
|
||||
# — register ScatterMoE in the ExpertsInterface so that
|
||||
# @use_experts_implementation dispatches to it.
|
||||
self._register_experts_interface()
|
||||
cfg.experts_implementation = "scattermoe"
|
||||
else:
|
||||
self._kernelize_model(moe_model_type)
|
||||
elif cfg.use_sonicmoe:
|
||||
if not importlib.util.find_spec("sonicmoe"):
|
||||
raise RuntimeError(
|
||||
@@ -84,14 +95,24 @@ class KernelsPlugin(BasePlugin):
|
||||
|
||||
_check_sonicmoe_gpu_compat()
|
||||
|
||||
from axolotl.integrations.kernels.libs.sonicmoe import patch_sonicmoe
|
||||
if is_experts_only_model(moe_model_type):
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.gemma4_experts import (
|
||||
patch_gemma4_sonicmoe,
|
||||
)
|
||||
|
||||
LOG.info(f"Applying SonicMoE patches for model type: {moe_model_type}")
|
||||
patch_sonicmoe(
|
||||
moe_model_type,
|
||||
torch_compile=bool(getattr(cfg, "torch_compile", False)),
|
||||
base_model_type=cfg.model_config_type,
|
||||
)
|
||||
LOG.info(
|
||||
f"Applying SonicMoE experts-level patch for model type: {moe_model_type}"
|
||||
)
|
||||
patch_gemma4_sonicmoe()
|
||||
else:
|
||||
from axolotl.integrations.kernels.libs.sonicmoe import patch_sonicmoe
|
||||
|
||||
LOG.info(f"Applying SonicMoE patches for model type: {moe_model_type}")
|
||||
patch_sonicmoe(
|
||||
moe_model_type,
|
||||
torch_compile=bool(getattr(cfg, "torch_compile", False)),
|
||||
base_model_type=cfg.model_config_type,
|
||||
)
|
||||
|
||||
def _register_kernels(self):
|
||||
from kernels import (
|
||||
@@ -139,3 +160,16 @@ class KernelsPlugin(BasePlugin):
|
||||
replace_kernel_forward_from_hub(
|
||||
model_moe_cls, "HFScatterMoEParallelExperts"
|
||||
)
|
||||
|
||||
def _register_experts_interface(self):
|
||||
"""Register ScatterMoE in the transformers ExpertsInterface.
|
||||
|
||||
This allows @use_experts_implementation-decorated Experts classes
|
||||
to dispatch to ScatterMoE when config._experts_implementation == "scattermoe".
|
||||
"""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.gemma4_experts import (
|
||||
register_scattermoe_experts,
|
||||
)
|
||||
|
||||
register_scattermoe_experts()
|
||||
LOG.info("Registered 'scattermoe' in transformers ExpertsInterface")
|
||||
|
||||
@@ -73,6 +73,44 @@ QKV_PATCHES = [
|
||||
query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2)
|
||||
key_states = self.k_norm(key_states.view(hidden_shape)).transpose(1, 2)
|
||||
value_states = value_states.view(hidden_shape).transpose(1, 2)
|
||||
""".lstrip("\n"),
|
||||
),
|
||||
# Gemma4: norm between proj and transpose, RoPE between norm and transpose,
|
||||
# conditional KV sharing (is_kv_shared_layer), v_proj may be None (attention_k_eq_v).
|
||||
# We only fuse the projection calls; norms, RoPE, and KV sharing stay as-is.
|
||||
(
|
||||
"""
|
||||
query_states = self.q_proj(hidden_states).view(hidden_shape)
|
||||
query_states = self.q_norm(query_states)
|
||||
query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2)
|
||||
query_states = query_states.transpose(1, 2)
|
||||
|
||||
# For layers with shared KV (from kv sharing point onwards), we reuse the same keys/values states as the last non-sharing layer
|
||||
if self.is_kv_shared_layer and past_key_values is not None:
|
||||
key_states, value_states = past_key_values.shared_layers[self.kv_shared_layer_index]
|
||||
# Device of past layer may be different from current one
|
||||
key_states = key_states.to(query_states.device)
|
||||
value_states = value_states.to(query_states.device)
|
||||
else:
|
||||
key_states = self.k_proj(hidden_states).view(hidden_shape)
|
||||
value_states = self.v_proj(hidden_states).view(hidden_shape) if self.v_proj is not None else key_states
|
||||
""".lstrip("\n"),
|
||||
"""
|
||||
query_states, key_states, value_states = self.apply_qkv(hidden_states)
|
||||
query_states = query_states.view(hidden_shape)
|
||||
query_states = self.q_norm(query_states)
|
||||
query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2)
|
||||
query_states = query_states.transpose(1, 2)
|
||||
|
||||
# For layers with shared KV (from kv sharing point onwards), we reuse the same keys/values states as the last non-sharing layer
|
||||
if self.is_kv_shared_layer and past_key_values is not None:
|
||||
key_states, value_states = past_key_values.shared_layers[self.kv_shared_layer_index]
|
||||
# Device of past layer may be different from current one
|
||||
key_states = key_states.to(query_states.device)
|
||||
value_states = value_states.to(query_states.device)
|
||||
else:
|
||||
key_states = key_states.view(hidden_shape)
|
||||
value_states = value_states.view(hidden_shape) if self.v_proj is not None else key_states
|
||||
""".lstrip("\n"),
|
||||
),
|
||||
]
|
||||
@@ -113,6 +151,23 @@ def original_apply_qkv(
|
||||
return query_states, key_states, value_states
|
||||
|
||||
|
||||
def original_apply_qkv_optional_v(
|
||||
self: nn.Module, hidden_states: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""QKV projection for models where v_proj may be None (e.g. Gemma4 attention_k_eq_v).
|
||||
|
||||
When v_proj is None, key_states are reused as value_states.
|
||||
"""
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
if self.v_proj is not None:
|
||||
value_states = self.v_proj(hidden_states)
|
||||
else:
|
||||
value_states = key_states
|
||||
|
||||
return query_states, key_states, value_states
|
||||
|
||||
|
||||
def original_apply_o(self: nn.Module, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Original implementation of output projection without optimizations.
|
||||
@@ -183,6 +238,11 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
|
||||
|
||||
return Gemma3Attention
|
||||
|
||||
if model_type in ("gemma4", "gemma4_text"):
|
||||
from transformers.models.gemma4.modeling_gemma4 import Gemma4TextAttention
|
||||
|
||||
return Gemma4TextAttention
|
||||
|
||||
try:
|
||||
# Dynamically import the module and attention class
|
||||
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
||||
@@ -410,14 +470,24 @@ def apply_lora_kernel_patches(
|
||||
# Add QKV, O fallback implementations to start
|
||||
# These will be overwritten later (if some conditions apply)
|
||||
for self_attn in find_self_attn_in_layer(layer):
|
||||
self_attn.apply_qkv = types.MethodType(original_apply_qkv, self_attn)
|
||||
# Use v_proj-optional fallback for models where v_proj can be None
|
||||
# (e.g. Gemma4 with attention_k_eq_v=True)
|
||||
if getattr(self_attn, "v_proj", None) is None:
|
||||
self_attn.apply_qkv = types.MethodType(
|
||||
original_apply_qkv_optional_v, self_attn
|
||||
)
|
||||
else:
|
||||
self_attn.apply_qkv = types.MethodType(original_apply_qkv, self_attn)
|
||||
self_attn.apply_o = types.MethodType(original_apply_o, self_attn)
|
||||
|
||||
if cfg.lora_qkv_kernel:
|
||||
# Query, key, value patching
|
||||
# Filter out None projections (e.g. Gemma4 v_proj when attention_k_eq_v=True)
|
||||
proj_names = ["q_proj", "k_proj", "v_proj"]
|
||||
layer_modules = [
|
||||
getattr(self_attn, linear_proj)
|
||||
for linear_proj in ["q_proj", "k_proj", "v_proj"]
|
||||
getattr(self_attn, name)
|
||||
for name in proj_names
|
||||
if getattr(self_attn, name, None) is not None
|
||||
]
|
||||
can_patch_qkv = all(
|
||||
hasattr(module, "lora_A") for module in layer_modules
|
||||
|
||||
@@ -111,7 +111,6 @@ def patch_qwen3_next_gateddelta_layer():
|
||||
"""Patch Qwen3NextGatedDeltaNet to parse cu_seqlens and pass to chunk_gated_delta_rule"""
|
||||
try:
|
||||
from transformers.models.qwen3_next.modeling_qwen3_next import (
|
||||
Qwen3NextDynamicCache,
|
||||
Qwen3NextGatedDeltaNet,
|
||||
apply_mask_to_padding_states,
|
||||
)
|
||||
@@ -125,8 +124,7 @@ def patch_qwen3_next_gateddelta_layer():
|
||||
def patched_gated_delta_net_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cache_params: Optional[Qwen3NextDynamicCache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
cache_params=None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
@@ -137,9 +135,8 @@ def patch_qwen3_next_gateddelta_layer():
|
||||
|
||||
use_precomputed_states = (
|
||||
cache_params is not None
|
||||
and cache_params.has_previous_state
|
||||
and cache_params.has_previous_state(self.layer_idx)
|
||||
and seq_len == 1
|
||||
and cache_position is not None
|
||||
)
|
||||
|
||||
# Compute cu_seqlens early for use by both causal_conv1d and chunk_gated_delta_rule
|
||||
@@ -148,9 +145,9 @@ def patch_qwen3_next_gateddelta_layer():
|
||||
cu_seqlens = get_cu_seqlens(position_ids=position_ids)
|
||||
|
||||
# getting projected states from cache if it exists
|
||||
if cache_params is not None:
|
||||
conv_state = cache_params.conv_states[self.layer_idx]
|
||||
recurrent_state = cache_params.recurrent_states[self.layer_idx]
|
||||
if use_precomputed_states:
|
||||
conv_state = cache_params.layers[self.layer_idx].conv_states
|
||||
recurrent_state = cache_params.layers[self.layer_idx].recurrent_states
|
||||
|
||||
projected_states_qkvz = self.in_proj_qkvz(hidden_states)
|
||||
projected_states_ba = self.in_proj_ba(hidden_states)
|
||||
@@ -162,10 +159,9 @@ def patch_qwen3_next_gateddelta_layer():
|
||||
)
|
||||
|
||||
mixed_qkv = torch.cat((query, key, value), dim=-1) # [B, T, D]
|
||||
mixed_qkv = mixed_qkv.transpose(1, 2) # [B, D, T]
|
||||
|
||||
if use_precomputed_states:
|
||||
# Inference single-token path: causal_conv1d_update expects [B, D, T]
|
||||
mixed_qkv = mixed_qkv.transpose(1, 2)
|
||||
mixed_qkv = self.causal_conv1d_update(
|
||||
mixed_qkv,
|
||||
conv_state,
|
||||
@@ -173,19 +169,17 @@ def patch_qwen3_next_gateddelta_layer():
|
||||
self.conv1d.bias,
|
||||
self.activation,
|
||||
)
|
||||
mixed_qkv = mixed_qkv.transpose(1, 2)
|
||||
else:
|
||||
if cache_params is not None:
|
||||
# Cache state expects [B, D, T] for the inference update path
|
||||
mixed_qkv_t = mixed_qkv.transpose(1, 2)
|
||||
conv_state = F.pad(
|
||||
mixed_qkv_t,
|
||||
(self.conv_kernel_size - mixed_qkv_t.shape[-1], 0),
|
||||
mixed_qkv,
|
||||
(self.conv_kernel_size - mixed_qkv.shape[-1], 0),
|
||||
)
|
||||
cache_params.conv_states[self.layer_idx] = conv_state
|
||||
cache_params.update_conv_state(conv_state, self.layer_idx)
|
||||
|
||||
if fla_causal_conv1d is not None:
|
||||
# FLA Triton causal_conv1d: [B, T, D] in/out, with cu_seqlens support
|
||||
mixed_qkv = mixed_qkv.transpose(1, 2) # [B, T, D] for FLA
|
||||
mixed_qkv, _ = fla_causal_conv1d(
|
||||
x=mixed_qkv,
|
||||
weight=self.conv1d.weight.squeeze(1),
|
||||
@@ -193,6 +187,15 @@ def patch_qwen3_next_gateddelta_layer():
|
||||
activation=self.activation,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
mixed_qkv = mixed_qkv.transpose(1, 2) # back to [B, D, T]
|
||||
elif self.causal_conv1d_fn is not None:
|
||||
mixed_qkv = self.causal_conv1d_fn(
|
||||
x=mixed_qkv,
|
||||
weight=self.conv1d.weight.squeeze(1),
|
||||
bias=self.conv1d.bias,
|
||||
activation=self.activation,
|
||||
seq_idx=None,
|
||||
)
|
||||
else:
|
||||
# PyTorch fallback (no cu_seqlens support)
|
||||
if cu_seqlens is not None and cu_seqlens.shape[0] > batch_size + 1:
|
||||
@@ -203,11 +206,9 @@ def patch_qwen3_next_gateddelta_layer():
|
||||
LOG.warning_once(
|
||||
"FLA causal_conv1d not available. Falling back to PyTorch conv1d."
|
||||
)
|
||||
mixed_qkv = mixed_qkv.transpose(1, 2)
|
||||
mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len])
|
||||
mixed_qkv = mixed_qkv.transpose(1, 2)
|
||||
|
||||
# mixed_qkv is [B, T, D] in all paths
|
||||
mixed_qkv = mixed_qkv.transpose(1, 2) # [B, T, D]
|
||||
query, key, value = torch.split(
|
||||
mixed_qkv,
|
||||
[
|
||||
@@ -255,7 +256,7 @@ def patch_qwen3_next_gateddelta_layer():
|
||||
|
||||
# Update cache
|
||||
if cache_params is not None:
|
||||
cache_params.recurrent_states[self.layer_idx] = last_recurrent_state
|
||||
cache_params.update_recurrent_state(last_recurrent_state, self.layer_idx)
|
||||
|
||||
z_shape_og = z.shape
|
||||
# reshape input data into 2D tensor
|
||||
|
||||
271
src/axolotl/utils/chat_templates/templates/gemma4.jinja
Normal file
271
src/axolotl/utils/chat_templates/templates/gemma4.jinja
Normal file
@@ -0,0 +1,271 @@
|
||||
{%- macro format_parameters(properties, required) -%}
|
||||
{%- set standard_keys = ['description', 'type', 'properties', 'required', 'nullable'] -%}
|
||||
{%- set ns = namespace(found_first=false) -%}
|
||||
{%- for key, value in properties | dictsort -%}
|
||||
{%- set add_comma = false -%}
|
||||
{%- if key not in standard_keys -%}
|
||||
{%- if ns.found_first %},{% endif -%}
|
||||
{%- set ns.found_first = true -%}
|
||||
{{ key }}:{
|
||||
{%- if value['description'] -%}
|
||||
description:<|"|>{{ value['description'] }}<|"|>
|
||||
{%- set add_comma = true -%}
|
||||
{%- endif -%}
|
||||
{%- if value['nullable'] %}
|
||||
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||
nullable:true
|
||||
{%- endif -%}
|
||||
{%- if value['type'] | upper == 'STRING' -%}
|
||||
{%- if value['enum'] -%}
|
||||
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||
enum:{{ format_argument(value['enum']) }}
|
||||
{%- endif -%}
|
||||
{%- elif value['type'] | upper == 'OBJECT' -%}
|
||||
,properties:{
|
||||
{%- if value['properties'] is defined and value['properties'] is mapping -%}
|
||||
{{- format_parameters(value['properties'], value['required'] | default([])) -}}
|
||||
{%- elif value is mapping -%}
|
||||
{{- format_parameters(value, value['required'] | default([])) -}}
|
||||
{%- endif -%}
|
||||
}
|
||||
{%- if value['required'] -%}
|
||||
,required:[
|
||||
{%- for item in value['required'] | default([]) -%}
|
||||
<|"|>{{- item -}}<|"|>
|
||||
{%- if not loop.last %},{% endif -%}
|
||||
{%- endfor -%}
|
||||
]
|
||||
{%- endif -%}
|
||||
{%- elif value['type'] | upper == 'ARRAY' -%}
|
||||
{%- if value['items'] is mapping and value['items'] -%}
|
||||
,items:{
|
||||
{%- set ns_items = namespace(found_first=false) -%}
|
||||
{%- for item_key, item_value in value['items'] | dictsort -%}
|
||||
{%- if item_value is not none -%}
|
||||
{%- if ns_items.found_first %},{% endif -%}
|
||||
{%- set ns_items.found_first = true -%}
|
||||
{%- if item_key == 'properties' -%}
|
||||
properties:{
|
||||
{%- if item_value is mapping -%}
|
||||
{{- format_parameters(item_value, value['items']['required'] | default([])) -}}
|
||||
{%- endif -%}
|
||||
}
|
||||
{%- elif item_key == 'required' -%}
|
||||
required:[
|
||||
{%- for req_item in item_value -%}
|
||||
<|"|>{{- req_item -}}<|"|>
|
||||
{%- if not loop.last %},{% endif -%}
|
||||
{%- endfor -%}
|
||||
]
|
||||
{%- elif item_key == 'type' -%}
|
||||
{%- if item_value is string -%}
|
||||
type:{{ format_argument(item_value | upper) }}
|
||||
{%- else -%}
|
||||
type:{{ format_argument(item_value | map('upper') | list) }}
|
||||
{%- endif -%}
|
||||
{%- else -%}
|
||||
{{ item_key }}:{{ format_argument(item_value) }}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||
type:<|"|>{{ value['type'] | upper }}<|"|>}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- endmacro -%}
|
||||
{%- macro format_function_declaration(tool_data) -%}
|
||||
declaration:{{- tool_data['function']['name'] -}}{description:<|"|>{{- tool_data['function']['description'] -}}<|"|>
|
||||
{%- set params = tool_data['function']['parameters'] -%}
|
||||
{%- if params -%}
|
||||
,parameters:{
|
||||
{%- if params['properties'] -%}
|
||||
properties:{ {{- format_parameters(params['properties'], params['required']) -}} },
|
||||
{%- endif -%}
|
||||
{%- if params['required'] -%}
|
||||
required:[
|
||||
{%- for item in params['required'] -%}
|
||||
<|"|>{{- item -}}<|"|>
|
||||
{{- ',' if not loop.last -}}
|
||||
{%- endfor -%}
|
||||
],
|
||||
{%- endif -%}
|
||||
{%- if params['type'] -%}
|
||||
type:<|"|>{{- params['type'] | upper -}}<|"|>}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- if 'response' in tool_data['function'] -%}
|
||||
{%- set response_declaration = tool_data['function']['response'] -%}
|
||||
,response:{
|
||||
{%- if response_declaration['description'] -%}
|
||||
description:<|"|>{{- response_declaration['description'] -}}<|"|>,
|
||||
{%- endif -%}
|
||||
{%- if response_declaration['type'] | upper == 'OBJECT' -%}
|
||||
type:<|"|>{{- response_declaration['type'] | upper -}}<|"|>}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
}
|
||||
{%- endmacro -%}
|
||||
{%- macro format_argument(argument, escape_keys=True) -%}
|
||||
{%- if argument is string -%}
|
||||
{{- '<|"|>' + argument + '<|"|>' -}}
|
||||
{%- elif argument is boolean -%}
|
||||
{{- 'true' if argument else 'false' -}}
|
||||
{%- elif argument is mapping -%}
|
||||
{{- '{' -}}
|
||||
{%- set ns = namespace(found_first=false) -%}
|
||||
{%- for key, value in argument | dictsort -%}
|
||||
{%- if ns.found_first %},{% endif -%}
|
||||
{%- set ns.found_first = true -%}
|
||||
{%- if escape_keys -%}
|
||||
{{- '<|"|>' + key + '<|"|>' -}}
|
||||
{%- else -%}
|
||||
{{- key -}}
|
||||
{%- endif -%}
|
||||
:{{- format_argument(value, escape_keys=escape_keys) -}}
|
||||
{%- endfor -%}
|
||||
{{- '}' -}}
|
||||
{%- elif argument is sequence -%}
|
||||
{{- '[' -}}
|
||||
{%- for item in argument -%}
|
||||
{{- format_argument(item, escape_keys=escape_keys) -}}
|
||||
{%- if not loop.last %},{% endif -%}
|
||||
{%- endfor -%}
|
||||
{{- ']' -}}
|
||||
{%- else -%}
|
||||
{{- argument -}}
|
||||
{%- endif -%}
|
||||
{%- endmacro -%}
|
||||
{#- Removes '<|channel>...<channel|>' thinking blocks from model output.
|
||||
Splits on the end token '<channel|>', then checks each part for the start
|
||||
token '<|channel>' and keeps only the text before it. -#}
|
||||
{%- macro strip_thinking(text) -%}
|
||||
{%- set ns = namespace(cleaned='') -%}
|
||||
{%- for part in text.split('<channel|>') -%}
|
||||
{%- if '<|channel>' in part -%}
|
||||
{%- set ns.cleaned = ns.cleaned + part.split('<|channel>')[0] -%}
|
||||
{%- else -%}
|
||||
{%- set ns.cleaned = ns.cleaned + part -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{{- ns.cleaned | trim -}}
|
||||
{%- endmacro -%}
|
||||
|
||||
{%- set ns = namespace(prev_message_type=None) -%}
|
||||
{%- set loop_messages = messages -%}
|
||||
{{ bos_token }}
|
||||
{#- Handle System/Tool Definitions Block -#}
|
||||
{%- if (enable_thinking is defined and enable_thinking) or tools or messages[0]['role'] in ['system', 'developer'] -%}
|
||||
{{- '<|turn>system\n' -}}
|
||||
|
||||
{#- Inject Thinking token at the very top of the FIRST system turn -#}
|
||||
{%- if enable_thinking is defined and enable_thinking -%}
|
||||
{{- '<|think|>' -}}
|
||||
{%- set ns.prev_message_type = 'think' -%}
|
||||
{%- endif -%}
|
||||
|
||||
{%- if messages[0]['role'] in ['system', 'developer'] -%}
|
||||
{{- messages[0]['content'] | trim -}}
|
||||
{%- set loop_messages = messages[1:] -%}
|
||||
{%- endif -%}
|
||||
|
||||
{%- if tools -%}
|
||||
{%- for tool in tools %}
|
||||
{{- '<|tool>' -}}
|
||||
{{- format_function_declaration(tool) | trim -}}
|
||||
{{- '<tool|>' -}}
|
||||
{%- endfor %}
|
||||
{%- set ns.prev_message_type = 'tool' -%}
|
||||
{%- endif -%}
|
||||
|
||||
{{- '<turn|>\n' -}}
|
||||
{%- endif %}
|
||||
|
||||
{#- Loop through messages -#}
|
||||
{%- for message in loop_messages -%}
|
||||
{#- Reset so only special message types (tool_call, image, etc.) influence
|
||||
the generation prompt formatting below. Plain text leaves it as None. -#}
|
||||
{%- set ns.prev_message_type = None -%}
|
||||
{%- set role = 'model' if message['role'] == 'assistant' else message['role'] -%}
|
||||
{{- '<|turn>' + role + '\n' }}
|
||||
|
||||
{%- if message['tool_calls'] -%}
|
||||
{%- for tool_call in message['tool_calls'] -%}
|
||||
{%- set function = tool_call['function'] -%}
|
||||
{{- '<|tool_call>call:' + function['name'] + '{' -}}
|
||||
{%- if function['arguments'] is mapping -%}
|
||||
{%- set ns_args = namespace(found_first=false) -%}
|
||||
{%- for key, value in function['arguments'] | dictsort -%}
|
||||
{%- if ns_args.found_first %},{% endif -%}
|
||||
{%- set ns_args.found_first = true -%}
|
||||
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
|
||||
{%- endfor -%}
|
||||
{%- elif function['arguments'] is string -%}
|
||||
{{- function['arguments'] -}}
|
||||
{%- endif -%}
|
||||
{{- '}<tool_call|>' -}}
|
||||
{%- endfor -%}
|
||||
{%- set ns.prev_message_type = 'tool_call' -%}
|
||||
{%- endif -%}
|
||||
|
||||
{%- if message['tool_responses'] -%}
|
||||
{#- Tool Response handling -#}
|
||||
{%- for tool_response in message['tool_responses'] -%}
|
||||
{{- '<|tool_response>' -}}
|
||||
{%- if tool_response['response'] is mapping -%}
|
||||
{{- 'response:' + tool_response['name'] | default('unknown') + '{' -}}
|
||||
{%- for key, value in tool_response['response'] | dictsort -%}
|
||||
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
|
||||
{%- if not loop.last %},{% endif -%}
|
||||
{%- endfor -%}
|
||||
{{- '}' -}}
|
||||
{%- else -%}
|
||||
{{- 'response:' + tool_response['name'] | default('unknown') + '{value:' + format_argument(tool_response['response'], escape_keys=False) + '}' -}}
|
||||
{%- endif -%}
|
||||
{{- '<tool_response|>' -}}
|
||||
{%- endfor -%}
|
||||
{%- set ns.prev_message_type = 'tool_response' -%}
|
||||
{%- endif -%}
|
||||
|
||||
{%- if message['content'] is string -%}
|
||||
{%- if role == 'model' -%}
|
||||
{{- strip_thinking(message['content']) -}}
|
||||
{%- else -%}
|
||||
{{- message['content'] | trim -}}
|
||||
{%- endif -%}
|
||||
{%- elif message['content'] is sequence -%}
|
||||
{%- for item in message['content'] -%}
|
||||
{%- if item['type'] == 'text' -%}
|
||||
{%- if role == 'model' -%}
|
||||
{{- strip_thinking(item['text']) -}}
|
||||
{%- else -%}
|
||||
{{- item['text'] | trim -}}
|
||||
{%- endif -%}
|
||||
{%- elif item['type'] == 'image' -%}
|
||||
{{- '\n\n<|image|>\n\n' -}}
|
||||
{%- set ns.prev_message_type = 'image' -%}
|
||||
{%- elif item['type'] == 'audio' -%}
|
||||
{{- '<|audio|>' -}}
|
||||
{%- set ns.prev_message_type = 'audio' -%}
|
||||
{%- elif item['type'] == 'video' -%}
|
||||
{{- '\n\n<|video|>\n\n' -}}
|
||||
{%- set ns.prev_message_type = 'video' -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- endif -%}
|
||||
|
||||
{%- if not (message['tool_responses'] and not message['content']) -%}
|
||||
{{- '<turn|>\n' -}}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
|
||||
{%- if add_generation_prompt -%}
|
||||
{%- if ns.prev_message_type != 'tool_response' -%}
|
||||
{{- '<|turn>model\n' -}}
|
||||
{%- endif -%}
|
||||
{%- if not enable_thinking | default(false) -%}
|
||||
{{- '<|channel>thought\n<channel|>' -}}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
@@ -72,6 +72,7 @@ class ChatTemplate(str, Enum):
|
||||
qwen2_vl = "qwen2_vl"
|
||||
gemma3 = "gemma3"
|
||||
gemma3n = "gemma3n"
|
||||
gemma4 = "gemma4"
|
||||
command_a = "command_a"
|
||||
command_a_tool_use = "command_a_tool_use"
|
||||
command_a_rag = "command_a_rag"
|
||||
|
||||
@@ -69,9 +69,7 @@ class TRLConfig(BaseModel):
|
||||
generation_batch_size: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Batch size for generation. Controls how many unique "
|
||||
"prompts are generated per step. For full DP utilization, set to "
|
||||
"num_generations * data_parallel_size (or a multiple thereof)."
|
||||
"description": "Batch size for generation. Controls how many unique prompts are generated per step. Should be num_generations * data_parallel_size for full DP utilization."
|
||||
},
|
||||
)
|
||||
num_generations: int | None = Field(
|
||||
|
||||
1052
tests/integrations/test_gemma4_moe.py
Normal file
1052
tests/integrations/test_gemma4_moe.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user