fix ddp/fsdp w gemma4 (#3584)

* fix ddp/fsdp w gemma4

* address pr comments

* activation offloading fix and update agent docs for gemma4
This commit is contained in:
Wing Lian
2026-04-09 20:02:36 -07:00
committed by GitHub
parent 7daf7d96f1
commit 4ef608dda3
9 changed files with 398 additions and 2 deletions

View File

@@ -19,6 +19,8 @@ TOPICS = {
"preference_tuning": "docs/agents/preference_tuning.md",
"reward_modelling": "docs/agents/reward_modelling.md",
"pretraining": "docs/agents/pretraining.md",
"model_architectures": "docs/agents/model_architectures.md",
"new_model_support": "docs/agents/new_model_support.md",
}

View File

@@ -404,7 +404,9 @@ class AxolotlTrainer(
# Gemma4 requires mm_token_type_ids during training (even for text-only).
# Inject zeros (= text token type) when not provided by the data collator.
_model_type = getattr(getattr(model, "config", None), "model_type", None)
# Use unwrap_model to handle DDP/FSDP wrappers that don't proxy .config.
_unwrapped = self.accelerator.unwrap_model(model)
_model_type = getattr(getattr(_unwrapped, "config", None), "model_type", None)
if (
"mm_token_type_ids" not in inputs
and "input_ids" in inputs
@@ -445,6 +447,21 @@ class AxolotlTrainer(
LOG.info("Running evaluation step...")
return super().evaluate(*args, **kwargs)
@override
def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
# Gemma4 requires mm_token_type_ids even during evaluation.
_unwrapped = self.accelerator.unwrap_model(model)
_model_type = getattr(getattr(_unwrapped, "config", None), "model_type", None)
if (
"mm_token_type_ids" not in inputs
and "input_ids" in inputs
and _model_type == "gemma4"
):
inputs["mm_token_type_ids"] = torch.zeros_like(inputs["input_ids"])
return super().prediction_step(
model, inputs, prediction_loss_only, ignore_keys=ignore_keys
)
@staticmethod
def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None):
concatenated_batch = {}

View File

@@ -36,7 +36,7 @@ from axolotl.telemetry.manager import TelemetryManager
from axolotl.utils.ctx_managers.sequence_parallel import SequenceParallelContextManager
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import cleanup_distributed
from axolotl.utils.freeze import freeze_layers_except
from axolotl.utils.freeze import freeze_layers_except, freeze_mm_modules
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.enums import RLType
from axolotl.utils.train import determine_last_checkpoint
@@ -114,6 +114,10 @@ def setup_model_and_tokenizer(
):
model.enable_input_require_grads()
# Freeze multimodal modules for text-only training of multimodal models
if cfg.freeze_mm_modules:
freeze_mm_modules(model)
return model, tokenizer, peft_config, processor

View File

@@ -268,6 +268,37 @@ def normalize_config(cfg):
):
cfg.gradient_checkpointing_kwargs = {"use_reentrant": True}
# Gemma4 requires use_reentrant=False for DDP (shared per-layer norms cause
# "marked ready twice" errors with reentrant checkpointing) and
# ddp_find_unused_parameters=True (per_layer_projection LoRA params may not
# receive gradients on every step).
if cfg.model_config_type == "gemma4":
if cfg.gradient_checkpointing:
if cfg.gradient_checkpointing_kwargs is None:
cfg.gradient_checkpointing_kwargs = {}
if cfg.gradient_checkpointing_kwargs.get("use_reentrant") is not False:
LOG.warning(
"Gemma4 requires use_reentrant=False for gradient checkpointing "
"in distributed training. Setting use_reentrant=False."
)
cfg.gradient_checkpointing_kwargs["use_reentrant"] = False
if cfg.ddp and cfg.ddp_find_unused_parameters is None:
if cfg.activation_offloading is True:
# activation_offloading uses checkpoint wrappers that conflict
# with find_unused_parameters (causes "marked ready twice").
# Use freeze_mm_modules instead to eliminate unused params.
LOG.info(
"Gemma4 + DDP + activation_offloading: skipping "
"ddp_find_unused_parameters (use freeze_mm_modules to "
"handle unused vision/audio params)."
)
else:
LOG.warning(
"Gemma4 requires ddp_find_unused_parameters=True for DDP. "
"Auto-enabling."
)
cfg.ddp_find_unused_parameters = True
log_gpu_memory_usage(LOG, "baseline", cfg.device)

View File

@@ -10,6 +10,44 @@ from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
# Top-level module name prefixes that belong to vision/audio/multimodal encoders
# rather than the language backbone. These are matched against the first component
# of each ``named_parameter`` path (e.g. "model.vision_tower." -> "vision_tower").
_MM_MODULE_PREFIXES = (
"vision_tower",
"vision_model",
"vision_encoder",
"embed_vision",
"multi_modal_projector",
"visual",
"audio_tower",
"audio_model",
"embed_audio",
)
def freeze_mm_modules(model):
"""Freeze all vision/audio/multimodal-projector parameters.
Iterates over ``model.named_parameters()`` and sets ``requires_grad = False``
for any parameter whose name contains a known vision/audio module prefix.
This is useful when fine-tuning only the language backbone of a multimodal
model and avoids the need for ``ddp_find_unused_parameters=True``.
"""
frozen_count = 0
for name, param in model.named_parameters():
# Check if any path component matches a vision/audio prefix
parts = name.split(".")
if any(part in _MM_MODULE_PREFIXES for part in parts):
if param.requires_grad:
param.requires_grad = False
frozen_count += 1
if is_main_process():
LOG.debug(f"freeze_mm_modules: froze {name}")
if is_main_process():
LOG.info(f"freeze_mm_modules: froze {frozen_count} vision/audio parameters")
def freeze_layers_except(model, regex_patterns):
"""

View File

@@ -578,6 +578,17 @@ class AxolotlInputConfig(
},
)
freeze_mm_modules: bool | None = Field(
default=None,
json_schema_extra={
"description": "Freeze multimodal encoder parameters (vision, audio, etc.) for "
"text-only training of multimodal models. When True, parameters belonging to "
"vision towers, audio towers, multimodal projectors, and similar non-language "
"modules are frozen (requires_grad=False). This allows DDP training without "
"ddp_find_unused_parameters=True."
},
)
unfrozen_parameters: list[str] | None = Field(
default=None,
json_schema_extra={