Gemma4 fixes and profiler (#3591)

This commit is contained in:
Wing Lian
2026-04-10 16:46:17 -04:00
committed by GitHub
parent 315cdeede9
commit 29fa4dedbb
10 changed files with 1926 additions and 1 deletions

View File

@@ -86,7 +86,7 @@ Features:
**Requirements**:
- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU
- Python 3.11
- Python >=3.11 (3.12 recommended)
- PyTorch ≥2.9.1
### Google Colab
@@ -95,6 +95,34 @@ Features:
### Installation
#### Using uv (recommended)
```bash
# install uv if you don't already have it installed
curl -LsSf https://astral.sh/uv/install.sh | sh
source $HOME/.local/bin/env
# CUDA 12.8.1 tends to have better package compatibility
export UV_TORCH_BACKEND=cu128
# create a new virtual environment
uv venv --python 3.12
source .venv/bin/activate
uv pip install torch==2.10.0 torchvision
uv pip install --no-build-isolation axolotl[deepspeed]
# recommended - install cut-cross-entropy
uv pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@main"
# (optional) - prefetch flash-attn2 and causal-conv1d kernels
uv run --python 3.12 python -c "from kernels import get_kernel; get_kernel('kernels-community/flash-attn2'); get_kernel('kernels-community/causal-conv1d')"
# Download example axolotl configs, deepspeed configs
axolotl fetch examples
axolotl fetch deepspeed_configs # OPTIONAL
```
#### Using pip
```bash

View File

@@ -2,6 +2,64 @@
Model-specific quirks, required settings, and known issues. Check this before debugging training failures on specific model families.
## VLM (Vision Language Model) Quick Start
All VLM configs require these four lines:
```yaml
processor_type: AutoProcessor
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
```
Decision tree for VLM config:
```text
Is the model multimodal (has vision/audio encoder)?
├─ YES: Add `freeze_mm_modules: true` if training text only
│ Add `chat_template: <model_template>` (e.g. gemma4, qwen3_5, gemma3)
│ LoRA: use regex `lora_target_modules` to restrict to language model
└─ NO: Train as a regular text model
Is the model MoE (e.g. Gemma4 26B-A4B, Qwen3.5 35B-A3B)?
├─ YES: Add `lora_target_parameters` for expert LoRA
│ Consider ScatterMoE kernels (see Plugins section)
└─ NO: Standard LoRA config
```
## Plugins & Optimizations
### Cut Cross Entropy (CCE)
Computes loss from hidden states + lm_head weight without materializing the full logits tensor, saving significant VRAM. Install if not already present:
```bash
uv pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@main"
```
```yaml
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
```
### ScatterMoE Kernels
Fuses expert + LoRA computation into a single kernel for MoE models. Significant speedup for models with many experts.
```yaml
plugins:
- axolotl.integrations.kernels.KernelsPlugin
use_kernels: true
use_scattermoe: true
experts_implementation: scattermoe
# Expert LoRA targets (3D parameter tensors, not nn.Linear):
lora_target_parameters:
- experts.gate_up_proj
- experts.down_proj
```
Supported: Gemma4 (`gemma4_text`), Mixtral, Qwen MoE variants. The plugin auto-detects model type and routing function. Without ScatterMoE, expert LoRA still works but runs base expert matmul and LoRA as separate operations.
## Gemma 4
**Models**: `google/gemma-4-26B-A4B` (MoE), `google/gemma-4-31B` (dense), `google/gemma-4-E2B`, `google/gemma-4-E4B`
@@ -66,6 +124,36 @@ fsdp_config:
experts_implementation: scattermoe
```
### VLM (Vision) Training
All Gemma4 models load as `Gemma4ForConditionalGeneration` with a vision tower. No custom `ProcessingStrategy` needed — the base class auto-detects the image token.
```yaml
base_model: google/gemma-4-E2B-it # or E4B-it, 26B-A4B
processor_type: AutoProcessor
freeze_mm_modules: true
chat_template: gemma4
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
```
A starting VLM loss of ~8-15 is typical. In most runs, loss converges below 1.0 within ~30-50 steps, though results may vary across configurations.
For the 26B-A4B MoE variant with ScatterMoE + expert LoRA + CCE, add:
```yaml
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
- axolotl.integrations.kernels.KernelsPlugin
use_kernels: true
use_scattermoe: true
experts_implementation: scattermoe
lora_target_parameters:
- experts.gate_up_proj
- experts.down_proj
```
### Common issues
| Symptom | Cause | Fix |

View File

@@ -91,6 +91,30 @@ Watch for: loss never decreasing (check `train_on_inputs`, dataset, LR), loss go
| FSDP save hangs | Use `fsdp_state_dict_type: FULL_STATE_DICT` |
| DeepSpeed CheckpointError | Set `use_reentrant: true` in `gradient_checkpointing_kwargs` |
## Profiling
To profile training and identify optimization opportunities:
```yaml
# Profile steps 3-7 (after warmup/autotuning settles)
profiler_steps_start: 3
profiler_steps: 5
```
This produces `profiler_trace.json` (Chrome trace) and `snapshot.pickle` (memory snapshot) in `output_dir`.
View the Chrome trace at `chrome://tracing`.
To programmatically inspect the trace:
```bash
python scripts/analyze_profile.py output_dir/
```
The trace shows per-kernel CUDA times, memory allocations, and operator-level breakdown. Look for:
- **Large matmul kernels**: candidates for fusion or quantization
- **Memory copies (H2D/D2H)**: unnecessary data movement
- **Small frequent kernels**: candidates for kernel fusion
- **Gaps between kernels**: pipeline bubbles from CPU overhead
Full troubleshooting: [training_stability.qmd](../training_stability.qmd), [debugging.qmd](../debugging.qmd)
## File Map

View File

@@ -8,6 +8,7 @@ format:
## Supported Models
- [Gemma-4](#sec-gemma-4) *(NEW)*
- [Mllama](#sec-mllama)
- [Llama4](#sec-llama4)
- [Pixtral](#sec-pixtral)
@@ -138,6 +139,40 @@ base_model: mistralai/Voxtral-Mini-3B-2507
processor_type: VoxtralProcessor
```
### Gemma-4 {#sec-gemma-4}
All Gemma 4 variants (E2B, E4B, 26B-A4B, 31B) load as multimodal models even for text-only training.
```yaml
base_model: google/gemma-4-E2B-it # or E4B-it, 26B-A4B, 31B
chat_template: gemma4
freeze_mm_modules: true # freeze vision/audio encoders for text-only or vision LoRA
# For the 26B-A4B MoE model, enable ScatterMoE and expert LoRA:
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
- axolotl.integrations.kernels.KernelsPlugin
use_kernels: true
use_scattermoe: true
experts_implementation: scattermoe
lora_target_modules: 'model.language_model.layers.[\d]+.(_checkpoint_wrapped_module.)?(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
# MoE expert LoRA (3D tensors, not nn.Linear) — only for 26B-A4B:
lora_target_parameters:
- experts.gate_up_proj
- experts.down_proj
```
::: {.callout-warning}
Gemma 4 VLM training starts with high loss (~8-15). This is expected — see the [training stability guide](training_stability.qmd) for details.
:::
::: {.callout-tip}
For DDP training, axolotl auto-detects Gemma4 and sets `use_reentrant=False` and `ddp_find_unused_parameters=True`. However, when `activation_offloading: true`, `ddp_find_unused_parameters` is skipped (checkpoint wrappers conflict with it); use `freeze_mm_modules: true` instead to handle unused vision/audio params. For FSDP2, use `fsdp_transformer_layer_cls_to_wrap: Gemma4TextDecoderLayer`.
:::
### Gemma-3 {#sec-gemma-3}
::: {.callout-tip}

View File

@@ -0,0 +1,62 @@
# Gemma 4 E2B Vision LoRA
#
# Fine-tuning LM LoRA adapters on multimodal Gemma4 with vision/multimodal modules frozen.
# Uses the base ProcessingStrategy (auto-detects image_token from processor).
base_model: google/gemma-4-E2B-it
processor_type: AutoProcessor
freeze_mm_modules: true
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
strict: false
# Required for vision/multimodal training
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
chat_template: gemma4
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:100]
val_set_size: 0
output_dir: ./outputs/gemma4-e2b-vision-lora
adapter: lora
sequence_len: 2048
pad_to_sequence_len: false
lora_r: 16
lora_alpha: 32
lora_dropout: 0
# Target language model only — vision encoder is frozen via freeze_mm_modules
lora_target_modules: 'model.language_model.layers.[\d]+.(_checkpoint_wrapped_module.)?(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
max_steps: 10
optimizer: adamw_torch_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
logging_steps: 1
sdp_attention: true
warmup_ratio: 0.1
weight_decay: 0.0
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

View File

@@ -0,0 +1,62 @@
# Qwen 3.5 35B-A3B MoE Vision LoRA
#
# Vision fine-tuning of the hybrid DeltaNet + Attention MoE model.
# 256 experts, 8 active per token, with early-fusion vision support.
base_model: Qwen/Qwen3.5-35B-A3B
processor_type: AutoProcessor
# Required for vision/multimodal training
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
chat_template: qwen3_5
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:100]
val_set_size: 0
output_dir: ./outputs/qwen35-35b-a3b-vision-lora
adapter: lora
sequence_len: 4096
pad_to_sequence_len: false
lora_r: 16
lora_alpha: 32
lora_dropout: 0
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
- down_proj
- up_proj
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
max_steps: 10
optimizer: adamw_torch_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
weight_decay: 0.0
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

1518
scripts/analyze_profile.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -435,6 +435,23 @@ class AxolotlTrainer(
num_items_in_batch=num_items_in_batch,
)
# Gemma4ForConditionalGeneration computes loss with a manual
# nn.CrossEntropyLoss() that bypasses proper num_items_in_batch
# normalization and does redundant attention_mask filtering.
# Compute loss externally using the standard loss_function instead.
if _model_type == "gemma4" and "labels" in inputs:
labels = inputs.pop("labels")
outputs = model(**inputs)
logits = outputs.logits
unwrapped = self.accelerator.unwrap_model(model)
vocab_size = unwrapped.config.get_text_config().vocab_size
loss = unwrapped.loss_function(
logits, labels, vocab_size, num_items_in_batch=num_items_in_batch
)
if return_outputs:
return loss, outputs
return loss
return super().compute_loss(
model,
inputs,

View File

@@ -222,6 +222,56 @@ class LigerPlugin(BasePlugin):
rms_norm=cfg.liger_rms_norm,
swiglu=cfg.liger_glu_activation,
)
elif cfg.model_config_type in ("gemma4", "gemma4_text"):
# Gemma4: offset=0 (NOT 1 like Gemma3), in_place=False required for
# gradient checkpointing compatibility, RoPE incompatible (separate q/k).
from liger_kernel.transformers.geglu import LigerGEGLUMLP
from transformers.models.gemma4 import modeling_gemma4
if cfg.liger_rms_norm:
_OrigGemma4RMSNorm = modeling_gemma4.Gemma4RMSNorm
class _LigerGemma4RMSNorm(LigerRMSNorm):
"""LigerRMSNorm for Gemma4 with in_place=False and with_scale support."""
def __new__(cls, dim, eps=1e-6, with_scale=True):
if not with_scale:
return _OrigGemma4RMSNorm(dim, eps, with_scale=False)
return super().__new__(cls)
def __init__(self, dim, eps=1e-6, with_scale=True):
if not with_scale:
return
# offset=0.0 (standard), in_place=False (gradient checkpointing safe)
super().__init__(
dim, eps, offset=0.0, casting_mode="llama", in_place=False
)
modeling_gemma4.Gemma4RMSNorm = _LigerGemma4RMSNorm
if cfg.liger_glu_activation:
class _LigerGemma4MLP(LigerGEGLUMLP):
def __init__(self, config, layer_idx=None):
super().__init__(config)
modeling_gemma4.Gemma4TextMLP = _LigerGemma4MLP
if cfg.liger_rope:
LOG.warning(
"Liger RoPE is not compatible with Gemma4 (separate q/k application). Skipping."
)
if cfg.liger_layer_norm:
modeling_gemma4.nn.LayerNorm = LigerLayerNorm
if cfg.liger_cross_entropy:
modeling_gemma4.nn.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
LOG.warning(
"Liger fused linear cross entropy is not compatible with Gemma4. Skipping."
)
LOG.info(
f"Applied Liger kernels for gemma4: "
f"rms_norm={cfg.liger_rms_norm}, glu={cfg.liger_glu_activation}, "
f"rope=False (incompatible), layer_norm={cfg.liger_layer_norm}"
)
elif cfg.liger_fused_linear_cross_entropy:
try:
from .models.base import patch_lce_forward

View File

@@ -112,6 +112,47 @@ QKV_PATCHES = [
else:
key_states = key_states.view(hidden_shape)
value_states = value_states.view(hidden_shape) if self.v_proj is not None else key_states
""".lstrip("\n"),
),
# Gemma4 (transformers >= 5.6): shared_kv_states parameter replaces
# past_key_values.shared_layers, and v_norm added after k_norm.
(
"""
query_states = self.q_proj(hidden_states).view(hidden_shape)
query_states = self.q_norm(query_states)
query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2)
query_states = query_states.transpose(1, 2)
# For layers with shared KV (from kv sharing point onwards), we reuse the same keys/values states as the last non-sharing layer.
# We cannot simply reuse the cached state if we have a Cache, as sliding layers will not remember the full states in their Cache
# once we are past the sliding window - so we always use `shared_kv_states` instead, even when past_key_values is not None
if self.is_kv_shared_layer:
key_states, value_states = shared_kv_states[self.kv_shared_layer_index]
# Device of past layer may be different from current one
key_states = key_states.to(query_states.device)
value_states = value_states.to(query_states.device)
else:
key_states = self.k_proj(hidden_states).view(hidden_shape)
value_states = self.v_proj(hidden_states).view(hidden_shape) if self.v_proj is not None else key_states
""".lstrip("\n"),
"""
query_states, key_states, value_states = self.apply_qkv(hidden_states)
query_states = query_states.view(hidden_shape)
query_states = self.q_norm(query_states)
query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2)
query_states = query_states.transpose(1, 2)
# For layers with shared KV (from kv sharing point onwards), we reuse the same keys/values states as the last non-sharing layer.
# We cannot simply reuse the cached state if we have a Cache, as sliding layers will not remember the full states in their Cache
# once we are past the sliding window - so we always use `shared_kv_states` instead, even when past_key_values is not None
if self.is_kv_shared_layer:
key_states, value_states = shared_kv_states[self.kv_shared_layer_index]
# Device of past layer may be different from current one
key_states = key_states.to(query_states.device)
value_states = value_states.to(query_states.device)
else:
key_states = key_states.view(hidden_shape)
value_states = value_states.view(hidden_shape) if self.v_proj is not None else key_states
""".lstrip("\n"),
),
]