fix ddp/fsdp w gemma4 (#3584)

* fix ddp/fsdp w gemma4

* address pr comments

* activation offloading fix and update agent docs for gemma4
This commit is contained in:
Wing Lian
2026-04-09 20:02:36 -07:00
committed by GitHub
parent 7daf7d96f1
commit 4ef608dda3
9 changed files with 398 additions and 2 deletions

View File

@@ -38,6 +38,8 @@ Agent-specific references:
- [docs/agents/grpo.md](docs/agents/grpo.md) — GRPO online RL with reward functions - [docs/agents/grpo.md](docs/agents/grpo.md) — GRPO online RL with reward functions
- [docs/agents/reward_modelling.md](docs/agents/reward_modelling.md) — outcome and process reward models - [docs/agents/reward_modelling.md](docs/agents/reward_modelling.md) — outcome and process reward models
- [docs/agents/pretraining.md](docs/agents/pretraining.md) — continual pretraining - [docs/agents/pretraining.md](docs/agents/pretraining.md) — continual pretraining
- [docs/agents/model_architectures.md](docs/agents/model_architectures.md) — model-specific quirks (Gemma4, Qwen3.5 MoE, etc.)
- [docs/agents/new_model_support.md](docs/agents/new_model_support.md) — debugging and adding support for new model architectures
## Config Pattern ## Config Pattern

View File

@@ -0,0 +1,110 @@
# Model Architectures — Agent Reference
Model-specific quirks, required settings, and known issues. Check this before debugging training failures on specific model families.
## Gemma 4
**Models**: `google/gemma-4-26B-A4B` (MoE), `google/gemma-4-31B` (dense), `google/gemma-4-E2B`, `google/gemma-4-E4B`
**Architecture**: Multimodal wrapper (`Gemma4ForConditionalGeneration`) over a text backbone (`Gemma4TextModel`), with optional vision/audio encoders. All Gemma4 HF repos have `model_type: "gemma4"` — even text-only variants load as multimodal with a vision tower.
### Required settings
```yaml
# Always needed for Gemma4:
freeze_mm_modules: true # Freeze vision/audio encoders for text-only training
gradient_checkpointing_kwargs:
use_reentrant: false # Shared per-layer norms cause "marked ready twice" with reentrant
# LoRA target — restrict to language model only (DO NOT use lora_target_linear: true):
lora_target_modules: 'model.language_model.layers.[\d]+.(_checkpoint_wrapped_module.)?(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
```
### Auto-detection
Axolotl auto-detects Gemma4 and applies:
- `use_reentrant: false` for gradient checkpointing
- `ddp_find_unused_parameters: true` for DDP (skipped when `activation_offloading: true`)
### Multi-GPU
| Strategy | Works? | Notes |
|----------|--------|-------|
| DDP | Yes | Auto-sets `ddp_find_unused_parameters=True` |
| DDP + activation_offloading | Yes | `find_unused_parameters` is skipped (conflicts with checkpoint wrappers) |
| FSDP1 | No | OOM during dequantization/sharding with QLoRA |
| FSDP2 | Yes | Use `Gemma4TextDecoderLayer` (not `Gemma4DecoderLayer`) as wrap class |
| FSDP2 + activation_offloading | Yes | Lowest VRAM (~26 GiB/GPU for 26B-A4B) |
FSDP2 config:
```yaml
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_version: 2
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: Gemma4TextDecoderLayer
```
### MoE (26B-A4B)
- `enable_moe_block: true`, 256 experts, top-k routing
- No separate `SparseMoeBlock` — MoE is embedded in each decoder layer
- Expert LoRA targets 3D parameter tensors:
```yaml
lora_target_parameters:
- experts.gate_up_proj
- experts.down_proj
```
- ScatterMoE kernel acceleration:
```yaml
plugins:
- axolotl.integrations.kernels.KernelsPlugin
use_kernels: true
use_scattermoe: true
experts_implementation: scattermoe
```
### Common issues
| Symptom | Cause | Fix |
|---------|-------|-----|
| `mm_token_type_ids is required` in DDP | `model.config` not accessible through DDP wrapper | Already fixed — `unwrap_model()` in `compute_loss` and `prediction_step` |
| `marked a variable ready twice` in DDP | `ddp_find_unused_parameters=True` + activation_offloading checkpoint wrappers | Auto-handled — `find_unused_parameters` is skipped when `activation_offloading: true` |
| Loss ~12 instead of ~0.5 | Using `lora_target_linear: true` (applies LoRA to vision/audio modules) | Use the regex `lora_target_modules` pattern instead |
| FSDP2 `Could not find Gemma4AudioLayer` | Auto-wrap detects `_no_split_modules` including audio layers that don't exist | Explicitly set `fsdp_transformer_layer_cls_to_wrap: Gemma4TextDecoderLayer` |
| `Gemma4ClippableLinear not supported` by PEFT | Vision tower uses a non-standard linear wrapper | Axolotl patches this automatically via `_patch_peft_clippable_linear()` |
### E2B/E4B dense models
These have `hidden_size_per_layer_input: 256` (per-layer input embeddings) and `attention_k_eq_v: False`. Known issue: loss starts higher than expected (~12 vs ~0.5 for 26B). Root cause under investigation — may be related to the per-layer input mechanism or the `Gemma4ForConditionalGeneration` loss computation.
## Gemma 3
**Models**: `google/gemma-3-*`
- `ddp_find_unused_parameters: true` needed (multimodal unused params)
- `use_reentrant: false` recommended
- Attention mask must be dropped for sample packing (handled automatically)
- Multi-GPU test currently skipped (`tests/e2e/multigpu/test_gemma3.py`)
## Qwen 3.5 MoE
**Models**: `Qwen/Qwen3.5-35B-A3B`
- Hybrid architecture: DeltaNet linear attention (30 layers) + full attention (10 layers)
- 256 experts, 8 active per token
- Known weight scale drift in late DeltaNet layers (36-38) due to AdamW + rare expert interaction
- Fix: `normalize_weight_scales` config to detect and rescale outliers:
```yaml
normalize_weight_scales:
- name_pattern: 'linear_attn\.conv1d\.weight'
threshold: 1.3
```
## General MoE Notes
- `lora_target_linear: true` with multimodal MoE models will apply LoRA to ALL linear modules including vision/audio encoders — use regex `lora_target_modules` to restrict to language model only
- Rare experts get larger effective learning rate from AdamW (small second-moment estimates) — can cause weight drift in recurrent/SSM components. Use `normalize_weight_scales` with `dry_run: true` to detect.
- For ScatterMoE kernel support, set `experts_implementation: scattermoe` and add the KernelsPlugin

View File

@@ -0,0 +1,181 @@
# New Model Support — Agent Reference
Guide for debugging and adding support for new model architectures in axolotl. Based on lessons learned from Gemma4, Gemma3, Qwen2-VL, and other multimodal/MoE models.
## Quick Validation Checklist
When testing a new model, run through these checks in order:
1. **Does the model load?** `axolotl preprocess config.yaml` — catches config schema errors
2. **Does LoRA apply?** Check for "Unsupported layer type" warnings from PEFT
3. **Is the initial loss sane?** First-step loss for a pretrained model should be 0.52.0 for SFT
4. **Does sample packing work?** Compare loss with `sample_packing: true` vs `false` — should be similar
5. **Is CCE active?** Check for "Applying Cut Cross Entropy" log and verify peak VRAM is lower
## Loss Debugging
### Expected initial loss
A pretrained model doing SFT should start with loss roughly in the 0.52.0 range. If loss starts above 3.0, something is wrong. If it's near `log(vocab_size)` (≈ 12 for 262K vocab), the model is predicting at random — attention masking or model weights are broken.
### Direct comparison technique
The fastest way to isolate a loss issue — bypass the trainer entirely:
```python
# Load model via axolotl's pipeline (applies all patches)
from axolotl.cli.config import load_cfg
from axolotl.utils.config import normalize_config, prepare_plugins
from axolotl.loaders.tokenizer import load_tokenizer
from axolotl.loaders.model import ModelLoader
cfg = load_cfg("your_config.yaml")
normalize_config(cfg)
prepare_plugins(cfg)
tokenizer = load_tokenizer(cfg)
model, _ = ModelLoader(cfg, tokenizer).load()
# Forward pass on preprocessed data
model.train()
out = model(input_ids, labels=labels)
print(f"Direct loss: {out.loss.item()}") # Compare to trainer's reported loss
```
If direct loss is correct (~1.0) but trainer reports 34x higher, check `model_accepts_loss_kwargs` (see below).
### `model_accepts_loss_kwargs` inflation
HF Trainer checks if the model's `forward()` has `**kwargs` and sets `model_accepts_loss_kwargs=True`. This changes loss normalization: the trainer does NOT divide loss by `gradient_accumulation_steps` before logging. The gradient is correct — only the logged loss is inflated.
**Symptom**: Logged loss ≈ actual_loss × gradient_accumulation_steps.
**Which models are affected**: Any model with `**kwargs` in forward (common in multimodal models for extra inputs like `mm_token_type_ids`, `pixel_values`, etc.).
**Fix location**: `src/axolotl/core/trainers/base.py` `__init__()` — after `super().__init__()`, check if the unwrapped model actually has `num_items_in_batch` in its forward signature. If not, set `self.model_accepts_loss_kwargs = False`.
## Multimodal Models (ForConditionalGeneration)
Many recent models use `ForConditionalGeneration` as the top-level class, not `ForCausalLM`:
- Gemma3 → `Gemma3ForConditionalGeneration`
- Gemma4 → `Gemma4ForConditionalGeneration`
- Qwen2-VL → `Qwen2VLForConditionalGeneration`
- LLaVA → `LlavaForConditionalGeneration`
### Why this matters
| Component | Targets `ForCausalLM` | Needs `ForConditionalGeneration` |
|-----------|----------------------|--------------------------------|
| CCE patches | ✅ (default) | ❌ silently inactive if not patched |
| PEFT LoRA | ✅ | May fail on custom layer types |
| HF Trainer label handling | ✅ | May need extra inputs |
### Required extra inputs
Multimodal models require special inputs during training even for text-only data:
| Model | Required Input | Value for Text-Only |
|-------|---------------|-------------------|
| Gemma4 | `mm_token_type_ids` | `torch.zeros_like(input_ids)` |
| Gemma3 | `token_type_ids` | `torch.zeros_like(input_ids)` |
Auto-inject in `compute_loss()` when not provided by the data collator. See `core/trainers/base.py`.
### Custom layer types and PEFT
Vision towers often use custom module wrappers that PEFT doesn't support:
| Model | Custom Layer | Wraps | Fix |
|-------|-------------|-------|-----|
| Gemma4 | `Gemma4ClippableLinear` | `nn.Linear` | Redirect to `.linear` child |
Fix location: `src/axolotl/loaders/adapter.py` `_patch_peft_clippable_linear()`.
## Sample Packing
### How packed sequence detection works (transformers ≥ 5.x)
`transformers.masking_utils._preprocess_mask_arguments()` detects packed sequences from `position_ids` resets. But **only when `attention_mask is None`**:
```python
# From masking_utils.py:
if position_ids is not None and attention_mask is None and past_key_values is None:
packed_sequence_mask = find_packed_sequence_indices(position_ids)
```
If the collator provides an all-ones `attention_mask`, packing detection is **skipped** and the model builds a single causal mask spanning all packed sequences → cross-sequence attention leakage → very high loss.
### Fix for models using `create_causal_mask_mapping`
For Gemma3, Gemma4, and similar models that use the new transformers masking system, remove `attention_mask` from inputs when sample packing is active:
```python
# In compute_loss():
if (
self.args.sample_packing
and model_type in ("gemma4", "gemma3")
and "attention_mask" in inputs
and "position_ids" in inputs
):
del inputs["attention_mask"]
```
Fix location: `src/axolotl/core/trainers/base.py` `compute_loss()`.
### Models that DON'T need this fix
Older models that use `_prepare_4d_causal_attention_mask` (Llama, Mistral, Qwen2, etc.) handle sample packing via axolotl's multipack attention monkeypatch instead. Only models using the new `create_causal_mask_mapping` / `create_causal_mask` masking system need the `attention_mask` removal.
## Attention Backend Selection
| Backend | Config | head_dim limit | torch_compile | Notes |
|---------|--------|---------------|---------------|-------|
| FA2 | `flash_attention: true` | 256 | ✅ | Fastest when supported |
| FA4 | auto with `flash_attention: true` | 256 (SM90+) | ✅ | Auto-detected on H100+ |
| SDPA | `sdp_attention: true` | None | ✅ | Universal fallback |
| flex | `flex_attention: true` | None | ⚠️ Triton OOM for large head_dim | Good for variable head dims |
| eager | neither set | None | ✅ | Slowest, always works |
**Check model support**: Look at `_supports_flash_attn_2`, `_supports_flex_attn`, `_supports_sdpa` attributes on the model class.
**head_dim gotcha**: The 256 limit is specific to flash-attn CUDA kernels, NOT PyTorch-level. SDPA and flex_attention both handle arbitrary head_dim. Models with `global_head_dim > 256` (Gemma4: 512) must use SDPA or flex.
**flex + compile gotcha**: `torch_compile` with flex_attention can hit Triton shared memory OOM for large head_dim. Falls back to eager per-function (not a crash, but slower). Unsloth disables flex for Gemma4 for this reason.
## Cut Cross Entropy (CCE)
### How CCE patches work
CCE replaces the model's `forward()` with a fused version that computes loss from hidden states + lm_head weight without materializing the full logits tensor. This saves ~`batch × seq_len × vocab_size × dtype_bytes` of VRAM.
### Adding CCE for a new model
1. Check if the model type is in `cut_cross_entropy.transformers.patch.PATCH_FNS`
2. If not, axolotl's generic fallback (`integrations/cut_cross_entropy/__init__.py` `patch_llama_like()`) patches `{Prefix}ForCausalLM.forward` with `cce_forward`
3. For multimodal models (`ForConditionalGeneration`), a model-specific patch is needed in `ml-cross-entropy` repo
4. The multimodal `cce_forward` must accept all extra kwargs (pixel_values, mm_token_type_ids, etc.) and pop any that would conflict before calling `self.model()`
### Common CCE pitfall
If CCE appears active (log says "Applying Cut Cross Entropy") but peak VRAM doesn't decrease, check which class was patched. If the model loads as `ForConditionalGeneration` but CCE patched `ForCausalLM`, the patch is silently inactive.
## MoE Models
### Dense MLP vs MoE experts
Some MoE models (e.g., Gemma4) have BOTH dense MLP layers and MoE expert layers at every decoder layer:
- `gate_proj/up_proj/down_proj` → targets the **dense MLP** (`Gemma4TextMLP`)
- `experts.gate_up_proj/experts.down_proj` → targets the **MoE experts** (`Gemma4TextExperts`)
LoRA on the dense MLP works normally. Expert LoRA via `lora_target_parameters` requires PEFT support for the specific expert module type (may warn "Unsupported layer type").
### ScatterMoE kernels
`use_scattermoe: true` with `experts_implementation: scattermoe` registers fused expert kernels via transformers' `ExpertsInterface`. Significant speedup for MoE models. Requires the kernels plugin:
```yaml
plugins:
- axolotl.integrations.kernels.KernelsPlugin
use_kernels: true
use_scattermoe: true
experts_implementation: scattermoe
```
## Where to Add Model-Specific Fixes
| What | Where | Example |
|------|-------|---------|
| Missing forward inputs | `core/trainers/base.py` `compute_loss()` | mm_token_type_ids injection |
| Attention mask fixes | `core/trainers/base.py` `compute_loss()` | Sample packing mask removal |
| Loss logging fixes | `core/trainers/base.py` `__init__()` | model_accepts_loss_kwargs override |
| PEFT/LoRA patches | `loaders/adapter.py` | ClippableLinear redirect |
| Attention patches | `monkeypatch/attention/` | FA4 tuple fix |
| Model-specific patches | `loaders/patch_manager.py` `_apply_model_specific_patches()` | Llama4, Kimi, NemotronH |
| CCE patches | `ml-cross-entropy` repo `transformers/` | Per-model cce_forward |
| Example configs | `examples/<model>/` | Validated YAML |
| Config validation | `utils/schemas/validation.py` | Compatibility checks |

View File

@@ -19,6 +19,8 @@ TOPICS = {
"preference_tuning": "docs/agents/preference_tuning.md", "preference_tuning": "docs/agents/preference_tuning.md",
"reward_modelling": "docs/agents/reward_modelling.md", "reward_modelling": "docs/agents/reward_modelling.md",
"pretraining": "docs/agents/pretraining.md", "pretraining": "docs/agents/pretraining.md",
"model_architectures": "docs/agents/model_architectures.md",
"new_model_support": "docs/agents/new_model_support.md",
} }

