fix ddp/fsdp w gemma4 (#3584)
* fix ddp/fsdp w gemma4 * address pr comments * activation offloading fix and update agent docs for gemma4
This commit is contained in:
@@ -38,6 +38,8 @@ Agent-specific references:
|
|||||||
- [docs/agents/grpo.md](docs/agents/grpo.md) — GRPO online RL with reward functions
|
- [docs/agents/grpo.md](docs/agents/grpo.md) — GRPO online RL with reward functions
|
||||||
- [docs/agents/reward_modelling.md](docs/agents/reward_modelling.md) — outcome and process reward models
|
- [docs/agents/reward_modelling.md](docs/agents/reward_modelling.md) — outcome and process reward models
|
||||||
- [docs/agents/pretraining.md](docs/agents/pretraining.md) — continual pretraining
|
- [docs/agents/pretraining.md](docs/agents/pretraining.md) — continual pretraining
|
||||||
|
- [docs/agents/model_architectures.md](docs/agents/model_architectures.md) — model-specific quirks (Gemma4, Qwen3.5 MoE, etc.)
|
||||||
|
- [docs/agents/new_model_support.md](docs/agents/new_model_support.md) — debugging and adding support for new model architectures
|
||||||
|
|
||||||
## Config Pattern
|
## Config Pattern
|
||||||
|
|
||||||
|
|||||||
110
docs/agents/model_architectures.md
Normal file
110
docs/agents/model_architectures.md
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
# Model Architectures — Agent Reference
|
||||||
|
|
||||||
|
Model-specific quirks, required settings, and known issues. Check this before debugging training failures on specific model families.
|
||||||
|
|
||||||
|
## Gemma 4
|
||||||
|
|
||||||
|
**Models**: `google/gemma-4-26B-A4B` (MoE), `google/gemma-4-31B` (dense), `google/gemma-4-E2B`, `google/gemma-4-E4B`
|
||||||
|
|
||||||
|
**Architecture**: Multimodal wrapper (`Gemma4ForConditionalGeneration`) over a text backbone (`Gemma4TextModel`), with optional vision/audio encoders. All Gemma4 HF repos have `model_type: "gemma4"` — even text-only variants load as multimodal with a vision tower.
|
||||||
|
|
||||||
|
### Required settings
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# Always needed for Gemma4:
|
||||||
|
freeze_mm_modules: true # Freeze vision/audio encoders for text-only training
|
||||||
|
gradient_checkpointing_kwargs:
|
||||||
|
use_reentrant: false # Shared per-layer norms cause "marked ready twice" with reentrant
|
||||||
|
|
||||||
|
# LoRA target — restrict to language model only (DO NOT use lora_target_linear: true):
|
||||||
|
lora_target_modules: 'model.language_model.layers.[\d]+.(_checkpoint_wrapped_module.)?(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||||
|
```
|
||||||
|
|
||||||
|
### Auto-detection
|
||||||
|
|
||||||
|
Axolotl auto-detects Gemma4 and applies:
|
||||||
|
- `use_reentrant: false` for gradient checkpointing
|
||||||
|
- `ddp_find_unused_parameters: true` for DDP (skipped when `activation_offloading: true`)
|
||||||
|
|
||||||
|
### Multi-GPU
|
||||||
|
|
||||||
|
| Strategy | Works? | Notes |
|
||||||
|
|----------|--------|-------|
|
||||||
|
| DDP | Yes | Auto-sets `ddp_find_unused_parameters=True` |
|
||||||
|
| DDP + activation_offloading | Yes | `find_unused_parameters` is skipped (conflicts with checkpoint wrappers) |
|
||||||
|
| FSDP1 | No | OOM during dequantization/sharding with QLoRA |
|
||||||
|
| FSDP2 | Yes | Use `Gemma4TextDecoderLayer` (not `Gemma4DecoderLayer`) as wrap class |
|
||||||
|
| FSDP2 + activation_offloading | Yes | Lowest VRAM (~26 GiB/GPU for 26B-A4B) |
|
||||||
|
|
||||||
|
FSDP2 config:
|
||||||
|
```yaml
|
||||||
|
fsdp:
|
||||||
|
- full_shard
|
||||||
|
- auto_wrap
|
||||||
|
fsdp_config:
|
||||||
|
fsdp_version: 2
|
||||||
|
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
fsdp_transformer_layer_cls_to_wrap: Gemma4TextDecoderLayer
|
||||||
|
```
|
||||||
|
|
||||||
|
### MoE (26B-A4B)
|
||||||
|
|
||||||
|
- `enable_moe_block: true`, 256 experts, top-k routing
|
||||||
|
- No separate `SparseMoeBlock` — MoE is embedded in each decoder layer
|
||||||
|
- Expert LoRA targets 3D parameter tensors:
|
||||||
|
```yaml
|
||||||
|
lora_target_parameters:
|
||||||
|
- experts.gate_up_proj
|
||||||
|
- experts.down_proj
|
||||||
|
```
|
||||||
|
- ScatterMoE kernel acceleration:
|
||||||
|
```yaml
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.kernels.KernelsPlugin
|
||||||
|
use_kernels: true
|
||||||
|
use_scattermoe: true
|
||||||
|
experts_implementation: scattermoe
|
||||||
|
```
|
||||||
|
|
||||||
|
### Common issues
|
||||||
|
|
||||||
|
| Symptom | Cause | Fix |
|
||||||
|
|---------|-------|-----|
|
||||||
|
| `mm_token_type_ids is required` in DDP | `model.config` not accessible through DDP wrapper | Already fixed — `unwrap_model()` in `compute_loss` and `prediction_step` |
|
||||||
|
| `marked a variable ready twice` in DDP | `ddp_find_unused_parameters=True` + activation_offloading checkpoint wrappers | Auto-handled — `find_unused_parameters` is skipped when `activation_offloading: true` |
|
||||||
|
| Loss ~12 instead of ~0.5 | Using `lora_target_linear: true` (applies LoRA to vision/audio modules) | Use the regex `lora_target_modules` pattern instead |
|
||||||
|
| FSDP2 `Could not find Gemma4AudioLayer` | Auto-wrap detects `_no_split_modules` including audio layers that don't exist | Explicitly set `fsdp_transformer_layer_cls_to_wrap: Gemma4TextDecoderLayer` |
|
||||||
|
| `Gemma4ClippableLinear not supported` by PEFT | Vision tower uses a non-standard linear wrapper | Axolotl patches this automatically via `_patch_peft_clippable_linear()` |
|
||||||
|
|
||||||
|
### E2B/E4B dense models
|
||||||
|
|
||||||
|
These have `hidden_size_per_layer_input: 256` (per-layer input embeddings) and `attention_k_eq_v: False`. Known issue: loss starts higher than expected (~12 vs ~0.5 for 26B). Root cause under investigation — may be related to the per-layer input mechanism or the `Gemma4ForConditionalGeneration` loss computation.
|
||||||
|
|
||||||
|
## Gemma 3
|
||||||
|
|
||||||
|
**Models**: `google/gemma-3-*`
|
||||||
|
|
||||||
|
- `ddp_find_unused_parameters: true` needed (multimodal unused params)
|
||||||
|
- `use_reentrant: false` recommended
|
||||||
|
- Attention mask must be dropped for sample packing (handled automatically)
|
||||||
|
- Multi-GPU test currently skipped (`tests/e2e/multigpu/test_gemma3.py`)
|
||||||
|
|
||||||
|
## Qwen 3.5 MoE
|
||||||
|
|
||||||
|
**Models**: `Qwen/Qwen3.5-35B-A3B`
|
||||||
|
|
||||||
|
- Hybrid architecture: DeltaNet linear attention (30 layers) + full attention (10 layers)
|
||||||
|
- 256 experts, 8 active per token
|
||||||
|
- Known weight scale drift in late DeltaNet layers (36-38) due to AdamW + rare expert interaction
|
||||||
|
- Fix: `normalize_weight_scales` config to detect and rescale outliers:
|
||||||
|
```yaml
|
||||||
|
normalize_weight_scales:
|
||||||
|
- name_pattern: 'linear_attn\.conv1d\.weight'
|
||||||
|
threshold: 1.3
|
||||||
|
```
|
||||||
|
|
||||||
|
## General MoE Notes
|
||||||
|
|
||||||
|
- `lora_target_linear: true` with multimodal MoE models will apply LoRA to ALL linear modules including vision/audio encoders — use regex `lora_target_modules` to restrict to language model only
|
||||||
|
- Rare experts get larger effective learning rate from AdamW (small second-moment estimates) — can cause weight drift in recurrent/SSM components. Use `normalize_weight_scales` with `dry_run: true` to detect.
|
||||||
|
- For ScatterMoE kernel support, set `experts_implementation: scattermoe` and add the KernelsPlugin
|
||||||
181
docs/agents/new_model_support.md
Normal file
181
docs/agents/new_model_support.md
Normal file
@@ -0,0 +1,181 @@
|
|||||||
|
# New Model Support — Agent Reference
|
||||||
|
|
||||||
|
Guide for debugging and adding support for new model architectures in axolotl. Based on lessons learned from Gemma4, Gemma3, Qwen2-VL, and other multimodal/MoE models.
|
||||||
|
|
||||||
|
## Quick Validation Checklist
|
||||||
|
|
||||||
|
When testing a new model, run through these checks in order:
|
||||||
|
|
||||||
|
1. **Does the model load?** `axolotl preprocess config.yaml` — catches config schema errors
|
||||||
|
2. **Does LoRA apply?** Check for "Unsupported layer type" warnings from PEFT
|
||||||
|
3. **Is the initial loss sane?** First-step loss for a pretrained model should be 0.5–2.0 for SFT
|
||||||
|
4. **Does sample packing work?** Compare loss with `sample_packing: true` vs `false` — should be similar
|
||||||
|
5. **Is CCE active?** Check for "Applying Cut Cross Entropy" log and verify peak VRAM is lower
|
||||||
|
|
||||||
|
## Loss Debugging
|
||||||
|
|
||||||
|
### Expected initial loss
|
||||||
|
A pretrained model doing SFT should start with loss roughly in the 0.5–2.0 range. If loss starts above 3.0, something is wrong. If it's near `log(vocab_size)` (≈ 12 for 262K vocab), the model is predicting at random — attention masking or model weights are broken.
|
||||||
|
|
||||||
|
### Direct comparison technique
|
||||||
|
The fastest way to isolate a loss issue — bypass the trainer entirely:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Load model via axolotl's pipeline (applies all patches)
|
||||||
|
from axolotl.cli.config import load_cfg
|
||||||
|
from axolotl.utils.config import normalize_config, prepare_plugins
|
||||||
|
from axolotl.loaders.tokenizer import load_tokenizer
|
||||||
|
from axolotl.loaders.model import ModelLoader
|
||||||
|
|
||||||
|
cfg = load_cfg("your_config.yaml")
|
||||||
|
normalize_config(cfg)
|
||||||
|
prepare_plugins(cfg)
|
||||||
|
tokenizer = load_tokenizer(cfg)
|
||||||
|
model, _ = ModelLoader(cfg, tokenizer).load()
|
||||||
|
|
||||||
|
# Forward pass on preprocessed data
|
||||||
|
model.train()
|
||||||
|
out = model(input_ids, labels=labels)
|
||||||
|
print(f"Direct loss: {out.loss.item()}") # Compare to trainer's reported loss
|
||||||
|
```
|
||||||
|
|
||||||
|
If direct loss is correct (~1.0) but trainer reports 3–4x higher, check `model_accepts_loss_kwargs` (see below).
|
||||||
|
|
||||||
|
### `model_accepts_loss_kwargs` inflation
|
||||||
|
HF Trainer checks if the model's `forward()` has `**kwargs` and sets `model_accepts_loss_kwargs=True`. This changes loss normalization: the trainer does NOT divide loss by `gradient_accumulation_steps` before logging. The gradient is correct — only the logged loss is inflated.
|
||||||
|
|
||||||
|
**Symptom**: Logged loss ≈ actual_loss × gradient_accumulation_steps.
|
||||||
|
|
||||||
|
**Which models are affected**: Any model with `**kwargs` in forward (common in multimodal models for extra inputs like `mm_token_type_ids`, `pixel_values`, etc.).
|
||||||
|
|
||||||
|
**Fix location**: `src/axolotl/core/trainers/base.py` `__init__()` — after `super().__init__()`, check if the unwrapped model actually has `num_items_in_batch` in its forward signature. If not, set `self.model_accepts_loss_kwargs = False`.
|
||||||
|
|
||||||
|
## Multimodal Models (ForConditionalGeneration)
|
||||||
|
|
||||||
|
Many recent models use `ForConditionalGeneration` as the top-level class, not `ForCausalLM`:
|
||||||
|
- Gemma3 → `Gemma3ForConditionalGeneration`
|
||||||
|
- Gemma4 → `Gemma4ForConditionalGeneration`
|
||||||
|
- Qwen2-VL → `Qwen2VLForConditionalGeneration`
|
||||||
|
- LLaVA → `LlavaForConditionalGeneration`
|
||||||
|
|
||||||
|
### Why this matters
|
||||||
|
|
||||||
|
| Component | Targets `ForCausalLM` | Needs `ForConditionalGeneration` |
|
||||||
|
|-----------|----------------------|--------------------------------|
|
||||||
|
| CCE patches | ✅ (default) | ❌ silently inactive if not patched |
|
||||||
|
| PEFT LoRA | ✅ | May fail on custom layer types |
|
||||||
|
| HF Trainer label handling | ✅ | May need extra inputs |
|
||||||
|
|
||||||
|
### Required extra inputs
|
||||||
|
Multimodal models require special inputs during training even for text-only data:
|
||||||
|
|
||||||
|
| Model | Required Input | Value for Text-Only |
|
||||||
|
|-------|---------------|-------------------|
|
||||||
|
| Gemma4 | `mm_token_type_ids` | `torch.zeros_like(input_ids)` |
|
||||||
|
| Gemma3 | `token_type_ids` | `torch.zeros_like(input_ids)` |
|
||||||
|
|
||||||
|
Auto-inject in `compute_loss()` when not provided by the data collator. See `core/trainers/base.py`.
|
||||||
|
|
||||||
|
### Custom layer types and PEFT
|
||||||
|
Vision towers often use custom module wrappers that PEFT doesn't support:
|
||||||
|
|
||||||
|
| Model | Custom Layer | Wraps | Fix |
|
||||||
|
|-------|-------------|-------|-----|
|
||||||
|
| Gemma4 | `Gemma4ClippableLinear` | `nn.Linear` | Redirect to `.linear` child |
|
||||||
|
|
||||||
|
Fix location: `src/axolotl/loaders/adapter.py` `_patch_peft_clippable_linear()`.
|
||||||
|
|
||||||
|
## Sample Packing
|
||||||
|
|
||||||
|
### How packed sequence detection works (transformers ≥ 5.x)
|
||||||
|
`transformers.masking_utils._preprocess_mask_arguments()` detects packed sequences from `position_ids` resets. But **only when `attention_mask is None`**:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# From masking_utils.py:
|
||||||
|
if position_ids is not None and attention_mask is None and past_key_values is None:
|
||||||
|
packed_sequence_mask = find_packed_sequence_indices(position_ids)
|
||||||
|
```
|
||||||
|
|
||||||
|
If the collator provides an all-ones `attention_mask`, packing detection is **skipped** and the model builds a single causal mask spanning all packed sequences → cross-sequence attention leakage → very high loss.
|
||||||
|
|
||||||
|
### Fix for models using `create_causal_mask_mapping`
|
||||||
|
For Gemma3, Gemma4, and similar models that use the new transformers masking system, remove `attention_mask` from inputs when sample packing is active:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# In compute_loss():
|
||||||
|
if (
|
||||||
|
self.args.sample_packing
|
||||||
|
and model_type in ("gemma4", "gemma3")
|
||||||
|
and "attention_mask" in inputs
|
||||||
|
and "position_ids" in inputs
|
||||||
|
):
|
||||||
|
del inputs["attention_mask"]
|
||||||
|
```
|
||||||
|
|
||||||
|
Fix location: `src/axolotl/core/trainers/base.py` `compute_loss()`.
|
||||||
|
|
||||||
|
### Models that DON'T need this fix
|
||||||
|
Older models that use `_prepare_4d_causal_attention_mask` (Llama, Mistral, Qwen2, etc.) handle sample packing via axolotl's multipack attention monkeypatch instead. Only models using the new `create_causal_mask_mapping` / `create_causal_mask` masking system need the `attention_mask` removal.
|
||||||
|
|
||||||
|
## Attention Backend Selection
|
||||||
|
|
||||||
|
| Backend | Config | head_dim limit | torch_compile | Notes |
|
||||||
|
|---------|--------|---------------|---------------|-------|
|
||||||
|
| FA2 | `flash_attention: true` | 256 | ✅ | Fastest when supported |
|
||||||
|
| FA4 | auto with `flash_attention: true` | 256 (SM90+) | ✅ | Auto-detected on H100+ |
|
||||||
|
| SDPA | `sdp_attention: true` | None | ✅ | Universal fallback |
|
||||||
|
| flex | `flex_attention: true` | None | ⚠️ Triton OOM for large head_dim | Good for variable head dims |
|
||||||
|
| eager | neither set | None | ✅ | Slowest, always works |
|
||||||
|
|
||||||
|
**Check model support**: Look at `_supports_flash_attn_2`, `_supports_flex_attn`, `_supports_sdpa` attributes on the model class.
|
||||||
|
|
||||||
|
**head_dim gotcha**: The 256 limit is specific to flash-attn CUDA kernels, NOT PyTorch-level. SDPA and flex_attention both handle arbitrary head_dim. Models with `global_head_dim > 256` (Gemma4: 512) must use SDPA or flex.
|
||||||
|
|
||||||
|
**flex + compile gotcha**: `torch_compile` with flex_attention can hit Triton shared memory OOM for large head_dim. Falls back to eager per-function (not a crash, but slower). Unsloth disables flex for Gemma4 for this reason.
|
||||||
|
|
||||||
|
## Cut Cross Entropy (CCE)
|
||||||
|
|
||||||
|
### How CCE patches work
|
||||||
|
CCE replaces the model's `forward()` with a fused version that computes loss from hidden states + lm_head weight without materializing the full logits tensor. This saves ~`batch × seq_len × vocab_size × dtype_bytes` of VRAM.
|
||||||
|
|
||||||
|
### Adding CCE for a new model
|
||||||
|
1. Check if the model type is in `cut_cross_entropy.transformers.patch.PATCH_FNS`
|
||||||
|
2. If not, axolotl's generic fallback (`integrations/cut_cross_entropy/__init__.py` `patch_llama_like()`) patches `{Prefix}ForCausalLM.forward` with `cce_forward`
|
||||||
|
3. For multimodal models (`ForConditionalGeneration`), a model-specific patch is needed in `ml-cross-entropy` repo
|
||||||
|
4. The multimodal `cce_forward` must accept all extra kwargs (pixel_values, mm_token_type_ids, etc.) and pop any that would conflict before calling `self.model()`
|
||||||
|
|
||||||
|
### Common CCE pitfall
|
||||||
|
If CCE appears active (log says "Applying Cut Cross Entropy") but peak VRAM doesn't decrease, check which class was patched. If the model loads as `ForConditionalGeneration` but CCE patched `ForCausalLM`, the patch is silently inactive.
|
||||||
|
|
||||||
|
## MoE Models
|
||||||
|
|
||||||
|
### Dense MLP vs MoE experts
|
||||||
|
Some MoE models (e.g., Gemma4) have BOTH dense MLP layers and MoE expert layers at every decoder layer:
|
||||||
|
- `gate_proj/up_proj/down_proj` → targets the **dense MLP** (`Gemma4TextMLP`)
|
||||||
|
- `experts.gate_up_proj/experts.down_proj` → targets the **MoE experts** (`Gemma4TextExperts`)
|
||||||
|
|
||||||
|
LoRA on the dense MLP works normally. Expert LoRA via `lora_target_parameters` requires PEFT support for the specific expert module type (may warn "Unsupported layer type").
|
||||||
|
|
||||||
|
### ScatterMoE kernels
|
||||||
|
`use_scattermoe: true` with `experts_implementation: scattermoe` registers fused expert kernels via transformers' `ExpertsInterface`. Significant speedup for MoE models. Requires the kernels plugin:
|
||||||
|
```yaml
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.kernels.KernelsPlugin
|
||||||
|
use_kernels: true
|
||||||
|
use_scattermoe: true
|
||||||
|
experts_implementation: scattermoe
|
||||||
|
```
|
||||||
|
|
||||||
|
## Where to Add Model-Specific Fixes
|
||||||
|
|
||||||
|
| What | Where | Example |
|
||||||
|
|------|-------|---------|
|
||||||
|
| Missing forward inputs | `core/trainers/base.py` `compute_loss()` | mm_token_type_ids injection |
|
||||||
|
| Attention mask fixes | `core/trainers/base.py` `compute_loss()` | Sample packing mask removal |
|
||||||
|
| Loss logging fixes | `core/trainers/base.py` `__init__()` | model_accepts_loss_kwargs override |
|
||||||
|
| PEFT/LoRA patches | `loaders/adapter.py` | ClippableLinear redirect |
|
||||||
|
| Attention patches | `monkeypatch/attention/` | FA4 tuple fix |
|
||||||
|
| Model-specific patches | `loaders/patch_manager.py` `_apply_model_specific_patches()` | Llama4, Kimi, NemotronH |
|
||||||
|
| CCE patches | `ml-cross-entropy` repo `transformers/` | Per-model cce_forward |
|
||||||
|
| Example configs | `examples/<model>/` | Validated YAML |
|
||||||
|
| Config validation | `utils/schemas/validation.py` | Compatibility checks |
|
||||||
@@ -19,6 +19,8 @@ TOPICS = {
|
|||||||
"preference_tuning": "docs/agents/preference_tuning.md",
|
"preference_tuning": "docs/agents/preference_tuning.md",
|
||||||
"reward_modelling": "docs/agents/reward_modelling.md",
|
"reward_modelling": "docs/agents/reward_modelling.md",
|
||||||
"pretraining": "docs/agents/pretraining.md",
|
"pretraining": "docs/agents/pretraining.md",
|
||||||
|
"model_architectures": "docs/agents/model_architectures.md",
|
||||||
|
"new_model_support": "docs/agents/new_model_support.md",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -404,7 +404,9 @@ class AxolotlTrainer(
|
|||||||
|
|
||||||
# Gemma4 requires mm_token_type_ids during training (even for text-only).
|
# Gemma4 requires mm_token_type_ids during training (even for text-only).
|
||||||
# Inject zeros (= text token type) when not provided by the data collator.
|
# Inject zeros (= text token type) when not provided by the data collator.
|
||||||
_model_type = getattr(getattr(model, "config", None), "model_type", None)
|
# Use unwrap_model to handle DDP/FSDP wrappers that don't proxy .config.
|
||||||
|
_unwrapped = self.accelerator.unwrap_model(model)
|
||||||
|
_model_type = getattr(getattr(_unwrapped, "config", None), "model_type", None)
|
||||||
if (
|
if (
|
||||||
"mm_token_type_ids" not in inputs
|
"mm_token_type_ids" not in inputs
|
||||||
and "input_ids" in inputs
|
and "input_ids" in inputs
|
||||||
@@ -445,6 +447,21 @@ class AxolotlTrainer(
|
|||||||
LOG.info("Running evaluation step...")
|
LOG.info("Running evaluation step...")
|
||||||
return super().evaluate(*args, **kwargs)
|
return super().evaluate(*args, **kwargs)
|
||||||
|
|
||||||
|
@override
|
||||||
|
def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
|
||||||
|
# Gemma4 requires mm_token_type_ids even during evaluation.
|
||||||
|
_unwrapped = self.accelerator.unwrap_model(model)
|
||||||
|
_model_type = getattr(getattr(_unwrapped, "config", None), "model_type", None)
|
||||||
|
if (
|
||||||
|
"mm_token_type_ids" not in inputs
|
||||||
|
and "input_ids" in inputs
|
||||||
|
and _model_type == "gemma4"
|
||||||
|
):
|
||||||
|
inputs["mm_token_type_ids"] = torch.zeros_like(inputs["input_ids"])
|
||||||
|
return super().prediction_step(
|
||||||
|
model, inputs, prediction_loss_only, ignore_keys=ignore_keys
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None):
|
def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None):
|
||||||
concatenated_batch = {}
|
concatenated_batch = {}
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ from axolotl.telemetry.manager import TelemetryManager
|
|||||||
from axolotl.utils.ctx_managers.sequence_parallel import SequenceParallelContextManager
|
from axolotl.utils.ctx_managers.sequence_parallel import SequenceParallelContextManager
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.distributed import cleanup_distributed
|
from axolotl.utils.distributed import cleanup_distributed
|
||||||
from axolotl.utils.freeze import freeze_layers_except
|
from axolotl.utils.freeze import freeze_layers_except, freeze_mm_modules
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
from axolotl.utils.schemas.enums import RLType
|
from axolotl.utils.schemas.enums import RLType
|
||||||
from axolotl.utils.train import determine_last_checkpoint
|
from axolotl.utils.train import determine_last_checkpoint
|
||||||
@@ -114,6 +114,10 @@ def setup_model_and_tokenizer(
|
|||||||
):
|
):
|
||||||
model.enable_input_require_grads()
|
model.enable_input_require_grads()
|
||||||
|
|
||||||
|
# Freeze multimodal modules for text-only training of multimodal models
|
||||||
|
if cfg.freeze_mm_modules:
|
||||||
|
freeze_mm_modules(model)
|
||||||
|
|
||||||
return model, tokenizer, peft_config, processor
|
return model, tokenizer, peft_config, processor
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -268,6 +268,37 @@ def normalize_config(cfg):
|
|||||||
):
|
):
|
||||||
cfg.gradient_checkpointing_kwargs = {"use_reentrant": True}
|
cfg.gradient_checkpointing_kwargs = {"use_reentrant": True}
|
||||||
|
|
||||||
|
# Gemma4 requires use_reentrant=False for DDP (shared per-layer norms cause
|
||||||
|
# "marked ready twice" errors with reentrant checkpointing) and
|
||||||
|
# ddp_find_unused_parameters=True (per_layer_projection LoRA params may not
|
||||||
|
# receive gradients on every step).
|
||||||
|
if cfg.model_config_type == "gemma4":
|
||||||
|
if cfg.gradient_checkpointing:
|
||||||
|
if cfg.gradient_checkpointing_kwargs is None:
|
||||||
|
cfg.gradient_checkpointing_kwargs = {}
|
||||||
|
if cfg.gradient_checkpointing_kwargs.get("use_reentrant") is not False:
|
||||||
|
LOG.warning(
|
||||||
|
"Gemma4 requires use_reentrant=False for gradient checkpointing "
|
||||||
|
"in distributed training. Setting use_reentrant=False."
|
||||||
|
)
|
||||||
|
cfg.gradient_checkpointing_kwargs["use_reentrant"] = False
|
||||||
|
if cfg.ddp and cfg.ddp_find_unused_parameters is None:
|
||||||
|
if cfg.activation_offloading is True:
|
||||||
|
# activation_offloading uses checkpoint wrappers that conflict
|
||||||
|
# with find_unused_parameters (causes "marked ready twice").
|
||||||
|
# Use freeze_mm_modules instead to eliminate unused params.
|
||||||
|
LOG.info(
|
||||||
|
"Gemma4 + DDP + activation_offloading: skipping "
|
||||||
|
"ddp_find_unused_parameters (use freeze_mm_modules to "
|
||||||
|
"handle unused vision/audio params)."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
LOG.warning(
|
||||||
|
"Gemma4 requires ddp_find_unused_parameters=True for DDP. "
|
||||||
|
"Auto-enabling."
|
||||||
|
)
|
||||||
|
cfg.ddp_find_unused_parameters = True
|
||||||
|
|
||||||
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -10,6 +10,44 @@ from axolotl.utils.logging import get_logger
|
|||||||
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
# Top-level module name prefixes that belong to vision/audio/multimodal encoders
|
||||||
|
# rather than the language backbone. These are matched against the first component
|
||||||
|
# of each ``named_parameter`` path (e.g. "model.vision_tower." -> "vision_tower").
|
||||||
|
_MM_MODULE_PREFIXES = (
|
||||||
|
"vision_tower",
|
||||||
|
"vision_model",
|
||||||
|
"vision_encoder",
|
||||||
|
"embed_vision",
|
||||||
|
"multi_modal_projector",
|
||||||
|
"visual",
|
||||||
|
"audio_tower",
|
||||||
|
"audio_model",
|
||||||
|
"embed_audio",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def freeze_mm_modules(model):
|
||||||
|
"""Freeze all vision/audio/multimodal-projector parameters.
|
||||||
|
|
||||||
|
Iterates over ``model.named_parameters()`` and sets ``requires_grad = False``
|
||||||
|
for any parameter whose name contains a known vision/audio module prefix.
|
||||||
|
This is useful when fine-tuning only the language backbone of a multimodal
|
||||||
|
model and avoids the need for ``ddp_find_unused_parameters=True``.
|
||||||
|
"""
|
||||||
|
frozen_count = 0
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
# Check if any path component matches a vision/audio prefix
|
||||||
|
parts = name.split(".")
|
||||||
|
if any(part in _MM_MODULE_PREFIXES for part in parts):
|
||||||
|
if param.requires_grad:
|
||||||
|
param.requires_grad = False
|
||||||
|
frozen_count += 1
|
||||||
|
if is_main_process():
|
||||||
|
LOG.debug(f"freeze_mm_modules: froze {name}")
|
||||||
|
|
||||||
|
if is_main_process():
|
||||||
|
LOG.info(f"freeze_mm_modules: froze {frozen_count} vision/audio parameters")
|
||||||
|
|
||||||
|
|
||||||
def freeze_layers_except(model, regex_patterns):
|
def freeze_layers_except(model, regex_patterns):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -578,6 +578,17 @@ class AxolotlInputConfig(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
freeze_mm_modules: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Freeze multimodal encoder parameters (vision, audio, etc.) for "
|
||||||
|
"text-only training of multimodal models. When True, parameters belonging to "
|
||||||
|
"vision towers, audio towers, multimodal projectors, and similar non-language "
|
||||||
|
"modules are frozen (requires_grad=False). This allows DDP training without "
|
||||||
|
"ddp_find_unused_parameters=True."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
unfrozen_parameters: list[str] | None = Field(
|
unfrozen_parameters: list[str] | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
|
|||||||
Reference in New Issue
Block a user