gemma4 support (#3574)
Some checks failed
ci-cd / build-axolotl (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.11, 2.9.0) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 128, 12.8.1, true, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-uv (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-uv (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-uv (<nil>, 128, 12.8.1, true, linux/amd64,linux/arm64, 3.12, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-uv (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-uv (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
publish pypi / Create Release (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.11, 2.9.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 128, 12.8.1, true, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud-uv (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-uv (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud-uv (<nil>, 128, 12.8.1, true, linux/amd64,linux/arm64, 3.12, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-uv (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-uv (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud-no-tmux (<nil>, 128, 12.8.1, true, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-no-tmux (<nil>, 130, 13.0.0, <nil>, 3.11, 2.9.1) (push) Has been cancelled
publish pypi / Upload release to PyPI (push) Has been cancelled

* gemma4 support

* fixes

* chore: lint
This commit is contained in:
Wing Lian
2026-04-02 17:46:46 -04:00
committed by GitHub
parent 573726c839
commit 08fc7de87e
16 changed files with 2082 additions and 45 deletions

View File

@@ -0,0 +1,104 @@
# Gemma 4 26B-A4B MoE QLoRA with ScatterMoE kernels
#
# Validated: 50 steps on FineTome-100k, loss 7.4 -> 2.4, single RTX 5090 (32GB)
#
# Key notes:
# - Flash Attention 2 is NOT supported (global_head_dim=512 > FA2 max of 256).
# Use sdp_attention instead.
# - Gemma 4 is multimodal (text+vision+audio). For text-only SFT, restrict
# LoRA to the text backbone via lora_target_linear_modules regex.
# - MoE experts use `experts_implementation: scattermoe` — Gemma 4 embeds MoE
# directly in the decoder layer (no SparseMoeBlock), so we register ScatterMoE
# via the transformers ExpertsInterface.
# - Expert LoRA targets are `experts.gate_up_proj` / `experts.down_proj`
# (no `mlp.` prefix, unlike Qwen/Mixtral).
# - micro_batch_size: 1 fits 2048 seq_len on 32GB GPU with SDP attention.
# Use micro_batch_size: 4 with 1024 seq_len, or on 48GB+ GPUs.
base_model: google/gemma-4-26B-A4B
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
- axolotl.integrations.kernels.KernelsPlugin
- axolotl.integrations.liger.LigerPlugin
use_kernels: true
use_scattermoe: true
experts_implementation: scattermoe
torch_compile: false
liger_layer_norm: true
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_rms_norm_gated: true
strict: false
chat_template: gemma4
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:10%]
field_messages: conversations
message_property_mappings:
role: from
content: value
val_set_size: 0.05
output_dir: ./outputs/gemma4-26b-a4b-qlora
sequence_len: 2048
sample_packing: true
load_in_4bit: true
quantize_moe_experts: true
adapter: qlora
lora_r: 16
lora_alpha: 32
lora_dropout: 0
# Restrict LoRA to text backbone only (skip vision/audio encoders).
# lora_target_modules is intentionally empty — all module targeting is done
# via regex in lora_target_linear_modules below.
lora_target_modules: []
lora_target_linear_modules:
- language_model\.model\.layers\.\d+\.self_attn\.(q|k|v|o)_proj
# MoE expert LoRA (3D Parameter tensors, not nn.Linear)
lora_target_parameters:
- experts.gate_up_proj
- experts.down_proj
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false
bnb_config_kwargs:
bnb_4bit_use_double_quant: true
wandb_project: gemma4-qlora
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
gradient_checkpointing: true
activation_offloading: true
logging_steps: 1
# FA2 not supported — Gemma4 global_head_dim=512 exceeds FA2 max of 256
flash_attention: false
sdp_attention: true
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:

View File

@@ -12,7 +12,7 @@ packaging==26.0
huggingface_hub>=1.1.7
peft>=0.18.1
tokenizers>=0.22.1
transformers==5.4.0
transformers==5.5.0
accelerate==1.13.0
datasets==4.5.0
deepspeed>=0.18.6,<0.19.0

View File

@@ -381,6 +381,15 @@ class AxolotlTrainer(
# Store per-step trainable tokens for throughput calculation
self.state.tokens["trainable_tokens"] = trainable_tokens.detach().cpu()
# Gemma4 requires mm_token_type_ids during training (even for text-only).
# Inject zeros (= text token type) when not provided by the data collator.
if (
"mm_token_type_ids" not in inputs
and "input_ids" in inputs
and getattr(getattr(model, "config", None), "model_type", None) == "gemma4"
):
inputs["mm_token_type_ids"] = torch.zeros_like(inputs["input_ids"])
if self.args.orpo_alpha:
return self.orpo_compute_loss(
model,
@@ -508,12 +517,24 @@ class AxolotlTrainer(
)
# Perform a single forward pass
forward_kwargs = {
"input_ids": concat_inputs["input_ids"],
"attention_mask": concat_inputs["attention_mask"],
"labels": concat_inputs["labels"],
}
# Gemma4 requires mm_token_type_ids during training (even for text-only)
if (
getattr(getattr(model, "config", None), "model_type", None) == "gemma4"
and "mm_token_type_ids" not in concat_inputs
):
forward_kwargs["mm_token_type_ids"] = torch.zeros_like(
concat_inputs["input_ids"]
)
elif "mm_token_type_ids" in concat_inputs:
forward_kwargs["mm_token_type_ids"] = concat_inputs["mm_token_type_ids"]
outputs = model(
**{
"input_ids": concat_inputs["input_ids"],
"attention_mask": concat_inputs["attention_mask"],
"labels": concat_inputs["labels"],
},
**forward_kwargs,
output_hidden_states=True,
)

View File

@@ -28,7 +28,7 @@ use_scattermoe: true
use_sonicmoe: true
```
**Important:** Setting `experts_implementation` is incompatible with custom kernel options.
**Important:** Setting `experts_implementation` to `batched_mm` or `grouped_mm` is incompatible with custom kernel options. The exception is `experts_implementation: scattermoe`, which is used for models like Gemma 4 that embed MoE directly in the decoder layer (no SparseMoeBlock) and dispatch through the transformers `ExpertsInterface`.
### SonicMoE installation
@@ -63,7 +63,7 @@ Both paths use the shared `resolve_moe_block_classes` utility in `constants.py`
## Model Support Matrix
All models use the **SwiGLU** activation (`act_fn(gate) * up`). Neither kernel currently supports non-SwiGLU MoE architectures.
Most models use the **SwiGLU** activation (`silu(gate) * up`). Gemma 4 uses **GEGLU** (`gelu(gate) * up`). ScatterMoE supports any gated activation (activation is applied in Python between kernel calls). SonicMoE supports SwiGLU, GEGLU, and REGLU via its `ActivationType` enum.
### Routing strategies
@@ -76,6 +76,7 @@ All models use the **SwiGLU** activation (`act_fn(gate) * up`). Neither kernel c
| softmax → bias correction → topk | Softmax, bias via `gate.moe_statics`, topk, gather from original probs, clamp-based renorm | No | Yes |
| softmax → group_limited_greedy | Softmax, group selection (max per group), topk, scale only (no renorm) | No | Yes |
| softmax → topk via gate.wg | Softmax, gate weight at `gate.wg.weight` (not `gate.weight`), always renormalize | No | Yes |
| softmax → topk + per_expert_scale | RMSNorm → scale → proj → softmax → topk → renorm → per-expert learned scales | Yes | Yes |
| fused topk → softmax | Routing + expert computation fused in a single kernel | No | Planned |
### Per-model support
@@ -102,10 +103,13 @@ All models use the **SwiGLU** activation (`act_fn(gate) * up`). Neither kernel c
| `ernie4_5_moe` | ERNIE 4.5 MoE | softmax → bias → topk | No | **Yes** |
| `deepseek_v2` | DeepSeek-V2 | softmax → group_limited_greedy | No | **Yes** |
| `hunyuan_v1_moe` | HunYuan V1 MoE | softmax → topk (gate.wg) | No | **Yes** |
| `gemma4_text` | Gemma 4 (26B-A4B) | softmax → topk + per_expert_scale | **Yes**\*\* | **Yes**\*\* |
| `gpt_oss` | GPT-OSS | fused topk → softmax | No | Planned |
\* `glm4_moe_lite` with ScatterMoE may have issues — see Limitations.
\*\* Gemma 4 uses `experts_implementation: scattermoe` path (registered via `ExpertsInterface`) instead of SparseMoeBlock patching, since Gemma 4 embeds MoE directly in its decoder layer (no separate SparseMoeBlock). See the [Gemma 4 section](#gemma-4) below.
### Feature comparison
| Feature | ScatterMoE | SonicMoE |
@@ -131,6 +135,21 @@ Both kernels handle shared experts identically. Shared expert attribute names ar
If `shared_expert_gate` exists, sigmoid gating is applied to the shared expert contribution before adding it to the routed output. PEFT wraps shared expert linear layers with standard LoRA — no special handling is needed.
## Gemma 4
Gemma 4 (e.g. `google/gemma-4-26B-A4B`) has a unique hybrid MoE architecture:
- **No SparseMoeBlock**: MoE is embedded directly in the decoder layer alongside a dense MLP. Both run in parallel and their outputs are summed.
- **Custom router** (`Gemma4TextRouter`): RMSNorm → learned scale → linear projection → softmax → top-k → renormalization → per-expert learned scales.
- **GEGLU activation**: Uses `gelu_pytorch_tanh` (not SiLU/SwiGLU like most other MoE models).
- **128 experts, top-k=8** for the 26B-A4B variant.
Because there is no SparseMoeBlock class to patch, Gemma 4 uses a different integration path: we register `"scattermoe"` as a custom implementation in the transformers `ExpertsInterface`, and set `experts_implementation: scattermoe` in the config. The `@use_experts_implementation` decorator on `Gemma4TextExperts` then dispatches to our ScatterMoE kernel automatically. The router is untouched — it runs as-is.
**Important limitations:**
- **Flash Attention 2 is not supported** — Gemma 4 uses `global_head_dim: 512` for full attention layers, which exceeds FA2's maximum head dimension of 256. Use `sdp_attention: true` instead.
- **Multimodal model**: Gemma 4 includes vision and audio encoders. For text-only SFT, use `lora_target_linear_modules` with a regex to restrict LoRA to the text backbone (e.g. `language_model\.model\.layers\.\d+\.self_attn\.(q|k|v|o)_proj`).
## Limitations
- **ScatterMoE + GLM4-MoE Lite**: ScatterMoE does not work reliably for GLM 4.7 Flash (`glm4_moe_lite`).

View File

@@ -34,12 +34,20 @@ class KernelsArgs(BaseModel):
@classmethod
def check_experts_implementation(cls, data):
experts_implementation = data.get("experts_implementation")
use_scattermoe = data.get("use_scattermoe", False)
if experts_implementation is None:
# transformers may default to batched_mm when unset
data["experts_implementation"] = "eager"
elif experts_implementation != "eager":
elif experts_implementation == "scattermoe" and not use_scattermoe:
LOG.warning(
"`experts_implementation` must be set to 'eager' to use this. Automatically setting it to 'eager'."
"`experts_implementation='scattermoe'` requires `use_scattermoe: true`. "
"Automatically setting to 'eager'."
)
data["experts_implementation"] = "eager"
elif experts_implementation not in ("eager", "scattermoe"):
LOG.warning(
f"`experts_implementation={experts_implementation!r}` is not compatible with "
f"custom MoE kernels. Automatically setting to 'eager'."
)
data["experts_implementation"] = "eager"

View File

@@ -11,6 +11,7 @@ Models with custom routing (see sonicmoe/routing.py for implementations):
- ernie4_5_moe: softmax→bias correction→topk (softmax_bias_topk_routing)
- deepseek_v2: softmax→group_limited_greedy (softmax_group_limited_topk_routing)
- hunyuan_v1_moe: softmax→topk via gate.wg (softmax_topk_wg_routing)
- gemma4_text: RMSNorm→scale→proj→softmax→topk→renorm→per_expert_scale (experts-level patch)
"""
import importlib
@@ -53,6 +54,49 @@ SPARSE_MOE_BLOCK = {
}
# Models where MoE is NOT in a separate SparseMoeBlock but embedded in the
# decoder layer. For these, we patch the Experts class forward directly
# (same signature: hidden_states, top_k_index, top_k_weights -> Tensor).
# Routing stays untouched — the original model router runs as-is.
EXPERTS_ONLY_BLOCK = {
# gemma4: hybrid MLP+MoE in decoder layer, custom Gemma4TextRouter,
# no SparseMoeBlock. Experts use @use_experts_implementation with
# standard 3D param layout (gate_up_proj [E, 2*I, H], down_proj [E, H, I]).
"gemma4_text": "Gemma4TextExperts",
}
def resolve_experts_class(model_type: str):
"""Resolve the Experts class for models that need experts-level patching.
Returns the class, or None if the model uses SparseMoeBlock-level patching.
"""
entry = EXPERTS_ONLY_BLOCK.get(model_type)
if entry is None:
return None
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
try:
module = importlib.import_module(module_path)
except ModuleNotFoundError:
if model_type.endswith("_text"):
parent_type = model_type.removesuffix("_text")
module_path = f"transformers.models.{parent_type}.modeling_{parent_type}"
module = importlib.import_module(module_path)
else:
raise
cls = getattr(module, entry, None)
if cls is None:
raise ValueError(f"Could not find class '{entry}' in '{module_path}'")
return cls
def is_experts_only_model(model_type: str) -> bool:
"""Check if a model type requires experts-level (not block-level) patching."""
return model_type in EXPERTS_ONLY_BLOCK
def resolve_moe_block_classes(model_type: str):
"""Resolve all MoE block classes from transformers for the given model type.

View File

@@ -0,0 +1,235 @@
"""
ScatterMoE-accelerated experts forward for Gemma4.
Gemma4 has no separate SparseMoeBlock — MoE is embedded in the decoder layer.
The decoder layer handles routing (Gemma4TextRouter) and calls
``experts(hidden_states, top_k_index, top_k_weights)`` directly.
This module registers a ``"scattermoe"`` implementation in the transformers
``ExpertsInterface``, which the ``@use_experts_implementation`` decorator
dispatches to when ``config._experts_implementation == "scattermoe"``.
This is the clean way to hook into transformers' MoE dispatch — no
monkeypatching required. Works for Gemma4 and any future model that uses
``@use_experts_implementation`` with the standard forward signature
``(hidden_states, top_k_index, top_k_weights) -> Tensor``.
"""
import torch
from .parallel_experts import flatten_sort_count, parallel_linear
from .parallel_linear_lora import get_lora_params_from_wrapper, parallel_linear_lora
def _has_peft_wrapper(module):
"""Check if a module's parameter has been wrapped by PEFT ParamWrapper."""
try:
from peft.tuners.param_wrapper import ParamWrapper
for attr in ("gate_up_proj", "down_proj"):
param = getattr(module, attr, None)
if isinstance(param, ParamWrapper):
return True
except ImportError:
pass
return False
def _unwrap_experts_lora(experts):
"""Extract base weights and LoRA params from a PEFT-wrapped Experts module.
Returns:
(base_experts, gup_lora, down_lora) where each lora is
(lora_A, lora_B, scaling) or None.
"""
try:
from peft.tuners.param_wrapper import ParamWrapper
except ImportError:
return experts, None, None
if not isinstance(getattr(experts, "gate_up_proj", None), ParamWrapper):
return experts, None, None
base_experts = experts
gup_lora = None
down_lora = None
gup_param = experts.gate_up_proj
if isinstance(gup_param, ParamWrapper):
lora_A, lora_B, scaling = get_lora_params_from_wrapper(gup_param)
if lora_A is not None:
num_experts = experts.num_experts
rank = lora_A.shape[0] // num_experts
from .layers import peft_lora_to_scattermoe
sm_A, sm_B = peft_lora_to_scattermoe(lora_A, lora_B, num_experts, rank)
gup_lora = (sm_A, sm_B, scaling)
down_param = experts.down_proj
if isinstance(down_param, ParamWrapper):
lora_A, lora_B, scaling = get_lora_params_from_wrapper(down_param)
if lora_A is not None:
num_experts = experts.num_experts
rank = lora_A.shape[0] // num_experts
from .layers import peft_lora_to_scattermoe
sm_A, sm_B = peft_lora_to_scattermoe(lora_A, lora_B, num_experts, rank)
down_lora = (sm_A, sm_B, scaling)
return base_experts, gup_lora, down_lora
def _get_base_param(param):
"""Get the base tensor from a PEFT ParamWrapper or regular Parameter."""
try:
from peft.tuners.param_wrapper import ParamWrapper
while isinstance(param, ParamWrapper):
param = param.original_parameter
except ImportError:
pass
return param
def _parallel_linear_maybe_lora(
x,
weight,
top_k,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
lora_tuple,
grouped_in,
grouped_out,
gates=None,
):
"""Call parallel_linear or parallel_linear_lora depending on whether LoRA is active."""
if lora_tuple is not None:
lora_A, lora_B, scaling = lora_tuple
return parallel_linear_lora(
x,
weight,
top_k,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
lora_A,
lora_B,
scaling,
grouped_in=grouped_in,
grouped_out=grouped_out,
gates=gates,
)
return parallel_linear(
x,
weight,
top_k,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
grouped_in=grouped_in,
grouped_out=grouped_out,
gates=gates,
)
def scattermoe_experts_forward(
self,
hidden_states: torch.Tensor,
top_k_index: torch.Tensor,
top_k_weights: torch.Tensor,
) -> torch.Tensor:
"""ScatterMoE-accelerated experts forward.
Drop-in replacement for the standard Experts forward signature used by
``@use_experts_implementation``-decorated classes (Gemma4, Mixtral, etc.):
``(hidden_states [T, H], top_k_index [T, K], top_k_weights [T, K]) -> [T, H]``
"""
K = top_k_index.shape[1]
routing_weights = top_k_weights.to(hidden_states.dtype)
sorted_expert_idxs, sorted_scattered_idxs, expert_offsets = flatten_sort_count(
top_k_index, num_experts=self.num_experts
)
# Get base weights (unwrap PEFT if needed)
gate_up_weight = _get_base_param(self.gate_up_proj).transpose(2, 1)
down_weight = _get_base_param(self.down_proj).transpose(2, 1)
# Extract LoRA params if PEFT is active
gup_lora, down_lora = None, None
if _has_peft_wrapper(self):
_, gup_lora, down_lora = _unwrap_experts_lora(self)
# Gate-up projection (with optional LoRA)
gates_h = _parallel_linear_maybe_lora(
hidden_states,
gate_up_weight,
K,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
gup_lora,
grouped_in=False,
grouped_out=True,
)
gates, h = gates_h.chunk(2, dim=-1)
h = self.act_fn(gates) * h
# Down projection (with optional LoRA + routing weights)
output = _parallel_linear_maybe_lora(
h,
down_weight,
1,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
down_lora,
grouped_in=True,
grouped_out=False,
gates=routing_weights,
)
return output
def register_scattermoe_experts():
"""Register ``"scattermoe"`` in the transformers ExpertsInterface.
After calling this, any model with ``@use_experts_implementation`` will
dispatch to ScatterMoE when ``config._experts_implementation == "scattermoe"``.
Also patches ``get_correct_experts_implementation`` to accept ``"scattermoe"``
as a valid value (transformers hardcodes an allowlist).
"""
from transformers.integrations.moe import ALL_EXPERTS_FUNCTIONS
from transformers.modeling_utils import PreTrainedModel
# 1. Register the forward function in the global interface
ALL_EXPERTS_FUNCTIONS.register("scattermoe", scattermoe_experts_forward)
# 2. Patch the validation to accept "scattermoe"
_original_get_correct = PreTrainedModel.get_correct_experts_implementation
def _patched_get_correct(self_model, requested_experts: str | None) -> str:
if requested_experts == "scattermoe":
return "scattermoe"
return _original_get_correct(self_model, requested_experts)
PreTrainedModel.get_correct_experts_implementation = _patched_get_correct
# Legacy monkeypatch approach (kept for backward compat with existing tests)
def patch_gemma4_scattermoe():
"""Monkeypatch Gemma4TextExperts.forward with ScatterMoE kernel."""
from axolotl.integrations.kernels.constants import resolve_experts_class
experts_cls = resolve_experts_class("gemma4_text")
if experts_cls is None:
raise ValueError("Could not resolve Gemma4TextExperts class")
if hasattr(experts_cls, "_original_forward"):
return # already patched
experts_cls._original_forward = experts_cls.forward
experts_cls.forward = scattermoe_experts_forward

View File

@@ -0,0 +1,106 @@
"""
SonicMoE-accelerated experts forward for Gemma4.
Gemma4 has no separate SparseMoeBlock — MoE is embedded in the decoder layer.
This module provides a drop-in replacement for ``Gemma4TextExperts.forward``
that uses SonicMoE kernels while preserving the original call signature.
"""
import torch
from .lora import has_lora, materialize_expert_lora, unwrap_experts_lora
def _get_expert_weights_gemma4(experts_module):
"""Extract expert weights from Gemma4TextExperts, applying LoRA if active.
Returns:
(gate_up_weight, down_weight) in SonicMoE layout [dim, dim, E].
"""
if has_lora(experts_module):
base_experts, lora_dict = unwrap_experts_lora(experts_module)
gate_up = materialize_expert_lora(
base_experts.gate_up_proj, lora_dict.get("gate_up_proj")
)
down = materialize_expert_lora(
base_experts.down_proj, lora_dict.get("down_proj")
)
else:
gate_up = experts_module.gate_up_proj
down = experts_module.down_proj
# Permute to SonicMoE layout:
# gate_up: [E, 2*I, H] -> [2*I, H, E]
# down: [E, H, I] -> [H, I, E]
return gate_up.permute(1, 2, 0), down.permute(1, 2, 0)
def gemma4_sonicmoe_experts_forward(
self,
hidden_states: torch.Tensor,
top_k_index: torch.Tensor,
top_k_weights: torch.Tensor,
) -> torch.Tensor:
"""SonicMoE-accelerated replacement for Gemma4TextExperts.forward.
Same signature as the original: (hidden_states [T, H], top_k_index [T, K],
top_k_weights [T, K]) -> output [T, H].
"""
from sonicmoe import moe_general_routing_inputs
from sonicmoe.enums import ActivationType
T, _ = hidden_states.shape
K = top_k_index.shape[1]
E = self.num_experts
# Convert routing outputs to SonicMoE's flat format
# Token indices sorted ascending (required by SonicMoE)
token_indices = (
torch.arange(T, device=hidden_states.device, dtype=torch.int32)
.unsqueeze(1)
.expand(T, K)
)
flat_scores = top_k_weights.to(torch.float32).reshape(-1) # [T*K]
flat_token_idx = token_indices.reshape(-1) # [T*K]
flat_expert_idx = top_k_index.to(torch.int32).reshape(-1) # [T*K]
# Get weights (with LoRA materialization if needed)
gate_up_weight, down_weight = _get_expert_weights_gemma4(self)
gate_up_weight = gate_up_weight.to(hidden_states.dtype)
down_weight = down_weight.to(hidden_states.dtype)
if not torch.cuda.is_available():
raise RuntimeError("SonicMoE requires CUDA. No CUDA device available.")
cuda_stream = torch.cuda.current_stream().cuda_stream
output, _ = moe_general_routing_inputs(
hidden_states,
flat_scores,
flat_token_idx,
flat_expert_idx,
gate_up_weight,
None, # b1 (no gate/up bias)
down_weight,
None, # b2 (no down bias)
E,
cuda_stream,
ActivationType.GEGLU,
False, # is_inference_mode
)
return output
def patch_gemma4_sonicmoe():
"""Monkeypatch Gemma4TextExperts.forward with SonicMoE kernel."""
from axolotl.integrations.kernels.constants import resolve_experts_class
experts_cls = resolve_experts_class("gemma4_text")
if experts_cls is None:
raise ValueError("Could not resolve Gemma4TextExperts class")
if hasattr(experts_cls, "_original_forward"):
return # already patched
experts_cls._original_forward = experts_cls.forward
experts_cls.forward = gemma4_sonicmoe_experts_forward

View File

@@ -7,6 +7,7 @@ Different MoE architectures use different routing strategies:
- glm_moe_dsa / deepseek_v3 / minimax_m2: sigmoid -> topk (with group-based expert selection)
- ernie4_5_moe: softmax -> bias correction -> topk -> gather (softmax_bias_topk_routing)
- hunyuan_v1_moe: softmax -> topk via gate.wg (softmax_topk_wg_routing)
- gemma4_text: RMSNorm -> scale -> proj -> softmax -> topk -> renorm -> per_expert_scale (gemma4_routing)
- gpt_oss: topk -> softmax (uses fused moe_TC_softmax_topk_layer, routing_fn=None) [NOT YET SUPPORTED]
Each model type maps to a (routing_fn, activation_type, router_attr) triple.
@@ -66,6 +67,8 @@ def get_model_moe_config(model_type: str):
return softmax_bias_topk_routing, ActivationType.SWIGLU, "gate"
elif model_type in ("hunyuan_v1_moe",):
return softmax_topk_wg_routing, ActivationType.SWIGLU, "gate"
elif model_type in ("gemma4_text",):
return gemma4_routing, ActivationType.GEGLU, "router"
# Fused topk -> softmax path (routing_fn=None):
# elif model_type in ("gpt_oss",):
# # NOTE: gpt_oss has a router bias which moe_TC_softmax_topk_layer
@@ -501,3 +504,73 @@ def softmax_topk_wg_routing(
flat_expert_idx = top_indices.to(torch.int32).reshape(-1) # [T*K]
return flat_scores, flat_token_idx, flat_expert_idx, router_logits
def gemma4_routing(
hidden_states: torch.Tensor, moe_block
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Gemma4-style routing: RMSNorm → scale → proj → softmax → topk → renorm → per_expert_scale.
Gemma4's router (``Gemma4TextRouter``) has a unique structure:
1. RMSNorm (without learnable scale) on hidden states
2. Multiply by ``scale * hidden_size**-0.5``
3. Linear projection to expert scores
4. Softmax → topk
5. Normalize top-k weights to sum to 1
6. Multiply by per-expert learned scales
The router lives at ``moe_block.router`` (not ``moe_block.gate``).
LoRA on the router targets ``router.proj`` (nn.Linear).
Args:
hidden_states: [T, H] flattened token representations
moe_block: MoE block module (accesses moe_block.router)
Returns:
router_scores: [T*K] flattened scores (float32)
token_indices: [T*K] which token each entry belongs to (int32), sorted ascending
expert_indices: [T*K] which expert (int32)
router_logits: [T, E] original logits for aux loss
"""
router = moe_block.router
# Unwrap PEFT LoRA on router.proj (the nn.Linear)
_, proj_weight, proj_lora_delta = unwrap_gate_lora(router.proj)
T, _ = hidden_states.shape
K = router.top_k if hasattr(router, "top_k") else router.config.top_k_experts
# Reproduce Gemma4TextRouter.forward:
# 1. RMSNorm (no scale) + scale param * hidden_size**-0.5
normed = router.norm(hidden_states)
scaled = normed * router.scale * router.scalar_root_size
# 2. Project to expert scores
router_logits = F.linear(scaled.float(), proj_weight.float()) # [T, E]
if proj_lora_delta is not None:
router_logits = router_logits + F.linear(
scaled.float(), proj_lora_delta.float()
)
# 3. Softmax → topk
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
top_values, top_indices = torch.topk(router_probs, K, dim=-1) # [T, K]
# 4. Normalize top-k weights
top_values = top_values / top_values.sum(dim=-1, keepdim=True)
# 5. Per-expert scale
top_values = top_values * router.per_expert_scale[top_indices]
# Flatten for moe_general_routing_inputs
token_indices = (
torch.arange(T, device=hidden_states.device, dtype=torch.int32)
.unsqueeze(1)
.expand(T, K)
)
flat_scores = top_values.to(torch.float32).reshape(-1) # [T*K]
flat_token_idx = token_indices.reshape(-1) # [T*K]
flat_expert_idx = top_indices.to(torch.int32).reshape(-1) # [T*K]
return flat_scores, flat_token_idx, flat_expert_idx, router_logits

View File

@@ -61,20 +61,31 @@ class KernelsPlugin(BasePlugin):
return "axolotl.integrations.kernels.KernelsArgs"
def pre_model_load(self, cfg):
from axolotl.integrations.kernels.constants import SPARSE_MOE_BLOCK
from axolotl.integrations.kernels.constants import (
SPARSE_MOE_BLOCK,
is_experts_only_model,
)
# Prefer text backbone type for VLMs, but fall back to base type
# when the text type isn't in the supported mapping (e.g. qwen3_5_moe_text)
moe_model_type = cfg.model_config_type_text or cfg.model_config_type
if (
moe_model_type not in SPARSE_MOE_BLOCK
and not is_experts_only_model(moe_model_type)
and cfg.model_config_type in SPARSE_MOE_BLOCK
):
moe_model_type = cfg.model_config_type
if cfg.use_scattermoe:
self._register_kernels()
self._kernelize_model(moe_model_type)
if is_experts_only_model(moe_model_type):
# Models like Gemma4 where MoE is embedded in the decoder layer
# — register ScatterMoE in the ExpertsInterface so that
# @use_experts_implementation dispatches to it.
self._register_experts_interface()
cfg.experts_implementation = "scattermoe"
else:
self._kernelize_model(moe_model_type)
elif cfg.use_sonicmoe:
if not importlib.util.find_spec("sonicmoe"):
raise RuntimeError(
@@ -84,14 +95,24 @@ class KernelsPlugin(BasePlugin):
_check_sonicmoe_gpu_compat()
from axolotl.integrations.kernels.libs.sonicmoe import patch_sonicmoe
if is_experts_only_model(moe_model_type):
from axolotl.integrations.kernels.libs.sonicmoe.gemma4_experts import (
patch_gemma4_sonicmoe,
)
LOG.info(f"Applying SonicMoE patches for model type: {moe_model_type}")
patch_sonicmoe(
moe_model_type,
torch_compile=bool(getattr(cfg, "torch_compile", False)),
base_model_type=cfg.model_config_type,
)
LOG.info(
f"Applying SonicMoE experts-level patch for model type: {moe_model_type}"
)
patch_gemma4_sonicmoe()
else:
from axolotl.integrations.kernels.libs.sonicmoe import patch_sonicmoe
LOG.info(f"Applying SonicMoE patches for model type: {moe_model_type}")
patch_sonicmoe(
moe_model_type,
torch_compile=bool(getattr(cfg, "torch_compile", False)),
base_model_type=cfg.model_config_type,
)
def _register_kernels(self):
from kernels import (
@@ -139,3 +160,16 @@ class KernelsPlugin(BasePlugin):
replace_kernel_forward_from_hub(
model_moe_cls, "HFScatterMoEParallelExperts"
)
def _register_experts_interface(self):
"""Register ScatterMoE in the transformers ExpertsInterface.
This allows @use_experts_implementation-decorated Experts classes
to dispatch to ScatterMoE when config._experts_implementation == "scattermoe".
"""
from axolotl.integrations.kernels.libs.scattermoe_lora.gemma4_experts import (
register_scattermoe_experts,
)
register_scattermoe_experts()
LOG.info("Registered 'scattermoe' in transformers ExpertsInterface")

View File

@@ -73,6 +73,44 @@ QKV_PATCHES = [
query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2)
key_states = self.k_norm(key_states.view(hidden_shape)).transpose(1, 2)
value_states = value_states.view(hidden_shape).transpose(1, 2)
""".lstrip("\n"),
),
# Gemma4: norm between proj and transpose, RoPE between norm and transpose,
# conditional KV sharing (is_kv_shared_layer), v_proj may be None (attention_k_eq_v).
# We only fuse the projection calls; norms, RoPE, and KV sharing stay as-is.
(
"""
query_states = self.q_proj(hidden_states).view(hidden_shape)
query_states = self.q_norm(query_states)
query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2)
query_states = query_states.transpose(1, 2)
# For layers with shared KV (from kv sharing point onwards), we reuse the same keys/values states as the last non-sharing layer
if self.is_kv_shared_layer and past_key_values is not None:
key_states, value_states = past_key_values.shared_layers[self.kv_shared_layer_index]
# Device of past layer may be different from current one
key_states = key_states.to(query_states.device)
value_states = value_states.to(query_states.device)
else:
key_states = self.k_proj(hidden_states).view(hidden_shape)
value_states = self.v_proj(hidden_states).view(hidden_shape) if self.v_proj is not None else key_states
""".lstrip("\n"),
"""
query_states, key_states, value_states = self.apply_qkv(hidden_states)
query_states = query_states.view(hidden_shape)
query_states = self.q_norm(query_states)
query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2)
query_states = query_states.transpose(1, 2)
# For layers with shared KV (from kv sharing point onwards), we reuse the same keys/values states as the last non-sharing layer
if self.is_kv_shared_layer and past_key_values is not None:
key_states, value_states = past_key_values.shared_layers[self.kv_shared_layer_index]
# Device of past layer may be different from current one
key_states = key_states.to(query_states.device)
value_states = value_states.to(query_states.device)
else:
key_states = key_states.view(hidden_shape)
value_states = value_states.view(hidden_shape) if self.v_proj is not None else key_states
""".lstrip("\n"),
),
]
@@ -113,6 +151,23 @@ def original_apply_qkv(
return query_states, key_states, value_states
def original_apply_qkv_optional_v(
self: nn.Module, hidden_states: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""QKV projection for models where v_proj may be None (e.g. Gemma4 attention_k_eq_v).
When v_proj is None, key_states are reused as value_states.
"""
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
if self.v_proj is not None:
value_states = self.v_proj(hidden_states)
else:
value_states = key_states
return query_states, key_states, value_states
def original_apply_o(self: nn.Module, hidden_states: torch.Tensor) -> torch.Tensor:
"""
Original implementation of output projection without optimizations.
@@ -183,6 +238,11 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
return Gemma3Attention
if model_type in ("gemma4", "gemma4_text"):
from transformers.models.gemma4.modeling_gemma4 import Gemma4TextAttention
return Gemma4TextAttention
try:
# Dynamically import the module and attention class
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
@@ -410,14 +470,24 @@ def apply_lora_kernel_patches(
# Add QKV, O fallback implementations to start
# These will be overwritten later (if some conditions apply)
for self_attn in find_self_attn_in_layer(layer):
self_attn.apply_qkv = types.MethodType(original_apply_qkv, self_attn)
# Use v_proj-optional fallback for models where v_proj can be None
# (e.g. Gemma4 with attention_k_eq_v=True)
if getattr(self_attn, "v_proj", None) is None:
self_attn.apply_qkv = types.MethodType(
original_apply_qkv_optional_v, self_attn
)
else:
self_attn.apply_qkv = types.MethodType(original_apply_qkv, self_attn)
self_attn.apply_o = types.MethodType(original_apply_o, self_attn)
if cfg.lora_qkv_kernel:
# Query, key, value patching
# Filter out None projections (e.g. Gemma4 v_proj when attention_k_eq_v=True)
proj_names = ["q_proj", "k_proj", "v_proj"]
layer_modules = [
getattr(self_attn, linear_proj)
for linear_proj in ["q_proj", "k_proj", "v_proj"]
getattr(self_attn, name)
for name in proj_names
if getattr(self_attn, name, None) is not None
]
can_patch_qkv = all(
hasattr(module, "lora_A") for module in layer_modules

View File

@@ -111,7 +111,6 @@ def patch_qwen3_next_gateddelta_layer():
"""Patch Qwen3NextGatedDeltaNet to parse cu_seqlens and pass to chunk_gated_delta_rule"""
try:
from transformers.models.qwen3_next.modeling_qwen3_next import (
Qwen3NextDynamicCache,
Qwen3NextGatedDeltaNet,
apply_mask_to_padding_states,
)
@@ -125,8 +124,7 @@ def patch_qwen3_next_gateddelta_layer():
def patched_gated_delta_net_forward(
self,
hidden_states: torch.Tensor,
cache_params: Optional[Qwen3NextDynamicCache] = None,
cache_position: Optional[torch.LongTensor] = None,
cache_params=None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
):
@@ -137,9 +135,8 @@ def patch_qwen3_next_gateddelta_layer():
use_precomputed_states = (
cache_params is not None
and cache_params.has_previous_state
and cache_params.has_previous_state(self.layer_idx)
and seq_len == 1
and cache_position is not None
)
# Compute cu_seqlens early for use by both causal_conv1d and chunk_gated_delta_rule
@@ -148,9 +145,9 @@ def patch_qwen3_next_gateddelta_layer():
cu_seqlens = get_cu_seqlens(position_ids=position_ids)
# getting projected states from cache if it exists
if cache_params is not None:
conv_state = cache_params.conv_states[self.layer_idx]
recurrent_state = cache_params.recurrent_states[self.layer_idx]
if use_precomputed_states:
conv_state = cache_params.layers[self.layer_idx].conv_states
recurrent_state = cache_params.layers[self.layer_idx].recurrent_states
projected_states_qkvz = self.in_proj_qkvz(hidden_states)
projected_states_ba = self.in_proj_ba(hidden_states)
@@ -162,10 +159,9 @@ def patch_qwen3_next_gateddelta_layer():
)
mixed_qkv = torch.cat((query, key, value), dim=-1) # [B, T, D]
mixed_qkv = mixed_qkv.transpose(1, 2) # [B, D, T]
if use_precomputed_states:
# Inference single-token path: causal_conv1d_update expects [B, D, T]
mixed_qkv = mixed_qkv.transpose(1, 2)
mixed_qkv = self.causal_conv1d_update(
mixed_qkv,
conv_state,
@@ -173,19 +169,17 @@ def patch_qwen3_next_gateddelta_layer():
self.conv1d.bias,
self.activation,
)
mixed_qkv = mixed_qkv.transpose(1, 2)
else:
if cache_params is not None:
# Cache state expects [B, D, T] for the inference update path
mixed_qkv_t = mixed_qkv.transpose(1, 2)
conv_state = F.pad(
mixed_qkv_t,
(self.conv_kernel_size - mixed_qkv_t.shape[-1], 0),
mixed_qkv,
(self.conv_kernel_size - mixed_qkv.shape[-1], 0),
)
cache_params.conv_states[self.layer_idx] = conv_state
cache_params.update_conv_state(conv_state, self.layer_idx)
if fla_causal_conv1d is not None:
# FLA Triton causal_conv1d: [B, T, D] in/out, with cu_seqlens support
mixed_qkv = mixed_qkv.transpose(1, 2) # [B, T, D] for FLA
mixed_qkv, _ = fla_causal_conv1d(
x=mixed_qkv,
weight=self.conv1d.weight.squeeze(1),
@@ -193,6 +187,15 @@ def patch_qwen3_next_gateddelta_layer():
activation=self.activation,
cu_seqlens=cu_seqlens,
)
mixed_qkv = mixed_qkv.transpose(1, 2) # back to [B, D, T]
elif self.causal_conv1d_fn is not None:
mixed_qkv = self.causal_conv1d_fn(
x=mixed_qkv,
weight=self.conv1d.weight.squeeze(1),
bias=self.conv1d.bias,
activation=self.activation,
seq_idx=None,
)
else:
# PyTorch fallback (no cu_seqlens support)
if cu_seqlens is not None and cu_seqlens.shape[0] > batch_size + 1:
@@ -203,11 +206,9 @@ def patch_qwen3_next_gateddelta_layer():
LOG.warning_once(
"FLA causal_conv1d not available. Falling back to PyTorch conv1d."
)
mixed_qkv = mixed_qkv.transpose(1, 2)
mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len])
mixed_qkv = mixed_qkv.transpose(1, 2)
# mixed_qkv is [B, T, D] in all paths
mixed_qkv = mixed_qkv.transpose(1, 2) # [B, T, D]
query, key, value = torch.split(
mixed_qkv,
[
@@ -255,7 +256,7 @@ def patch_qwen3_next_gateddelta_layer():
# Update cache
if cache_params is not None:
cache_params.recurrent_states[self.layer_idx] = last_recurrent_state
cache_params.update_recurrent_state(last_recurrent_state, self.layer_idx)
z_shape_og = z.shape
# reshape input data into 2D tensor

View File

@@ -0,0 +1,271 @@
{%- macro format_parameters(properties, required) -%}
{%- set standard_keys = ['description', 'type', 'properties', 'required', 'nullable'] -%}
{%- set ns = namespace(found_first=false) -%}
{%- for key, value in properties | dictsort -%}
{%- set add_comma = false -%}
{%- if key not in standard_keys -%}
{%- if ns.found_first %},{% endif -%}
{%- set ns.found_first = true -%}
{{ key }}:{
{%- if value['description'] -%}
description:<|"|>{{ value['description'] }}<|"|>
{%- set add_comma = true -%}
{%- endif -%}
{%- if value['nullable'] %}
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
nullable:true
{%- endif -%}
{%- if value['type'] | upper == 'STRING' -%}
{%- if value['enum'] -%}
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
enum:{{ format_argument(value['enum']) }}
{%- endif -%}
{%- elif value['type'] | upper == 'OBJECT' -%}
,properties:{
{%- if value['properties'] is defined and value['properties'] is mapping -%}
{{- format_parameters(value['properties'], value['required'] | default([])) -}}
{%- elif value is mapping -%}
{{- format_parameters(value, value['required'] | default([])) -}}
{%- endif -%}
}
{%- if value['required'] -%}
,required:[
{%- for item in value['required'] | default([]) -%}
<|"|>{{- item -}}<|"|>
{%- if not loop.last %},{% endif -%}
{%- endfor -%}
]
{%- endif -%}
{%- elif value['type'] | upper == 'ARRAY' -%}
{%- if value['items'] is mapping and value['items'] -%}
,items:{
{%- set ns_items = namespace(found_first=false) -%}
{%- for item_key, item_value in value['items'] | dictsort -%}
{%- if item_value is not none -%}
{%- if ns_items.found_first %},{% endif -%}
{%- set ns_items.found_first = true -%}
{%- if item_key == 'properties' -%}
properties:{
{%- if item_value is mapping -%}
{{- format_parameters(item_value, value['items']['required'] | default([])) -}}
{%- endif -%}
}
{%- elif item_key == 'required' -%}
required:[
{%- for req_item in item_value -%}
<|"|>{{- req_item -}}<|"|>
{%- if not loop.last %},{% endif -%}
{%- endfor -%}
]
{%- elif item_key == 'type' -%}
{%- if item_value is string -%}
type:{{ format_argument(item_value | upper) }}
{%- else -%}
type:{{ format_argument(item_value | map('upper') | list) }}
{%- endif -%}
{%- else -%}
{{ item_key }}:{{ format_argument(item_value) }}
{%- endif -%}
{%- endif -%}
{%- endfor -%}
}
{%- endif -%}
{%- endif -%}
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
type:<|"|>{{ value['type'] | upper }}<|"|>}
{%- endif -%}
{%- endfor -%}
{%- endmacro -%}
{%- macro format_function_declaration(tool_data) -%}
declaration:{{- tool_data['function']['name'] -}}{description:<|"|>{{- tool_data['function']['description'] -}}<|"|>
{%- set params = tool_data['function']['parameters'] -%}
{%- if params -%}
,parameters:{
{%- if params['properties'] -%}
properties:{ {{- format_parameters(params['properties'], params['required']) -}} },
{%- endif -%}
{%- if params['required'] -%}
required:[
{%- for item in params['required'] -%}
<|"|>{{- item -}}<|"|>
{{- ',' if not loop.last -}}
{%- endfor -%}
],
{%- endif -%}
{%- if params['type'] -%}
type:<|"|>{{- params['type'] | upper -}}<|"|>}
{%- endif -%}
{%- endif -%}
{%- if 'response' in tool_data['function'] -%}
{%- set response_declaration = tool_data['function']['response'] -%}
,response:{
{%- if response_declaration['description'] -%}
description:<|"|>{{- response_declaration['description'] -}}<|"|>,
{%- endif -%}
{%- if response_declaration['type'] | upper == 'OBJECT' -%}
type:<|"|>{{- response_declaration['type'] | upper -}}<|"|>}
{%- endif -%}
{%- endif -%}
}
{%- endmacro -%}
{%- macro format_argument(argument, escape_keys=True) -%}
{%- if argument is string -%}
{{- '<|"|>' + argument + '<|"|>' -}}
{%- elif argument is boolean -%}
{{- 'true' if argument else 'false' -}}
{%- elif argument is mapping -%}
{{- '{' -}}
{%- set ns = namespace(found_first=false) -%}
{%- for key, value in argument | dictsort -%}
{%- if ns.found_first %},{% endif -%}
{%- set ns.found_first = true -%}
{%- if escape_keys -%}
{{- '<|"|>' + key + '<|"|>' -}}
{%- else -%}
{{- key -}}
{%- endif -%}
:{{- format_argument(value, escape_keys=escape_keys) -}}
{%- endfor -%}
{{- '}' -}}
{%- elif argument is sequence -%}
{{- '[' -}}
{%- for item in argument -%}
{{- format_argument(item, escape_keys=escape_keys) -}}
{%- if not loop.last %},{% endif -%}
{%- endfor -%}
{{- ']' -}}
{%- else -%}
{{- argument -}}
{%- endif -%}
{%- endmacro -%}
{#- Removes '<|channel>...<channel|>' thinking blocks from model output.
Splits on the end token '<channel|>', then checks each part for the start
token '<|channel>' and keeps only the text before it. -#}
{%- macro strip_thinking(text) -%}
{%- set ns = namespace(cleaned='') -%}
{%- for part in text.split('<channel|>') -%}
{%- if '<|channel>' in part -%}
{%- set ns.cleaned = ns.cleaned + part.split('<|channel>')[0] -%}
{%- else -%}
{%- set ns.cleaned = ns.cleaned + part -%}
{%- endif -%}
{%- endfor -%}
{{- ns.cleaned | trim -}}
{%- endmacro -%}
{%- set ns = namespace(prev_message_type=None) -%}
{%- set loop_messages = messages -%}
{{ bos_token }}
{#- Handle System/Tool Definitions Block -#}
{%- if (enable_thinking is defined and enable_thinking) or tools or messages[0]['role'] in ['system', 'developer'] -%}
{{- '<|turn>system\n' -}}
{#- Inject Thinking token at the very top of the FIRST system turn -#}
{%- if enable_thinking is defined and enable_thinking -%}
{{- '<|think|>' -}}
{%- set ns.prev_message_type = 'think' -%}
{%- endif -%}
{%- if messages[0]['role'] in ['system', 'developer'] -%}
{{- messages[0]['content'] | trim -}}
{%- set loop_messages = messages[1:] -%}
{%- endif -%}
{%- if tools -%}
{%- for tool in tools %}
{{- '<|tool>' -}}
{{- format_function_declaration(tool) | trim -}}
{{- '<tool|>' -}}
{%- endfor %}
{%- set ns.prev_message_type = 'tool' -%}
{%- endif -%}
{{- '<turn|>\n' -}}
{%- endif %}
{#- Loop through messages -#}
{%- for message in loop_messages -%}
{#- Reset so only special message types (tool_call, image, etc.) influence
the generation prompt formatting below. Plain text leaves it as None. -#}
{%- set ns.prev_message_type = None -%}
{%- set role = 'model' if message['role'] == 'assistant' else message['role'] -%}
{{- '<|turn>' + role + '\n' }}
{%- if message['tool_calls'] -%}
{%- for tool_call in message['tool_calls'] -%}
{%- set function = tool_call['function'] -%}
{{- '<|tool_call>call:' + function['name'] + '{' -}}
{%- if function['arguments'] is mapping -%}
{%- set ns_args = namespace(found_first=false) -%}
{%- for key, value in function['arguments'] | dictsort -%}
{%- if ns_args.found_first %},{% endif -%}
{%- set ns_args.found_first = true -%}
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
{%- endfor -%}
{%- elif function['arguments'] is string -%}
{{- function['arguments'] -}}
{%- endif -%}
{{- '}<tool_call|>' -}}
{%- endfor -%}
{%- set ns.prev_message_type = 'tool_call' -%}
{%- endif -%}
{%- if message['tool_responses'] -%}
{#- Tool Response handling -#}
{%- for tool_response in message['tool_responses'] -%}
{{- '<|tool_response>' -}}
{%- if tool_response['response'] is mapping -%}
{{- 'response:' + tool_response['name'] | default('unknown') + '{' -}}
{%- for key, value in tool_response['response'] | dictsort -%}
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
{%- if not loop.last %},{% endif -%}
{%- endfor -%}
{{- '}' -}}
{%- else -%}
{{- 'response:' + tool_response['name'] | default('unknown') + '{value:' + format_argument(tool_response['response'], escape_keys=False) + '}' -}}
{%- endif -%}
{{- '<tool_response|>' -}}
{%- endfor -%}
{%- set ns.prev_message_type = 'tool_response' -%}
{%- endif -%}
{%- if message['content'] is string -%}
{%- if role == 'model' -%}
{{- strip_thinking(message['content']) -}}
{%- else -%}
{{- message['content'] | trim -}}
{%- endif -%}
{%- elif message['content'] is sequence -%}
{%- for item in message['content'] -%}
{%- if item['type'] == 'text' -%}
{%- if role == 'model' -%}
{{- strip_thinking(item['text']) -}}
{%- else -%}
{{- item['text'] | trim -}}
{%- endif -%}
{%- elif item['type'] == 'image' -%}
{{- '\n\n<|image|>\n\n' -}}
{%- set ns.prev_message_type = 'image' -%}
{%- elif item['type'] == 'audio' -%}
{{- '<|audio|>' -}}
{%- set ns.prev_message_type = 'audio' -%}
{%- elif item['type'] == 'video' -%}
{{- '\n\n<|video|>\n\n' -}}
{%- set ns.prev_message_type = 'video' -%}
{%- endif -%}
{%- endfor -%}
{%- endif -%}
{%- if not (message['tool_responses'] and not message['content']) -%}
{{- '<turn|>\n' -}}
{%- endif -%}
{%- endfor -%}
{%- if add_generation_prompt -%}
{%- if ns.prev_message_type != 'tool_response' -%}
{{- '<|turn>model\n' -}}
{%- endif -%}
{%- if not enable_thinking | default(false) -%}
{{- '<|channel>thought\n<channel|>' -}}
{%- endif -%}
{%- endif -%}

View File

@@ -72,6 +72,7 @@ class ChatTemplate(str, Enum):
qwen2_vl = "qwen2_vl"
gemma3 = "gemma3"
gemma3n = "gemma3n"
gemma4 = "gemma4"
command_a = "command_a"
command_a_tool_use = "command_a_tool_use"
command_a_rag = "command_a_rag"

View File

@@ -69,9 +69,7 @@ class TRLConfig(BaseModel):
generation_batch_size: int | None = Field(
default=None,
json_schema_extra={
"description": "Batch size for generation. Controls how many unique "
"prompts are generated per step. For full DP utilization, set to "
"num_generations * data_parallel_size (or a multiple thereof)."
"description": "Batch size for generation. Controls how many unique prompts are generated per step. Should be num_generations * data_parallel_size for full DP utilization."
},
)
num_generations: int | None = Field(

File diff suppressed because it is too large Load Diff