fix ddp/fsdp w gemma4 (#3584)
* fix ddp/fsdp w gemma4 * address pr comments * activation offloading fix and update agent docs for gemma4
This commit is contained in:
@@ -19,6 +19,8 @@ TOPICS = {
|
||||
"preference_tuning": "docs/agents/preference_tuning.md",
|
||||
"reward_modelling": "docs/agents/reward_modelling.md",
|
||||
"pretraining": "docs/agents/pretraining.md",
|
||||
"model_architectures": "docs/agents/model_architectures.md",
|
||||
"new_model_support": "docs/agents/new_model_support.md",
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -404,7 +404,9 @@ class AxolotlTrainer(
|
||||
|
||||
# Gemma4 requires mm_token_type_ids during training (even for text-only).
|
||||
# Inject zeros (= text token type) when not provided by the data collator.
|
||||
_model_type = getattr(getattr(model, "config", None), "model_type", None)
|
||||
# Use unwrap_model to handle DDP/FSDP wrappers that don't proxy .config.
|
||||
_unwrapped = self.accelerator.unwrap_model(model)
|
||||
_model_type = getattr(getattr(_unwrapped, "config", None), "model_type", None)
|
||||
if (
|
||||
"mm_token_type_ids" not in inputs
|
||||
and "input_ids" in inputs
|
||||
@@ -445,6 +447,21 @@ class AxolotlTrainer(
|
||||
LOG.info("Running evaluation step...")
|
||||
return super().evaluate(*args, **kwargs)
|
||||
|
||||
@override
|
||||
def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
|
||||
# Gemma4 requires mm_token_type_ids even during evaluation.
|
||||
_unwrapped = self.accelerator.unwrap_model(model)
|
||||
_model_type = getattr(getattr(_unwrapped, "config", None), "model_type", None)
|
||||
if (
|
||||
"mm_token_type_ids" not in inputs
|
||||
and "input_ids" in inputs
|
||||
and _model_type == "gemma4"
|
||||
):
|
||||
inputs["mm_token_type_ids"] = torch.zeros_like(inputs["input_ids"])
|
||||
return super().prediction_step(
|
||||
model, inputs, prediction_loss_only, ignore_keys=ignore_keys
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None):
|
||||
concatenated_batch = {}
|
||||
|
||||
@@ -36,7 +36,7 @@ from axolotl.telemetry.manager import TelemetryManager
|
||||
from axolotl.utils.ctx_managers.sequence_parallel import SequenceParallelContextManager
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import cleanup_distributed
|
||||
from axolotl.utils.freeze import freeze_layers_except
|
||||
from axolotl.utils.freeze import freeze_layers_except, freeze_mm_modules
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.schemas.enums import RLType
|
||||
from axolotl.utils.train import determine_last_checkpoint
|
||||
@@ -114,6 +114,10 @@ def setup_model_and_tokenizer(
|
||||
):
|
||||
model.enable_input_require_grads()
|
||||
|
||||
# Freeze multimodal modules for text-only training of multimodal models
|
||||
if cfg.freeze_mm_modules:
|
||||
freeze_mm_modules(model)
|
||||
|
||||
return model, tokenizer, peft_config, processor
|
||||
|
||||
|
||||
|
||||
@@ -268,6 +268,37 @@ def normalize_config(cfg):
|
||||
):
|
||||
cfg.gradient_checkpointing_kwargs = {"use_reentrant": True}
|
||||
|
||||
# Gemma4 requires use_reentrant=False for DDP (shared per-layer norms cause
|
||||
# "marked ready twice" errors with reentrant checkpointing) and
|
||||
# ddp_find_unused_parameters=True (per_layer_projection LoRA params may not
|
||||
# receive gradients on every step).
|
||||
if cfg.model_config_type == "gemma4":
|
||||
if cfg.gradient_checkpointing:
|
||||
if cfg.gradient_checkpointing_kwargs is None:
|
||||
cfg.gradient_checkpointing_kwargs = {}
|
||||
if cfg.gradient_checkpointing_kwargs.get("use_reentrant") is not False:
|
||||
LOG.warning(
|
||||
"Gemma4 requires use_reentrant=False for gradient checkpointing "
|
||||
"in distributed training. Setting use_reentrant=False."
|
||||
)
|
||||
cfg.gradient_checkpointing_kwargs["use_reentrant"] = False
|
||||
if cfg.ddp and cfg.ddp_find_unused_parameters is None:
|
||||
if cfg.activation_offloading is True:
|
||||
# activation_offloading uses checkpoint wrappers that conflict
|
||||
# with find_unused_parameters (causes "marked ready twice").
|
||||
# Use freeze_mm_modules instead to eliminate unused params.
|
||||
LOG.info(
|
||||
"Gemma4 + DDP + activation_offloading: skipping "
|
||||
"ddp_find_unused_parameters (use freeze_mm_modules to "
|
||||
"handle unused vision/audio params)."
|
||||
)
|
||||
else:
|
||||
LOG.warning(
|
||||
"Gemma4 requires ddp_find_unused_parameters=True for DDP. "
|
||||
"Auto-enabling."
|
||||
)
|
||||
cfg.ddp_find_unused_parameters = True
|
||||
|
||||
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
||||
|
||||
|
||||
|
||||
@@ -10,6 +10,44 @@ from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
# Top-level module name prefixes that belong to vision/audio/multimodal encoders
|
||||
# rather than the language backbone. These are matched against the first component
|
||||
# of each ``named_parameter`` path (e.g. "model.vision_tower." -> "vision_tower").
|
||||
_MM_MODULE_PREFIXES = (
|
||||
"vision_tower",
|
||||
"vision_model",
|
||||
"vision_encoder",
|
||||
"embed_vision",
|
||||
"multi_modal_projector",
|
||||
"visual",
|
||||
"audio_tower",
|
||||
"audio_model",
|
||||
"embed_audio",
|
||||
)
|
||||
|
||||
|
||||
def freeze_mm_modules(model):
|
||||
"""Freeze all vision/audio/multimodal-projector parameters.
|
||||
|
||||
Iterates over ``model.named_parameters()`` and sets ``requires_grad = False``
|
||||
for any parameter whose name contains a known vision/audio module prefix.
|
||||
This is useful when fine-tuning only the language backbone of a multimodal
|
||||
model and avoids the need for ``ddp_find_unused_parameters=True``.
|
||||
"""
|
||||
frozen_count = 0
|
||||
for name, param in model.named_parameters():
|
||||
# Check if any path component matches a vision/audio prefix
|
||||
parts = name.split(".")
|
||||
if any(part in _MM_MODULE_PREFIXES for part in parts):
|
||||
if param.requires_grad:
|
||||
param.requires_grad = False
|
||||
frozen_count += 1
|
||||
if is_main_process():
|
||||
LOG.debug(f"freeze_mm_modules: froze {name}")
|
||||
|
||||
if is_main_process():
|
||||
LOG.info(f"freeze_mm_modules: froze {frozen_count} vision/audio parameters")
|
||||
|
||||
|
||||
def freeze_layers_except(model, regex_patterns):
|
||||
"""
|
||||
|
||||
@@ -578,6 +578,17 @@ class AxolotlInputConfig(
|
||||
},
|
||||
)
|
||||
|
||||
freeze_mm_modules: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Freeze multimodal encoder parameters (vision, audio, etc.) for "
|
||||
"text-only training of multimodal models. When True, parameters belonging to "
|
||||
"vision towers, audio towers, multimodal projectors, and similar non-language "
|
||||
"modules are frozen (requires_grad=False). This allows DDP training without "
|
||||
"ddp_find_unused_parameters=True."
|
||||
},
|
||||
)
|
||||
|
||||
unfrozen_parameters: list[str] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
|
||||
Reference in New Issue
Block a user