View File

@@ -404,7 +404,9 @@ class AxolotlTrainer(
# Gemma4 requires mm_token_type_ids during training (even for text-only). # Gemma4 requires mm_token_type_ids during training (even for text-only).
# Inject zeros (= text token type) when not provided by the data collator. # Inject zeros (= text token type) when not provided by the data collator.
_model_type = getattr(getattr(model, "config", None), "model_type", None) # Use unwrap_model to handle DDP/FSDP wrappers that don't proxy .config.
_unwrapped = self.accelerator.unwrap_model(model)
_model_type = getattr(getattr(_unwrapped, "config", None), "model_type", None)
if ( if (
"mm_token_type_ids" not in inputs "mm_token_type_ids" not in inputs
and "input_ids" in inputs and "input_ids" in inputs
@@ -445,6 +447,21 @@ class AxolotlTrainer(
LOG.info("Running evaluation step...") LOG.info("Running evaluation step...")
return super().evaluate(*args, **kwargs) return super().evaluate(*args, **kwargs)
@override
def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
# Gemma4 requires mm_token_type_ids even during evaluation.
_unwrapped = self.accelerator.unwrap_model(model)
_model_type = getattr(getattr(_unwrapped, "config", None), "model_type", None)
if (
"mm_token_type_ids" not in inputs
and "input_ids" in inputs
and _model_type == "gemma4"
):
inputs["mm_token_type_ids"] = torch.zeros_like(inputs["input_ids"])
return super().prediction_step(
model, inputs, prediction_loss_only, ignore_keys=ignore_keys
)
@staticmethod @staticmethod
def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None): def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None):
concatenated_batch = {} concatenated_batch = {}

View File

@@ -36,7 +36,7 @@ from axolotl.telemetry.manager import TelemetryManager
from axolotl.utils.ctx_managers.sequence_parallel import SequenceParallelContextManager from axolotl.utils.ctx_managers.sequence_parallel import SequenceParallelContextManager
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import cleanup_distributed from axolotl.utils.distributed import cleanup_distributed
from axolotl.utils.freeze import freeze_layers_except from axolotl.utils.freeze import freeze_layers_except, freeze_mm_modules
from axolotl.utils.logging import get_logger from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.enums import RLType from axolotl.utils.schemas.enums import RLType
from axolotl.utils.train import determine_last_checkpoint from axolotl.utils.train import determine_last_checkpoint
@@ -114,6 +114,10 @@ def setup_model_and_tokenizer(
): ):
model.enable_input_require_grads() model.enable_input_require_grads()
# Freeze multimodal modules for text-only training of multimodal models
if cfg.freeze_mm_modules:
freeze_mm_modules(model)
return model, tokenizer, peft_config, processor return model, tokenizer, peft_config, processor

View File

@@ -268,6 +268,37 @@ def normalize_config(cfg):
): ):
cfg.gradient_checkpointing_kwargs = {"use_reentrant": True} cfg.gradient_checkpointing_kwargs = {"use_reentrant": True}
# Gemma4 requires use_reentrant=False for DDP (shared per-layer norms cause
# "marked ready twice" errors with reentrant checkpointing) and
# ddp_find_unused_parameters=True (per_layer_projection LoRA params may not
# receive gradients on every step).
if cfg.model_config_type == "gemma4":
if cfg.gradient_checkpointing:
if cfg.gradient_checkpointing_kwargs is None:
cfg.gradient_checkpointing_kwargs = {}
if cfg.gradient_checkpointing_kwargs.get("use_reentrant") is not False:
LOG.warning(
"Gemma4 requires use_reentrant=False for gradient checkpointing "
"in distributed training. Setting use_reentrant=False."
)
cfg.gradient_checkpointing_kwargs["use_reentrant"] = False
if cfg.ddp and cfg.ddp_find_unused_parameters is None:
if cfg.activation_offloading is True:
# activation_offloading uses checkpoint wrappers that conflict
# with find_unused_parameters (causes "marked ready twice").
# Use freeze_mm_modules instead to eliminate unused params.
LOG.info(
"Gemma4 + DDP + activation_offloading: skipping "
"ddp_find_unused_parameters (use freeze_mm_modules to "
"handle unused vision/audio params)."
)
else:
LOG.warning(
"Gemma4 requires ddp_find_unused_parameters=True for DDP. "
"Auto-enabling."
)
cfg.ddp_find_unused_parameters = True
log_gpu_memory_usage(LOG, "baseline", cfg.device) log_gpu_memory_usage(LOG, "baseline", cfg.device)

