Model Architectures — Agent Reference
+Model-specific quirks, required settings, and known issues. Check this before debugging training failures on specific model families.
+Gemma 4
+Models: google/gemma-4-26B-A4B (MoE), google/gemma-4-31B (dense), google/gemma-4-E2B, google/gemma-4-E4B
Architecture: Multimodal wrapper (Gemma4ForConditionalGeneration) over a text backbone (Gemma4TextModel), with optional vision/audio encoders. All Gemma4 HF repos have model_type: "gemma4" — even text-only variants load as multimodal with a vision tower.
Required settings
+# Always needed for Gemma4:
+freeze_mm_modules: true # Freeze vision/audio encoders for text-only training
+gradient_checkpointing_kwargs:
+ use_reentrant: false # Shared per-layer norms cause "marked ready twice" with reentrant
+
+# LoRA target — restrict to language model only (DO NOT use lora_target_linear: true):
+lora_target_modules: 'model.language_model.layers.[\d]+.(_checkpoint_wrapped_module.)?(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'Auto-detection
+Axolotl auto-detects Gemma4 and applies:
+- use_reentrant: false for gradient checkpointing
+- ddp_find_unused_parameters: true for DDP (skipped when activation_offloading: true)
Multi-GPU
+| Strategy | +Works? | +Notes | +
|---|---|---|
| DDP | +Yes | +Auto-sets ddp_find_unused_parameters=True |
+
| DDP + activation_offloading | +Yes | +find_unused_parameters is skipped (conflicts with checkpoint wrappers) |
+
| FSDP1 | +No | +OOM during dequantization/sharding with QLoRA | +
| FSDP2 | +Yes | +Use Gemma4TextDecoderLayer (not Gemma4DecoderLayer) as wrap class |
+
| FSDP2 + activation_offloading | +Yes | +Lowest VRAM (~26 GiB/GPU for 26B-A4B) | +
FSDP2 config:
+fsdp:
+ - full_shard
+ - auto_wrap
+fsdp_config:
+ fsdp_version: 2
+ fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
+ fsdp_transformer_layer_cls_to_wrap: Gemma4TextDecoderLayerMoE (26B-A4B)
+-
+
enable_moe_block: true, 256 experts, top-k routing
+No separate
SparseMoeBlock— MoE is embedded in each decoder layer
+Expert LoRA targets 3D parameter tensors:
+lora_target_parameters: + - experts.gate_up_proj + - experts.down_proj
+ScatterMoE kernel acceleration:
+plugins: + - axolotl.integrations.kernels.KernelsPlugin +use_kernels: true +use_scattermoe: true +experts_implementation: scattermoe
+
Common issues
+| Symptom | +Cause | +Fix | +
|---|---|---|
mm_token_type_ids is required in DDP |
+model.config not accessible through DDP wrapper |
+Already fixed — unwrap_model() in compute_loss and prediction_step |
+
marked a variable ready twice in DDP |
+ddp_find_unused_parameters=True + activation_offloading checkpoint wrappers |
+Auto-handled — find_unused_parameters is skipped when activation_offloading: true |
+
| Loss ~12 instead of ~0.5 | +Using lora_target_linear: true (applies LoRA to vision/audio modules) |
+Use the regex lora_target_modules pattern instead |
+
FSDP2 Could not find Gemma4AudioLayer |
+Auto-wrap detects _no_split_modules including audio layers that don’t exist |
+Explicitly set fsdp_transformer_layer_cls_to_wrap: Gemma4TextDecoderLayer |
+
Gemma4ClippableLinear not supported by PEFT |
+Vision tower uses a non-standard linear wrapper | +Axolotl patches this automatically via _patch_peft_clippable_linear() |
+
E2B/E4B dense models
+These have hidden_size_per_layer_input: 256 (per-layer input embeddings) and attention_k_eq_v: False. Known issue: loss starts higher than expected (~12 vs ~0.5 for 26B). Root cause under investigation — may be related to the per-layer input mechanism or the Gemma4ForConditionalGeneration loss computation.
Gemma 3
+Models: google/gemma-3-*
-
+
ddp_find_unused_parameters: trueneeded (multimodal unused params)
+use_reentrant: falserecommended
+- Attention mask must be dropped for sample packing (handled automatically) +
- Multi-GPU test currently skipped (
tests/e2e/multigpu/test_gemma3.py)
+
Qwen 3.5 MoE
+Models: Qwen/Qwen3.5-35B-A3B
-
+
Hybrid architecture: DeltaNet linear attention (30 layers) + full attention (10 layers)
+256 experts, 8 active per token
+Known weight scale drift in late DeltaNet layers (36-38) due to AdamW + rare expert interaction
+Fix:
+normalize_weight_scalesconfig to detect and rescale outliers:normalize_weight_scales: + - name_pattern: 'linear_attn\.conv1d\.weight' + threshold: 1.3
+
General MoE Notes
+-
+
lora_target_linear: truewith multimodal MoE models will apply LoRA to ALL linear modules including vision/audio encoders — use regexlora_target_modulesto restrict to language model only
+- Rare experts get larger effective learning rate from AdamW (small second-moment estimates) — can cause weight drift in recurrent/SSM components. Use
normalize_weight_scaleswithdry_run: trueto detect.
+ - For ScatterMoE kernel support, set
experts_implementation: scattermoeand add the KernelsPlugin
+