From 4ef608dda31e84f6e0f9ffbb26ca7157cf0daccf Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 9 Apr 2026 20:02:36 -0700 Subject: [PATCH] fix ddp/fsdp w gemma4 (#3584) * fix ddp/fsdp w gemma4 * address pr comments * activation offloading fix and update agent docs for gemma4 --- AGENTS.md | 2 + docs/agents/model_architectures.md | 110 +++++++++++++++ docs/agents/new_model_support.md | 181 +++++++++++++++++++++++++ src/axolotl/cli/agent_docs/__init__.py | 2 + src/axolotl/core/trainers/base.py | 19 ++- src/axolotl/train.py | 6 +- src/axolotl/utils/config/__init__.py | 31 +++++ src/axolotl/utils/freeze.py | 38 ++++++ src/axolotl/utils/schemas/config.py | 11 ++ 9 files changed, 398 insertions(+), 2 deletions(-) create mode 100644 docs/agents/model_architectures.md create mode 100644 docs/agents/new_model_support.md diff --git a/AGENTS.md b/AGENTS.md index 6fb81e506..e9b747ce3 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -38,6 +38,8 @@ Agent-specific references: - [docs/agents/grpo.md](docs/agents/grpo.md) — GRPO online RL with reward functions - [docs/agents/reward_modelling.md](docs/agents/reward_modelling.md) — outcome and process reward models - [docs/agents/pretraining.md](docs/agents/pretraining.md) — continual pretraining +- [docs/agents/model_architectures.md](docs/agents/model_architectures.md) — model-specific quirks (Gemma4, Qwen3.5 MoE, etc.) +- [docs/agents/new_model_support.md](docs/agents/new_model_support.md) — debugging and adding support for new model architectures ## Config Pattern diff --git a/docs/agents/model_architectures.md b/docs/agents/model_architectures.md new file mode 100644 index 000000000..426db4ce9 --- /dev/null +++ b/docs/agents/model_architectures.md @@ -0,0 +1,110 @@ +# Model Architectures — Agent Reference + +Model-specific quirks, required settings, and known issues. Check this before debugging training failures on specific model families. + +## Gemma 4 + +**Models**: `google/gemma-4-26B-A4B` (MoE), `google/gemma-4-31B` (dense), `google/gemma-4-E2B`, `google/gemma-4-E4B` + +**Architecture**: Multimodal wrapper (`Gemma4ForConditionalGeneration`) over a text backbone (`Gemma4TextModel`), with optional vision/audio encoders. All Gemma4 HF repos have `model_type: "gemma4"` — even text-only variants load as multimodal with a vision tower. + +### Required settings + +```yaml +# Always needed for Gemma4: +freeze_mm_modules: true # Freeze vision/audio encoders for text-only training +gradient_checkpointing_kwargs: + use_reentrant: false # Shared per-layer norms cause "marked ready twice" with reentrant + +# LoRA target — restrict to language model only (DO NOT use lora_target_linear: true): +lora_target_modules: 'model.language_model.layers.[\d]+.(_checkpoint_wrapped_module.)?(mlp|self_attn).(up|down|gate|q|k|v|o)_proj' +``` + +### Auto-detection + +Axolotl auto-detects Gemma4 and applies: +- `use_reentrant: false` for gradient checkpointing +- `ddp_find_unused_parameters: true` for DDP (skipped when `activation_offloading: true`) + +### Multi-GPU + +| Strategy | Works? | Notes | +|----------|--------|-------| +| DDP | Yes | Auto-sets `ddp_find_unused_parameters=True` | +| DDP + activation_offloading | Yes | `find_unused_parameters` is skipped (conflicts with checkpoint wrappers) | +| FSDP1 | No | OOM during dequantization/sharding with QLoRA | +| FSDP2 | Yes | Use `Gemma4TextDecoderLayer` (not `Gemma4DecoderLayer`) as wrap class | +| FSDP2 + activation_offloading | Yes | Lowest VRAM (~26 GiB/GPU for 26B-A4B) | + +FSDP2 config: +```yaml +fsdp: + - full_shard + - auto_wrap +fsdp_config: + fsdp_version: 2 + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_transformer_layer_cls_to_wrap: Gemma4TextDecoderLayer +``` + +### MoE (26B-A4B) + +- `enable_moe_block: true`, 256 experts, top-k routing +- No separate `SparseMoeBlock` — MoE is embedded in each decoder layer +- Expert LoRA targets 3D parameter tensors: + ```yaml + lora_target_parameters: + - experts.gate_up_proj + - experts.down_proj + ``` +- ScatterMoE kernel acceleration: + ```yaml + plugins: + - axolotl.integrations.kernels.KernelsPlugin + use_kernels: true + use_scattermoe: true + experts_implementation: scattermoe + ``` + +### Common issues + +| Symptom | Cause | Fix | +|---------|-------|-----| +| `mm_token_type_ids is required` in DDP | `model.config` not accessible through DDP wrapper | Already fixed — `unwrap_model()` in `compute_loss` and `prediction_step` | +| `marked a variable ready twice` in DDP | `ddp_find_unused_parameters=True` + activation_offloading checkpoint wrappers | Auto-handled — `find_unused_parameters` is skipped when `activation_offloading: true` | +| Loss ~12 instead of ~0.5 | Using `lora_target_linear: true` (applies LoRA to vision/audio modules) | Use the regex `lora_target_modules` pattern instead | +| FSDP2 `Could not find Gemma4AudioLayer` | Auto-wrap detects `_no_split_modules` including audio layers that don't exist | Explicitly set `fsdp_transformer_layer_cls_to_wrap: Gemma4TextDecoderLayer` | +| `Gemma4ClippableLinear not supported` by PEFT | Vision tower uses a non-standard linear wrapper | Axolotl patches this automatically via `_patch_peft_clippable_linear()` | + +### E2B/E4B dense models + +These have `hidden_size_per_layer_input: 256` (per-layer input embeddings) and `attention_k_eq_v: False`. Known issue: loss starts higher than expected (~12 vs ~0.5 for 26B). Root cause under investigation — may be related to the per-layer input mechanism or the `Gemma4ForConditionalGeneration` loss computation. + +## Gemma 3 + +**Models**: `google/gemma-3-*` + +- `ddp_find_unused_parameters: true` needed (multimodal unused params) +- `use_reentrant: false` recommended +- Attention mask must be dropped for sample packing (handled automatically) +- Multi-GPU test currently skipped (`tests/e2e/multigpu/test_gemma3.py`) + +## Qwen 3.5 MoE + +**Models**: `Qwen/Qwen3.5-35B-A3B` + +- Hybrid architecture: DeltaNet linear attention (30 layers) + full attention (10 layers) +- 256 experts, 8 active per token +- Known weight scale drift in late DeltaNet layers (36-38) due to AdamW + rare expert interaction +- Fix: `normalize_weight_scales` config to detect and rescale outliers: + ```yaml + normalize_weight_scales: + - name_pattern: 'linear_attn\.conv1d\.weight' + threshold: 1.3 + ``` + +## General MoE Notes + +- `lora_target_linear: true` with multimodal MoE models will apply LoRA to ALL linear modules including vision/audio encoders — use regex `lora_target_modules` to restrict to language model only +- Rare experts get larger effective learning rate from AdamW (small second-moment estimates) — can cause weight drift in recurrent/SSM components. Use `normalize_weight_scales` with `dry_run: true` to detect. +- For ScatterMoE kernel support, set `experts_implementation: scattermoe` and add the KernelsPlugin diff --git a/docs/agents/new_model_support.md b/docs/agents/new_model_support.md new file mode 100644 index 000000000..8e6028896 --- /dev/null +++ b/docs/agents/new_model_support.md @@ -0,0 +1,181 @@ +# New Model Support — Agent Reference + +Guide for debugging and adding support for new model architectures in axolotl. Based on lessons learned from Gemma4, Gemma3, Qwen2-VL, and other multimodal/MoE models. + +## Quick Validation Checklist + +When testing a new model, run through these checks in order: + +1. **Does the model load?** `axolotl preprocess config.yaml` — catches config schema errors +2. **Does LoRA apply?** Check for "Unsupported layer type" warnings from PEFT +3. **Is the initial loss sane?** First-step loss for a pretrained model should be 0.5–2.0 for SFT +4. **Does sample packing work?** Compare loss with `sample_packing: true` vs `false` — should be similar +5. **Is CCE active?** Check for "Applying Cut Cross Entropy" log and verify peak VRAM is lower + +## Loss Debugging + +### Expected initial loss +A pretrained model doing SFT should start with loss roughly in the 0.5–2.0 range. If loss starts above 3.0, something is wrong. If it's near `log(vocab_size)` (≈ 12 for 262K vocab), the model is predicting at random — attention masking or model weights are broken. + +### Direct comparison technique +The fastest way to isolate a loss issue — bypass the trainer entirely: + +```python +# Load model via axolotl's pipeline (applies all patches) +from axolotl.cli.config import load_cfg +from axolotl.utils.config import normalize_config, prepare_plugins +from axolotl.loaders.tokenizer import load_tokenizer +from axolotl.loaders.model import ModelLoader + +cfg = load_cfg("your_config.yaml") +normalize_config(cfg) +prepare_plugins(cfg) +tokenizer = load_tokenizer(cfg) +model, _ = ModelLoader(cfg, tokenizer).load() + +# Forward pass on preprocessed data +model.train() +out = model(input_ids, labels=labels) +print(f"Direct loss: {out.loss.item()}") # Compare to trainer's reported loss +``` + +If direct loss is correct (~1.0) but trainer reports 3–4x higher, check `model_accepts_loss_kwargs` (see below). + +### `model_accepts_loss_kwargs` inflation +HF Trainer checks if the model's `forward()` has `**kwargs` and sets `model_accepts_loss_kwargs=True`. This changes loss normalization: the trainer does NOT divide loss by `gradient_accumulation_steps` before logging. The gradient is correct — only the logged loss is inflated. + +**Symptom**: Logged loss ≈ actual_loss × gradient_accumulation_steps. + +**Which models are affected**: Any model with `**kwargs` in forward (common in multimodal models for extra inputs like `mm_token_type_ids`, `pixel_values`, etc.). + +**Fix location**: `src/axolotl/core/trainers/base.py` `__init__()` — after `super().__init__()`, check if the unwrapped model actually has `num_items_in_batch` in its forward signature. If not, set `self.model_accepts_loss_kwargs = False`. + +## Multimodal Models (ForConditionalGeneration) + +Many recent models use `ForConditionalGeneration` as the top-level class, not `ForCausalLM`: +- Gemma3 → `Gemma3ForConditionalGeneration` +- Gemma4 → `Gemma4ForConditionalGeneration` +- Qwen2-VL → `Qwen2VLForConditionalGeneration` +- LLaVA → `LlavaForConditionalGeneration` + +### Why this matters + +| Component | Targets `ForCausalLM` | Needs `ForConditionalGeneration` | +|-----------|----------------------|--------------------------------| +| CCE patches | ✅ (default) | ❌ silently inactive if not patched | +| PEFT LoRA | ✅ | May fail on custom layer types | +| HF Trainer label handling | ✅ | May need extra inputs | + +### Required extra inputs +Multimodal models require special inputs during training even for text-only data: + +| Model | Required Input | Value for Text-Only | +|-------|---------------|-------------------| +| Gemma4 | `mm_token_type_ids` | `torch.zeros_like(input_ids)` | +| Gemma3 | `token_type_ids` | `torch.zeros_like(input_ids)` | + +Auto-inject in `compute_loss()` when not provided by the data collator. See `core/trainers/base.py`. + +### Custom layer types and PEFT +Vision towers often use custom module wrappers that PEFT doesn't support: + +| Model | Custom Layer | Wraps | Fix | +|-------|-------------|-------|-----| +| Gemma4 | `Gemma4ClippableLinear` | `nn.Linear` | Redirect to `.linear` child | + +Fix location: `src/axolotl/loaders/adapter.py` `_patch_peft_clippable_linear()`. + +## Sample Packing + +### How packed sequence detection works (transformers ≥ 5.x) +`transformers.masking_utils._preprocess_mask_arguments()` detects packed sequences from `position_ids` resets. But **only when `attention_mask is None`**: + +```python +# From masking_utils.py: +if position_ids is not None and attention_mask is None and past_key_values is None: + packed_sequence_mask = find_packed_sequence_indices(position_ids) +``` + +If the collator provides an all-ones `attention_mask`, packing detection is **skipped** and the model builds a single causal mask spanning all packed sequences → cross-sequence attention leakage → very high loss. + +### Fix for models using `create_causal_mask_mapping` +For Gemma3, Gemma4, and similar models that use the new transformers masking system, remove `attention_mask` from inputs when sample packing is active: + +```python +# In compute_loss(): +if ( + self.args.sample_packing + and model_type in ("gemma4", "gemma3") + and "attention_mask" in inputs + and "position_ids" in inputs +): + del inputs["attention_mask"] +``` + +Fix location: `src/axolotl/core/trainers/base.py` `compute_loss()`. + +### Models that DON'T need this fix +Older models that use `_prepare_4d_causal_attention_mask` (Llama, Mistral, Qwen2, etc.) handle sample packing via axolotl's multipack attention monkeypatch instead. Only models using the new `create_causal_mask_mapping` / `create_causal_mask` masking system need the `attention_mask` removal. + +## Attention Backend Selection + +| Backend | Config | head_dim limit | torch_compile | Notes | +|---------|--------|---------------|---------------|-------| +| FA2 | `flash_attention: true` | 256 | ✅ | Fastest when supported | +| FA4 | auto with `flash_attention: true` | 256 (SM90+) | ✅ | Auto-detected on H100+ | +| SDPA | `sdp_attention: true` | None | ✅ | Universal fallback | +| flex | `flex_attention: true` | None | ⚠️ Triton OOM for large head_dim | Good for variable head dims | +| eager | neither set | None | ✅ | Slowest, always works | + +**Check model support**: Look at `_supports_flash_attn_2`, `_supports_flex_attn`, `_supports_sdpa` attributes on the model class. + +**head_dim gotcha**: The 256 limit is specific to flash-attn CUDA kernels, NOT PyTorch-level. SDPA and flex_attention both handle arbitrary head_dim. Models with `global_head_dim > 256` (Gemma4: 512) must use SDPA or flex. + +**flex + compile gotcha**: `torch_compile` with flex_attention can hit Triton shared memory OOM for large head_dim. Falls back to eager per-function (not a crash, but slower). Unsloth disables flex for Gemma4 for this reason. + +## Cut Cross Entropy (CCE) + +### How CCE patches work +CCE replaces the model's `forward()` with a fused version that computes loss from hidden states + lm_head weight without materializing the full logits tensor. This saves ~`batch × seq_len × vocab_size × dtype_bytes` of VRAM. + +### Adding CCE for a new model +1. Check if the model type is in `cut_cross_entropy.transformers.patch.PATCH_FNS` +2. If not, axolotl's generic fallback (`integrations/cut_cross_entropy/__init__.py` `patch_llama_like()`) patches `{Prefix}ForCausalLM.forward` with `cce_forward` +3. For multimodal models (`ForConditionalGeneration`), a model-specific patch is needed in `ml-cross-entropy` repo +4. The multimodal `cce_forward` must accept all extra kwargs (pixel_values, mm_token_type_ids, etc.) and pop any that would conflict before calling `self.model()` + +### Common CCE pitfall +If CCE appears active (log says "Applying Cut Cross Entropy") but peak VRAM doesn't decrease, check which class was patched. If the model loads as `ForConditionalGeneration` but CCE patched `ForCausalLM`, the patch is silently inactive. + +## MoE Models + +### Dense MLP vs MoE experts +Some MoE models (e.g., Gemma4) have BOTH dense MLP layers and MoE expert layers at every decoder layer: +- `gate_proj/up_proj/down_proj` → targets the **dense MLP** (`Gemma4TextMLP`) +- `experts.gate_up_proj/experts.down_proj` → targets the **MoE experts** (`Gemma4TextExperts`) + +LoRA on the dense MLP works normally. Expert LoRA via `lora_target_parameters` requires PEFT support for the specific expert module type (may warn "Unsupported layer type"). + +### ScatterMoE kernels +`use_scattermoe: true` with `experts_implementation: scattermoe` registers fused expert kernels via transformers' `ExpertsInterface`. Significant speedup for MoE models. Requires the kernels plugin: +```yaml +plugins: + - axolotl.integrations.kernels.KernelsPlugin +use_kernels: true +use_scattermoe: true +experts_implementation: scattermoe +``` + +## Where to Add Model-Specific Fixes + +| What | Where | Example | +|------|-------|---------| +| Missing forward inputs | `core/trainers/base.py` `compute_loss()` | mm_token_type_ids injection | +| Attention mask fixes | `core/trainers/base.py` `compute_loss()` | Sample packing mask removal | +| Loss logging fixes | `core/trainers/base.py` `__init__()` | model_accepts_loss_kwargs override | +| PEFT/LoRA patches | `loaders/adapter.py` | ClippableLinear redirect | +| Attention patches | `monkeypatch/attention/` | FA4 tuple fix | +| Model-specific patches | `loaders/patch_manager.py` `_apply_model_specific_patches()` | Llama4, Kimi, NemotronH | +| CCE patches | `ml-cross-entropy` repo `transformers/` | Per-model cce_forward | +| Example configs | `examples//` | Validated YAML | +| Config validation | `utils/schemas/validation.py` | Compatibility checks | diff --git a/src/axolotl/cli/agent_docs/__init__.py b/src/axolotl/cli/agent_docs/__init__.py index d229184c0..14dbff32d 100644 --- a/src/axolotl/cli/agent_docs/__init__.py +++ b/src/axolotl/cli/agent_docs/__init__.py @@ -19,6 +19,8 @@ TOPICS = { "preference_tuning": "docs/agents/preference_tuning.md", "reward_modelling": "docs/agents/reward_modelling.md", "pretraining": "docs/agents/pretraining.md", + "model_architectures": "docs/agents/model_architectures.md", + "new_model_support": "docs/agents/new_model_support.md", } diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 650a238ec..96183973f 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -404,7 +404,9 @@ class AxolotlTrainer( # Gemma4 requires mm_token_type_ids during training (even for text-only). # Inject zeros (= text token type) when not provided by the data collator. - _model_type = getattr(getattr(model, "config", None), "model_type", None) + # Use unwrap_model to handle DDP/FSDP wrappers that don't proxy .config. + _unwrapped = self.accelerator.unwrap_model(model) + _model_type = getattr(getattr(_unwrapped, "config", None), "model_type", None) if ( "mm_token_type_ids" not in inputs and "input_ids" in inputs @@ -445,6 +447,21 @@ class AxolotlTrainer( LOG.info("Running evaluation step...") return super().evaluate(*args, **kwargs) + @override + def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None): + # Gemma4 requires mm_token_type_ids even during evaluation. + _unwrapped = self.accelerator.unwrap_model(model) + _model_type = getattr(getattr(_unwrapped, "config", None), "model_type", None) + if ( + "mm_token_type_ids" not in inputs + and "input_ids" in inputs + and _model_type == "gemma4" + ): + inputs["mm_token_type_ids"] = torch.zeros_like(inputs["input_ids"]) + return super().prediction_step( + model, inputs, prediction_loss_only, ignore_keys=ignore_keys + ) + @staticmethod def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None): concatenated_batch = {} diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 774aa1cec..23388e40e 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -36,7 +36,7 @@ from axolotl.telemetry.manager import TelemetryManager from axolotl.utils.ctx_managers.sequence_parallel import SequenceParallelContextManager from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import cleanup_distributed -from axolotl.utils.freeze import freeze_layers_except +from axolotl.utils.freeze import freeze_layers_except, freeze_mm_modules from axolotl.utils.logging import get_logger from axolotl.utils.schemas.enums import RLType from axolotl.utils.train import determine_last_checkpoint @@ -114,6 +114,10 @@ def setup_model_and_tokenizer( ): model.enable_input_require_grads() + # Freeze multimodal modules for text-only training of multimodal models + if cfg.freeze_mm_modules: + freeze_mm_modules(model) + return model, tokenizer, peft_config, processor diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index c5bad62de..8e6d3e7e7 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -268,6 +268,37 @@ def normalize_config(cfg): ): cfg.gradient_checkpointing_kwargs = {"use_reentrant": True} + # Gemma4 requires use_reentrant=False for DDP (shared per-layer norms cause + # "marked ready twice" errors with reentrant checkpointing) and + # ddp_find_unused_parameters=True (per_layer_projection LoRA params may not + # receive gradients on every step). + if cfg.model_config_type == "gemma4": + if cfg.gradient_checkpointing: + if cfg.gradient_checkpointing_kwargs is None: + cfg.gradient_checkpointing_kwargs = {} + if cfg.gradient_checkpointing_kwargs.get("use_reentrant") is not False: + LOG.warning( + "Gemma4 requires use_reentrant=False for gradient checkpointing " + "in distributed training. Setting use_reentrant=False." + ) + cfg.gradient_checkpointing_kwargs["use_reentrant"] = False + if cfg.ddp and cfg.ddp_find_unused_parameters is None: + if cfg.activation_offloading is True: + # activation_offloading uses checkpoint wrappers that conflict + # with find_unused_parameters (causes "marked ready twice"). + # Use freeze_mm_modules instead to eliminate unused params. + LOG.info( + "Gemma4 + DDP + activation_offloading: skipping " + "ddp_find_unused_parameters (use freeze_mm_modules to " + "handle unused vision/audio params)." + ) + else: + LOG.warning( + "Gemma4 requires ddp_find_unused_parameters=True for DDP. " + "Auto-enabling." + ) + cfg.ddp_find_unused_parameters = True + log_gpu_memory_usage(LOG, "baseline", cfg.device) diff --git a/src/axolotl/utils/freeze.py b/src/axolotl/utils/freeze.py index 936708f04..e60c49673 100644 --- a/src/axolotl/utils/freeze.py +++ b/src/axolotl/utils/freeze.py @@ -10,6 +10,44 @@ from axolotl.utils.logging import get_logger LOG = get_logger(__name__) +# Top-level module name prefixes that belong to vision/audio/multimodal encoders +# rather than the language backbone. These are matched against the first component +# of each ``named_parameter`` path (e.g. "model.vision_tower." -> "vision_tower"). +_MM_MODULE_PREFIXES = ( + "vision_tower", + "vision_model", + "vision_encoder", + "embed_vision", + "multi_modal_projector", + "visual", + "audio_tower", + "audio_model", + "embed_audio", +) + + +def freeze_mm_modules(model): + """Freeze all vision/audio/multimodal-projector parameters. + + Iterates over ``model.named_parameters()`` and sets ``requires_grad = False`` + for any parameter whose name contains a known vision/audio module prefix. + This is useful when fine-tuning only the language backbone of a multimodal + model and avoids the need for ``ddp_find_unused_parameters=True``. + """ + frozen_count = 0 + for name, param in model.named_parameters(): + # Check if any path component matches a vision/audio prefix + parts = name.split(".") + if any(part in _MM_MODULE_PREFIXES for part in parts): + if param.requires_grad: + param.requires_grad = False + frozen_count += 1 + if is_main_process(): + LOG.debug(f"freeze_mm_modules: froze {name}") + + if is_main_process(): + LOG.info(f"freeze_mm_modules: froze {frozen_count} vision/audio parameters") + def freeze_layers_except(model, regex_patterns): """ diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index d0f588d9b..474c3a349 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -578,6 +578,17 @@ class AxolotlInputConfig( }, ) + freeze_mm_modules: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Freeze multimodal encoder parameters (vision, audio, etc.) for " + "text-only training of multimodal models. When True, parameters belonging to " + "vision towers, audio towers, multimodal projectors, and similar non-language " + "modules are frozen (requires_grad=False). This allows DDP training without " + "ddp_find_unused_parameters=True." + }, + ) + unfrozen_parameters: list[str] | None = Field( default=None, json_schema_extra={