View File

@@ -10,6 +10,44 @@ from axolotl.utils.logging import get_logger
LOG = get_logger(__name__) LOG = get_logger(__name__)
# Top-level module name prefixes that belong to vision/audio/multimodal encoders
# rather than the language backbone. These are matched against the first component
# of each ``named_parameter`` path (e.g. "model.vision_tower." -> "vision_tower").
_MM_MODULE_PREFIXES = (
"vision_tower",
"vision_model",
"vision_encoder",
"embed_vision",
"multi_modal_projector",
"visual",
"audio_tower",
"audio_model",
"embed_audio",
)
def freeze_mm_modules(model):
"""Freeze all vision/audio/multimodal-projector parameters.
Iterates over ``model.named_parameters()`` and sets ``requires_grad = False``
for any parameter whose name contains a known vision/audio module prefix.
This is useful when fine-tuning only the language backbone of a multimodal
model and avoids the need for ``ddp_find_unused_parameters=True``.
"""
frozen_count = 0
for name, param in model.named_parameters():
# Check if any path component matches a vision/audio prefix
parts = name.split(".")
if any(part in _MM_MODULE_PREFIXES for part in parts):
if param.requires_grad:
param.requires_grad = False
frozen_count += 1
if is_main_process():
LOG.debug(f"freeze_mm_modules: froze {name}")
if is_main_process():
LOG.info(f"freeze_mm_modules: froze {frozen_count} vision/audio parameters")
def freeze_layers_except(model, regex_patterns): def freeze_layers_except(model, regex_patterns):
""" """

View File

@@ -578,6 +578,17 @@ class AxolotlInputConfig(
}, },
) )
freeze_mm_modules: bool | None = Field(
default=None,
json_schema_extra={
"description": "Freeze multimodal encoder parameters (vision, audio, etc.) for "
"text-only training of multimodal models. When True, parameters belonging to "
"vision towers, audio towers, multimodal projectors, and similar non-language "
"modules are frozen (requires_grad=False). This allows DDP training without "
"ddp_find_unused_parameters=True."
},
)
unfrozen_parameters: list[str] | None = Field( unfrozen_parameters: list[str] | None = Field(
default=None, default=None,
json_schema_extra={ json_schema_extra={