Gemma4 fixes and profiler (#3591)
This commit is contained in:
@@ -2,6 +2,64 @@
|
||||
|
||||
Model-specific quirks, required settings, and known issues. Check this before debugging training failures on specific model families.
|
||||
|
||||
## VLM (Vision Language Model) Quick Start
|
||||
|
||||
All VLM configs require these four lines:
|
||||
```yaml
|
||||
processor_type: AutoProcessor
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
```
|
||||
|
||||
Decision tree for VLM config:
|
||||
```text
|
||||
Is the model multimodal (has vision/audio encoder)?
|
||||
├─ YES: Add `freeze_mm_modules: true` if training text only
|
||||
│ Add `chat_template: <model_template>` (e.g. gemma4, qwen3_5, gemma3)
|
||||
│ LoRA: use regex `lora_target_modules` to restrict to language model
|
||||
└─ NO: Train as a regular text model
|
||||
|
||||
Is the model MoE (e.g. Gemma4 26B-A4B, Qwen3.5 35B-A3B)?
|
||||
├─ YES: Add `lora_target_parameters` for expert LoRA
|
||||
│ Consider ScatterMoE kernels (see Plugins section)
|
||||
└─ NO: Standard LoRA config
|
||||
```
|
||||
|
||||
## Plugins & Optimizations
|
||||
|
||||
### Cut Cross Entropy (CCE)
|
||||
|
||||
Computes loss from hidden states + lm_head weight without materializing the full logits tensor, saving significant VRAM. Install if not already present:
|
||||
|
||||
```bash
|
||||
uv pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@main"
|
||||
```
|
||||
|
||||
```yaml
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
```
|
||||
|
||||
### ScatterMoE Kernels
|
||||
|
||||
Fuses expert + LoRA computation into a single kernel for MoE models. Significant speedup for models with many experts.
|
||||
|
||||
```yaml
|
||||
plugins:
|
||||
- axolotl.integrations.kernels.KernelsPlugin
|
||||
use_kernels: true
|
||||
use_scattermoe: true
|
||||
experts_implementation: scattermoe
|
||||
|
||||
# Expert LoRA targets (3D parameter tensors, not nn.Linear):
|
||||
lora_target_parameters:
|
||||
- experts.gate_up_proj
|
||||
- experts.down_proj
|
||||
```
|
||||
|
||||
Supported: Gemma4 (`gemma4_text`), Mixtral, Qwen MoE variants. The plugin auto-detects model type and routing function. Without ScatterMoE, expert LoRA still works but runs base expert matmul and LoRA as separate operations.
|
||||
|
||||
## Gemma 4
|
||||
|
||||
**Models**: `google/gemma-4-26B-A4B` (MoE), `google/gemma-4-31B` (dense), `google/gemma-4-E2B`, `google/gemma-4-E4B`
|
||||
@@ -66,6 +124,36 @@ fsdp_config:
|
||||
experts_implementation: scattermoe
|
||||
```
|
||||
|
||||
### VLM (Vision) Training
|
||||
|
||||
All Gemma4 models load as `Gemma4ForConditionalGeneration` with a vision tower. No custom `ProcessingStrategy` needed — the base class auto-detects the image token.
|
||||
|
||||
```yaml
|
||||
base_model: google/gemma-4-E2B-it # or E4B-it, 26B-A4B
|
||||
processor_type: AutoProcessor
|
||||
freeze_mm_modules: true
|
||||
chat_template: gemma4
|
||||
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
```
|
||||
|
||||
A starting VLM loss of ~8-15 is typical. In most runs, loss converges below 1.0 within ~30-50 steps, though results may vary across configurations.
|
||||
|
||||
For the 26B-A4B MoE variant with ScatterMoE + expert LoRA + CCE, add:
|
||||
```yaml
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
- axolotl.integrations.kernels.KernelsPlugin
|
||||
use_kernels: true
|
||||
use_scattermoe: true
|
||||
experts_implementation: scattermoe
|
||||
lora_target_parameters:
|
||||
- experts.gate_up_proj
|
||||
- experts.down_proj
|
||||
```
|
||||
|
||||
### Common issues
|
||||
|
||||
| Symptom | Cause | Fix |
|
||||
|
||||
@@ -91,6 +91,30 @@ Watch for: loss never decreasing (check `train_on_inputs`, dataset, LR), loss go
|
||||
| FSDP save hangs | Use `fsdp_state_dict_type: FULL_STATE_DICT` |
|
||||
| DeepSpeed CheckpointError | Set `use_reentrant: true` in `gradient_checkpointing_kwargs` |
|
||||
|
||||
## Profiling
|
||||
|
||||
To profile training and identify optimization opportunities:
|
||||
|
||||
```yaml
|
||||
# Profile steps 3-7 (after warmup/autotuning settles)
|
||||
profiler_steps_start: 3
|
||||
profiler_steps: 5
|
||||
```
|
||||
|
||||
This produces `profiler_trace.json` (Chrome trace) and `snapshot.pickle` (memory snapshot) in `output_dir`.
|
||||
View the Chrome trace at `chrome://tracing`.
|
||||
|
||||
To programmatically inspect the trace:
|
||||
```bash
|
||||
python scripts/analyze_profile.py output_dir/
|
||||
```
|
||||
|
||||
The trace shows per-kernel CUDA times, memory allocations, and operator-level breakdown. Look for:
|
||||
- **Large matmul kernels**: candidates for fusion or quantization
|
||||
- **Memory copies (H2D/D2H)**: unnecessary data movement
|
||||
- **Small frequent kernels**: candidates for kernel fusion
|
||||
- **Gaps between kernels**: pipeline bubbles from CPU overhead
|
||||
|
||||
Full troubleshooting: [training_stability.qmd](../training_stability.qmd), [debugging.qmd](../debugging.qmd)
|
||||
|
||||
## File Map
|
||||
|
||||
Reference in New Issue
Block a user