Gemma4 fixes and profiler (#3591)
This commit is contained in:
30
README.md
30
README.md
@@ -86,7 +86,7 @@ Features:
|
||||
**Requirements**:
|
||||
|
||||
- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU
|
||||
- Python 3.11
|
||||
- Python >=3.11 (3.12 recommended)
|
||||
- PyTorch ≥2.9.1
|
||||
|
||||
### Google Colab
|
||||
@@ -95,6 +95,34 @@ Features:
|
||||
|
||||
### Installation
|
||||
|
||||
#### Using uv (recommended)
|
||||
|
||||
```bash
|
||||
# install uv if you don't already have it installed
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
source $HOME/.local/bin/env
|
||||
|
||||
# CUDA 12.8.1 tends to have better package compatibility
|
||||
export UV_TORCH_BACKEND=cu128
|
||||
|
||||
# create a new virtual environment
|
||||
uv venv --python 3.12
|
||||
source .venv/bin/activate
|
||||
|
||||
uv pip install torch==2.10.0 torchvision
|
||||
uv pip install --no-build-isolation axolotl[deepspeed]
|
||||
|
||||
# recommended - install cut-cross-entropy
|
||||
uv pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@main"
|
||||
|
||||
# (optional) - prefetch flash-attn2 and causal-conv1d kernels
|
||||
uv run --python 3.12 python -c "from kernels import get_kernel; get_kernel('kernels-community/flash-attn2'); get_kernel('kernels-community/causal-conv1d')"
|
||||
|
||||
# Download example axolotl configs, deepspeed configs
|
||||
axolotl fetch examples
|
||||
axolotl fetch deepspeed_configs # OPTIONAL
|
||||
```
|
||||
|
||||
#### Using pip
|
||||
|
||||
```bash
|
||||
|
||||
@@ -2,6 +2,64 @@
|
||||
|
||||
Model-specific quirks, required settings, and known issues. Check this before debugging training failures on specific model families.
|
||||
|
||||
## VLM (Vision Language Model) Quick Start
|
||||
|
||||
All VLM configs require these four lines:
|
||||
```yaml
|
||||
processor_type: AutoProcessor
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
```
|
||||
|
||||
Decision tree for VLM config:
|
||||
```text
|
||||
Is the model multimodal (has vision/audio encoder)?
|
||||
├─ YES: Add `freeze_mm_modules: true` if training text only
|
||||
│ Add `chat_template: <model_template>` (e.g. gemma4, qwen3_5, gemma3)
|
||||
│ LoRA: use regex `lora_target_modules` to restrict to language model
|
||||
└─ NO: Train as a regular text model
|
||||
|
||||
Is the model MoE (e.g. Gemma4 26B-A4B, Qwen3.5 35B-A3B)?
|
||||
├─ YES: Add `lora_target_parameters` for expert LoRA
|
||||
│ Consider ScatterMoE kernels (see Plugins section)
|
||||
└─ NO: Standard LoRA config
|
||||
```
|
||||
|
||||
## Plugins & Optimizations
|
||||
|
||||
### Cut Cross Entropy (CCE)
|
||||
|
||||
Computes loss from hidden states + lm_head weight without materializing the full logits tensor, saving significant VRAM. Install if not already present:
|
||||
|
||||
```bash
|
||||
uv pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@main"
|
||||
```
|
||||
|
||||
```yaml
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
```
|
||||
|
||||
### ScatterMoE Kernels
|
||||
|
||||
Fuses expert + LoRA computation into a single kernel for MoE models. Significant speedup for models with many experts.
|
||||
|
||||
```yaml
|
||||
plugins:
|
||||
- axolotl.integrations.kernels.KernelsPlugin
|
||||
use_kernels: true
|
||||
use_scattermoe: true
|
||||
experts_implementation: scattermoe
|
||||
|
||||
# Expert LoRA targets (3D parameter tensors, not nn.Linear):
|
||||
lora_target_parameters:
|
||||
- experts.gate_up_proj
|
||||
- experts.down_proj
|
||||
```
|
||||
|
||||
Supported: Gemma4 (`gemma4_text`), Mixtral, Qwen MoE variants. The plugin auto-detects model type and routing function. Without ScatterMoE, expert LoRA still works but runs base expert matmul and LoRA as separate operations.
|
||||
|
||||
## Gemma 4
|
||||
|
||||
**Models**: `google/gemma-4-26B-A4B` (MoE), `google/gemma-4-31B` (dense), `google/gemma-4-E2B`, `google/gemma-4-E4B`
|
||||
@@ -66,6 +124,36 @@ fsdp_config:
|
||||
experts_implementation: scattermoe
|
||||
```
|
||||
|
||||
### VLM (Vision) Training
|
||||
|
||||
All Gemma4 models load as `Gemma4ForConditionalGeneration` with a vision tower. No custom `ProcessingStrategy` needed — the base class auto-detects the image token.
|
||||
|
||||
```yaml
|
||||
base_model: google/gemma-4-E2B-it # or E4B-it, 26B-A4B
|
||||
processor_type: AutoProcessor
|
||||
freeze_mm_modules: true
|
||||
chat_template: gemma4
|
||||
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
```
|
||||
|
||||
A starting VLM loss of ~8-15 is typical. In most runs, loss converges below 1.0 within ~30-50 steps, though results may vary across configurations.
|
||||
|
||||
For the 26B-A4B MoE variant with ScatterMoE + expert LoRA + CCE, add:
|
||||
```yaml
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
- axolotl.integrations.kernels.KernelsPlugin
|
||||
use_kernels: true
|
||||
use_scattermoe: true
|
||||
experts_implementation: scattermoe
|
||||
lora_target_parameters:
|
||||
- experts.gate_up_proj
|
||||
- experts.down_proj
|
||||
```
|
||||
|
||||
### Common issues
|
||||
|
||||
| Symptom | Cause | Fix |
|
||||
|
||||
@@ -91,6 +91,30 @@ Watch for: loss never decreasing (check `train_on_inputs`, dataset, LR), loss go
|
||||
| FSDP save hangs | Use `fsdp_state_dict_type: FULL_STATE_DICT` |
|
||||
| DeepSpeed CheckpointError | Set `use_reentrant: true` in `gradient_checkpointing_kwargs` |
|
||||
|
||||
## Profiling
|
||||
|
||||
To profile training and identify optimization opportunities:
|
||||
|
||||
```yaml
|
||||
# Profile steps 3-7 (after warmup/autotuning settles)
|
||||
profiler_steps_start: 3
|
||||
profiler_steps: 5
|
||||
```
|
||||
|
||||
This produces `profiler_trace.json` (Chrome trace) and `snapshot.pickle` (memory snapshot) in `output_dir`.
|
||||
View the Chrome trace at `chrome://tracing`.
|
||||
|
||||
To programmatically inspect the trace:
|
||||
```bash
|
||||
python scripts/analyze_profile.py output_dir/
|
||||
```
|
||||
|
||||
The trace shows per-kernel CUDA times, memory allocations, and operator-level breakdown. Look for:
|
||||
- **Large matmul kernels**: candidates for fusion or quantization
|
||||
- **Memory copies (H2D/D2H)**: unnecessary data movement
|
||||
- **Small frequent kernels**: candidates for kernel fusion
|
||||
- **Gaps between kernels**: pipeline bubbles from CPU overhead
|
||||
|
||||
Full troubleshooting: [training_stability.qmd](../training_stability.qmd), [debugging.qmd](../debugging.qmd)
|
||||
|
||||
## File Map
|
||||
|
||||
@@ -8,6 +8,7 @@ format:
|
||||
|
||||
## Supported Models
|
||||
|
||||
- [Gemma-4](#sec-gemma-4) *(NEW)*
|
||||
- [Mllama](#sec-mllama)
|
||||
- [Llama4](#sec-llama4)
|
||||
- [Pixtral](#sec-pixtral)
|
||||
@@ -138,6 +139,40 @@ base_model: mistralai/Voxtral-Mini-3B-2507
|
||||
processor_type: VoxtralProcessor
|
||||
```
|
||||
|
||||
### Gemma-4 {#sec-gemma-4}
|
||||
|
||||
All Gemma 4 variants (E2B, E4B, 26B-A4B, 31B) load as multimodal models even for text-only training.
|
||||
|
||||
```yaml
|
||||
base_model: google/gemma-4-E2B-it # or E4B-it, 26B-A4B, 31B
|
||||
|
||||
chat_template: gemma4
|
||||
freeze_mm_modules: true # freeze vision/audio encoders for text-only or vision LoRA
|
||||
|
||||
# For the 26B-A4B MoE model, enable ScatterMoE and expert LoRA:
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
- axolotl.integrations.kernels.KernelsPlugin
|
||||
use_kernels: true
|
||||
use_scattermoe: true
|
||||
experts_implementation: scattermoe
|
||||
|
||||
lora_target_modules: 'model.language_model.layers.[\d]+.(_checkpoint_wrapped_module.)?(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
# MoE expert LoRA (3D tensors, not nn.Linear) — only for 26B-A4B:
|
||||
lora_target_parameters:
|
||||
- experts.gate_up_proj
|
||||
- experts.down_proj
|
||||
```
|
||||
|
||||
::: {.callout-warning}
|
||||
Gemma 4 VLM training starts with high loss (~8-15). This is expected — see the [training stability guide](training_stability.qmd) for details.
|
||||
:::
|
||||
|
||||
::: {.callout-tip}
|
||||
For DDP training, axolotl auto-detects Gemma4 and sets `use_reentrant=False` and `ddp_find_unused_parameters=True`. However, when `activation_offloading: true`, `ddp_find_unused_parameters` is skipped (checkpoint wrappers conflict with it); use `freeze_mm_modules: true` instead to handle unused vision/audio params. For FSDP2, use `fsdp_transformer_layer_cls_to_wrap: Gemma4TextDecoderLayer`.
|
||||
:::
|
||||
|
||||
### Gemma-3 {#sec-gemma-3}
|
||||
|
||||
::: {.callout-tip}
|
||||
|
||||
62
examples/gemma4/e2b-vision-lora.yaml
Normal file
62
examples/gemma4/e2b-vision-lora.yaml
Normal file
@@ -0,0 +1,62 @@
|
||||
# Gemma 4 E2B Vision LoRA
|
||||
#
|
||||
# Fine-tuning LM LoRA adapters on multimodal Gemma4 with vision/multimodal modules frozen.
|
||||
# Uses the base ProcessingStrategy (auto-detects image_token from processor).
|
||||
|
||||
base_model: google/gemma-4-E2B-it
|
||||
processor_type: AutoProcessor
|
||||
freeze_mm_modules: true
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
strict: false
|
||||
|
||||
# Required for vision/multimodal training
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
|
||||
chat_template: gemma4
|
||||
datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:100]
|
||||
|
||||
val_set_size: 0
|
||||
output_dir: ./outputs/gemma4-e2b-vision-lora
|
||||
|
||||
adapter: lora
|
||||
sequence_len: 2048
|
||||
pad_to_sequence_len: false
|
||||
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0
|
||||
# Target language model only — vision encoder is frozen via freeze_mm_modules
|
||||
lora_target_modules: 'model.language_model.layers.[\d]+.(_checkpoint_wrapped_module.)?(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
max_steps: 10
|
||||
optimizer: adamw_torch_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
logging_steps: 1
|
||||
sdp_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
weight_decay: 0.0
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
62
examples/qwen3.5/35b-a3b-moe-vision-lora.yaml
Normal file
62
examples/qwen3.5/35b-a3b-moe-vision-lora.yaml
Normal file
@@ -0,0 +1,62 @@
|
||||
# Qwen 3.5 35B-A3B MoE Vision LoRA
|
||||
#
|
||||
# Vision fine-tuning of the hybrid DeltaNet + Attention MoE model.
|
||||
# 256 experts, 8 active per token, with early-fusion vision support.
|
||||
|
||||
base_model: Qwen/Qwen3.5-35B-A3B
|
||||
processor_type: AutoProcessor
|
||||
|
||||
# Required for vision/multimodal training
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
|
||||
chat_template: qwen3_5
|
||||
datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:100]
|
||||
|
||||
val_set_size: 0
|
||||
output_dir: ./outputs/qwen35-35b-a3b-vision-lora
|
||||
|
||||
adapter: lora
|
||||
sequence_len: 4096
|
||||
pad_to_sequence_len: false
|
||||
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- k_proj
|
||||
- v_proj
|
||||
- o_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
max_steps: 10
|
||||
optimizer: adamw_torch_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
weight_decay: 0.0
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
1518
scripts/analyze_profile.py
Normal file
1518
scripts/analyze_profile.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -435,6 +435,23 @@ class AxolotlTrainer(
|
||||
num_items_in_batch=num_items_in_batch,
|
||||
)
|
||||
|
||||
# Gemma4ForConditionalGeneration computes loss with a manual
|
||||
# nn.CrossEntropyLoss() that bypasses proper num_items_in_batch
|
||||
# normalization and does redundant attention_mask filtering.
|
||||
# Compute loss externally using the standard loss_function instead.
|
||||
if _model_type == "gemma4" and "labels" in inputs:
|
||||
labels = inputs.pop("labels")
|
||||
outputs = model(**inputs)
|
||||
logits = outputs.logits
|
||||
unwrapped = self.accelerator.unwrap_model(model)
|
||||
vocab_size = unwrapped.config.get_text_config().vocab_size
|
||||
loss = unwrapped.loss_function(
|
||||
logits, labels, vocab_size, num_items_in_batch=num_items_in_batch
|
||||
)
|
||||
if return_outputs:
|
||||
return loss, outputs
|
||||
return loss
|
||||
|
||||
return super().compute_loss(
|
||||
model,
|
||||
inputs,
|
||||
|
||||
@@ -222,6 +222,56 @@ class LigerPlugin(BasePlugin):
|
||||
rms_norm=cfg.liger_rms_norm,
|
||||
swiglu=cfg.liger_glu_activation,
|
||||
)
|
||||
elif cfg.model_config_type in ("gemma4", "gemma4_text"):
|
||||
# Gemma4: offset=0 (NOT 1 like Gemma3), in_place=False required for
|
||||
# gradient checkpointing compatibility, RoPE incompatible (separate q/k).
|
||||
from liger_kernel.transformers.geglu import LigerGEGLUMLP
|
||||
from transformers.models.gemma4 import modeling_gemma4
|
||||
|
||||
if cfg.liger_rms_norm:
|
||||
_OrigGemma4RMSNorm = modeling_gemma4.Gemma4RMSNorm
|
||||
|
||||
class _LigerGemma4RMSNorm(LigerRMSNorm):
|
||||
"""LigerRMSNorm for Gemma4 with in_place=False and with_scale support."""
|
||||
|
||||
def __new__(cls, dim, eps=1e-6, with_scale=True):
|
||||
if not with_scale:
|
||||
return _OrigGemma4RMSNorm(dim, eps, with_scale=False)
|
||||
return super().__new__(cls)
|
||||
|
||||
def __init__(self, dim, eps=1e-6, with_scale=True):
|
||||
if not with_scale:
|
||||
return
|
||||
# offset=0.0 (standard), in_place=False (gradient checkpointing safe)
|
||||
super().__init__(
|
||||
dim, eps, offset=0.0, casting_mode="llama", in_place=False
|
||||
)
|
||||
|
||||
modeling_gemma4.Gemma4RMSNorm = _LigerGemma4RMSNorm
|
||||
if cfg.liger_glu_activation:
|
||||
|
||||
class _LigerGemma4MLP(LigerGEGLUMLP):
|
||||
def __init__(self, config, layer_idx=None):
|
||||
super().__init__(config)
|
||||
|
||||
modeling_gemma4.Gemma4TextMLP = _LigerGemma4MLP
|
||||
if cfg.liger_rope:
|
||||
LOG.warning(
|
||||
"Liger RoPE is not compatible with Gemma4 (separate q/k application). Skipping."
|
||||
)
|
||||
if cfg.liger_layer_norm:
|
||||
modeling_gemma4.nn.LayerNorm = LigerLayerNorm
|
||||
if cfg.liger_cross_entropy:
|
||||
modeling_gemma4.nn.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||
if cfg.liger_fused_linear_cross_entropy:
|
||||
LOG.warning(
|
||||
"Liger fused linear cross entropy is not compatible with Gemma4. Skipping."
|
||||
)
|
||||
LOG.info(
|
||||
f"Applied Liger kernels for gemma4: "
|
||||
f"rms_norm={cfg.liger_rms_norm}, glu={cfg.liger_glu_activation}, "
|
||||
f"rope=False (incompatible), layer_norm={cfg.liger_layer_norm}"
|
||||
)
|
||||
elif cfg.liger_fused_linear_cross_entropy:
|
||||
try:
|
||||
from .models.base import patch_lce_forward
|
||||
|
||||
@@ -112,6 +112,47 @@ QKV_PATCHES = [
|
||||
else:
|
||||
key_states = key_states.view(hidden_shape)
|
||||
value_states = value_states.view(hidden_shape) if self.v_proj is not None else key_states
|
||||
""".lstrip("\n"),
|
||||
),
|
||||
# Gemma4 (transformers >= 5.6): shared_kv_states parameter replaces
|
||||
# past_key_values.shared_layers, and v_norm added after k_norm.
|
||||
(
|
||||
"""
|
||||
query_states = self.q_proj(hidden_states).view(hidden_shape)
|
||||
query_states = self.q_norm(query_states)
|
||||
query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2)
|
||||
query_states = query_states.transpose(1, 2)
|
||||
|
||||
# For layers with shared KV (from kv sharing point onwards), we reuse the same keys/values states as the last non-sharing layer.
|
||||
# We cannot simply reuse the cached state if we have a Cache, as sliding layers will not remember the full states in their Cache
|
||||
# once we are past the sliding window - so we always use `shared_kv_states` instead, even when past_key_values is not None
|
||||
if self.is_kv_shared_layer:
|
||||
key_states, value_states = shared_kv_states[self.kv_shared_layer_index]
|
||||
# Device of past layer may be different from current one
|
||||
key_states = key_states.to(query_states.device)
|
||||
value_states = value_states.to(query_states.device)
|
||||
else:
|
||||
key_states = self.k_proj(hidden_states).view(hidden_shape)
|
||||
value_states = self.v_proj(hidden_states).view(hidden_shape) if self.v_proj is not None else key_states
|
||||
""".lstrip("\n"),
|
||||
"""
|
||||
query_states, key_states, value_states = self.apply_qkv(hidden_states)
|
||||
query_states = query_states.view(hidden_shape)
|
||||
query_states = self.q_norm(query_states)
|
||||
query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2)
|
||||
query_states = query_states.transpose(1, 2)
|
||||
|
||||
# For layers with shared KV (from kv sharing point onwards), we reuse the same keys/values states as the last non-sharing layer.
|
||||
# We cannot simply reuse the cached state if we have a Cache, as sliding layers will not remember the full states in their Cache
|
||||
# once we are past the sliding window - so we always use `shared_kv_states` instead, even when past_key_values is not None
|
||||
if self.is_kv_shared_layer:
|
||||
key_states, value_states = shared_kv_states[self.kv_shared_layer_index]
|
||||
# Device of past layer may be different from current one
|
||||
key_states = key_states.to(query_states.device)
|
||||
value_states = value_states.to(query_states.device)
|
||||
else:
|
||||
key_states = key_states.view(hidden_shape)
|
||||
value_states = value_states.view(hidden_shape) if self.v_proj is not None else key_states
|
||||
""".lstrip("\n"),
|
||||
),
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user