From 0e357b5df6c6296cbc66a7c4fcc0eacd5f32c03d Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Sat, 14 Feb 2026 00:32:48 +0700 Subject: [PATCH] fix: load gemma3 as text only model with dynamic weights --- examples/gemma3/gemma-3-4b-qlora.yml | 6 +- examples/qat_nvfp4/Gemma3-12B_baseline.yml | 3 +- examples/qat_nvfp4/Gemma3-12B_qat.yml | 1 + .../qat_nvfp4/Math-Gemma3-12B_baseline.yml | 3 +- examples/qat_nvfp4/Math-Gemma3-12B_qat.yml | 1 + .../qat_nvfp4/Math-Gemma3-27B_baseline.yml | 3 +- examples/qat_nvfp4/Math-Gemma3-27B_qat.yml | 1 + scripts/merge_gemma3_multimodal_weights.py | 225 ++++++++++++++++++ src/axolotl/integrations/gemma3/README.md | 37 +++ src/axolotl/integrations/gemma3/__init__.py | 9 + src/axolotl/integrations/gemma3/args.py | 31 +++ src/axolotl/integrations/gemma3/plugin.py | 88 +++++++ src/axolotl/loaders/utils.py | 7 + src/axolotl/utils/trainer.py | 2 +- 14 files changed, 406 insertions(+), 11 deletions(-) create mode 100644 scripts/merge_gemma3_multimodal_weights.py create mode 100644 src/axolotl/integrations/gemma3/README.md create mode 100644 src/axolotl/integrations/gemma3/__init__.py create mode 100644 src/axolotl/integrations/gemma3/args.py create mode 100644 src/axolotl/integrations/gemma3/plugin.py diff --git a/examples/gemma3/gemma-3-4b-qlora.yml b/examples/gemma3/gemma-3-4b-qlora.yml index 7d44f3c9b..7a78ada54 100644 --- a/examples/gemma3/gemma-3-4b-qlora.yml +++ b/examples/gemma3/gemma-3-4b-qlora.yml @@ -1,8 +1,7 @@ base_model: google/gemma-3-4b-it -# Need to set else transformers tries to load vision too -model_type: Gemma3ForCausalLM -cls_model_config: Gemma3TextConfig +plugins: + - axolotl.integrations.gemma3.Gemma3TextFromMultimodalPlugin load_in_4bit: true @@ -30,7 +29,6 @@ lora_model_dir: sequence_len: 2048 sample_packing: true - lora_r: 32 lora_alpha: 16 lora_dropout: 0 diff --git a/examples/qat_nvfp4/Gemma3-12B_baseline.yml b/examples/qat_nvfp4/Gemma3-12B_baseline.yml index be4e86635..80dd57228 100644 --- a/examples/qat_nvfp4/Gemma3-12B_baseline.yml +++ b/examples/qat_nvfp4/Gemma3-12B_baseline.yml @@ -1,12 +1,11 @@ base_model: google/gemma-3-12b-it -# Automatically upload checkpoint and final model to HF # hub_model_id: username/custom_model_name - load_in_8bit: false load_in_4bit: false strict: false plugins: + - axolotl.integrations.gemma3.Gemma3TextFromMultimodalPlugin - axolotl.integrations.liger.LigerPlugin liger_rope: true diff --git a/examples/qat_nvfp4/Gemma3-12B_qat.yml b/examples/qat_nvfp4/Gemma3-12B_qat.yml index 7fa81163f..4408e6235 100644 --- a/examples/qat_nvfp4/Gemma3-12B_qat.yml +++ b/examples/qat_nvfp4/Gemma3-12B_qat.yml @@ -7,6 +7,7 @@ load_in_4bit: false strict: false plugins: + - axolotl.integrations.gemma3.Gemma3TextFromMultimodalPlugin - axolotl.integrations.liger.LigerPlugin liger_rope: true diff --git a/examples/qat_nvfp4/Math-Gemma3-12B_baseline.yml b/examples/qat_nvfp4/Math-Gemma3-12B_baseline.yml index 9f209515b..76f5c14c8 100644 --- a/examples/qat_nvfp4/Math-Gemma3-12B_baseline.yml +++ b/examples/qat_nvfp4/Math-Gemma3-12B_baseline.yml @@ -1,12 +1,11 @@ base_model: google/gemma-3-12b-it -# Math finetuning configuration for Gemma3-12B # hub_model_id: username/custom_model_name - load_in_8bit: false load_in_4bit: false strict: false plugins: + - axolotl.integrations.gemma3.Gemma3TextFromMultimodalPlugin - axolotl.integrations.liger.LigerPlugin liger_rope: true diff --git a/examples/qat_nvfp4/Math-Gemma3-12B_qat.yml b/examples/qat_nvfp4/Math-Gemma3-12B_qat.yml index ef7e754be..833d6a94f 100644 --- a/examples/qat_nvfp4/Math-Gemma3-12B_qat.yml +++ b/examples/qat_nvfp4/Math-Gemma3-12B_qat.yml @@ -7,6 +7,7 @@ load_in_4bit: false strict: false plugins: + - axolotl.integrations.gemma3.Gemma3TextFromMultimodalPlugin - axolotl.integrations.liger.LigerPlugin liger_rope: true diff --git a/examples/qat_nvfp4/Math-Gemma3-27B_baseline.yml b/examples/qat_nvfp4/Math-Gemma3-27B_baseline.yml index 3a262d342..bb3ab8bc1 100644 --- a/examples/qat_nvfp4/Math-Gemma3-27B_baseline.yml +++ b/examples/qat_nvfp4/Math-Gemma3-27B_baseline.yml @@ -1,12 +1,11 @@ base_model: google/gemma-3-27b-it -# Math finetuning configuration for Gemma3-27B # hub_model_id: username/custom_model_name - load_in_8bit: false load_in_4bit: false strict: false plugins: + - axolotl.integrations.gemma3.Gemma3TextFromMultimodalPlugin - axolotl.integrations.liger.LigerPlugin liger_rope: true diff --git a/examples/qat_nvfp4/Math-Gemma3-27B_qat.yml b/examples/qat_nvfp4/Math-Gemma3-27B_qat.yml index 87016ae9c..17783169e 100644 --- a/examples/qat_nvfp4/Math-Gemma3-27B_qat.yml +++ b/examples/qat_nvfp4/Math-Gemma3-27B_qat.yml @@ -7,6 +7,7 @@ load_in_4bit: false strict: false plugins: + - axolotl.integrations.gemma3.Gemma3TextFromMultimodalPlugin - axolotl.integrations.liger.LigerPlugin liger_rope: true diff --git a/scripts/merge_gemma3_multimodal_weights.py b/scripts/merge_gemma3_multimodal_weights.py new file mode 100644 index 000000000..740380fa0 --- /dev/null +++ b/scripts/merge_gemma3_multimodal_weights.py @@ -0,0 +1,225 @@ +"""Merge trained text-only Gemma3 weights back into a full multimodal checkpoint. + +After training with the Gemma3TextFromMultimodalPlugin, the saved checkpoint +contains only the language model weights (with ``model.language_model.*`` +prefix, reversed by transformers v5's key_mapping on save). + +This script reconstructs a full ``Gemma3ForConditionalGeneration`` checkpoint by +combining the trained language model weights with the original vision tower and +projector weights from the base multimodal model. + +Usage:: + + python scripts/merge_gemma3_multimodal_weights.py \\ + --original-model google/gemma-3-4b-it \\ + --trained-model /path/to/trained/output \\ + --output-dir /path/to/merged +""" + +import argparse +import json +import logging +from pathlib import Path + +import torch +from huggingface_hub import split_torch_state_dict_into_shards +from safetensors.torch import load_file, save_file +from transformers import AutoConfig + +LOG = logging.getLogger(__name__) + + +def collect_safetensors(model_dir: Path) -> dict[str, torch.Tensor]: + """Load and merge all safetensors shard files in a directory.""" + shard_files = sorted(model_dir.glob("*.safetensors")) + if not shard_files: + raise FileNotFoundError(f"No safetensors files found in {model_dir}") + + state_dict: dict[str, torch.Tensor] = {} + for shard in shard_files: + LOG.info("Loading %s", shard.name) + state_dict.update(load_file(str(shard))) + return state_dict + + +def merge( + original_model: str, + trained_model: str, + output_dir: str, + *, + trust_remote_code: bool = False, +) -> None: + original_path = Path(original_model) + trained_path = Path(trained_model) + out_path = Path(output_dir) + out_path.mkdir(parents=True, exist_ok=True) + + # 1. Load the original multimodal checkpoint + LOG.info("Loading original multimodal weights from %s", original_model) + if original_path.is_dir(): + original_sd = collect_safetensors(original_path) + else: + from huggingface_hub import snapshot_download + + cached = Path( + snapshot_download(original_model, allow_patterns=["*.safetensors"]) + ) + original_sd = collect_safetensors(cached) + + # 2. Load trained text-only weights (already reversed to model.language_model.* by + # transformers v5 key_mapping on save) + LOG.info("Loading trained text-only weights from %s", trained_model) + trained_sd = collect_safetensors(trained_path) + + # 3. Classify original keys + lang_keys = {k for k in original_sd if k.startswith("model.language_model.")} + vision_keys = {k for k in original_sd if k.startswith("model.vision_tower.")} + projector_keys = { + k for k in original_sd if k.startswith("model.multi_modal_projector.") + } + other_keys = set(original_sd.keys()) - lang_keys - vision_keys - projector_keys + + LOG.info( + "Original checkpoint: %d language, %d vision, %d projector, %d other keys", + len(lang_keys), + len(vision_keys), + len(projector_keys), + len(other_keys), + ) + + # 4. Classify trained keys (reverse mapping on save gives model.language_model.* prefix) + trained_lang_keys = {k for k in trained_sd if k.startswith("model.language_model.")} + trained_other = set(trained_sd.keys()) - trained_lang_keys + + LOG.info( + "Trained checkpoint: %d language keys, %d other keys", + len(trained_lang_keys), + len(trained_other), + ) + + # 5. Build merged state dict + merged: dict[str, torch.Tensor] = {} + + # Keep vision tower and projector from original + for key in vision_keys | projector_keys: + merged[key] = original_sd[key] + + # Use trained language model weights (overwrite original) + for key in trained_lang_keys: + merged[key] = trained_sd[key] + + # For other trained keys (like lm_head.weight), use trained version + for key in trained_other: + merged[key] = trained_sd[key] + + # For any original other keys not covered by trained (shouldn't usually happen), + # keep original + for key in other_keys: + if key not in merged: + merged[key] = original_sd[key] + + # Check for missing language keys that were in original but not in trained + missing_lang = lang_keys - trained_lang_keys + if missing_lang: + LOG.warning( + "%d language keys in original but not in trained; keeping original: %s", + len(missing_lang), + list(missing_lang)[:5], + ) + for key in missing_lang: + merged[key] = original_sd[key] + + LOG.info("Merged checkpoint: %d total keys", len(merged)) + + # 6. Save merged weights (sharded at 50GB, matching transformers default) + LOG.info("Saving merged weights to %s", out_path) + state_dict_split = split_torch_state_dict_into_shards(merged, max_shard_size="50GB") + + for filename, tensors in state_dict_split.filename_to_tensors.items(): + shard = {name: merged[name] for name in tensors} + save_file(shard, str(out_path / filename)) + + if state_dict_split.is_sharded: + index = { + "metadata": { + "total_size": sum(t.numel() * t.element_size() for t in merged.values()) + }, + "weight_map": state_dict_split.tensor_to_filename, + } + with open(out_path / "model.safetensors.index.json", "w") as f: + json.dump(index, f, indent=2) + LOG.info("Saved %d shards", len(state_dict_split.filename_to_tensors)) + + # 7. Copy/update config + LOG.info("Writing config.json") + original_config = AutoConfig.from_pretrained( + original_model, trust_remote_code=trust_remote_code + ) + + # Update text_config fields from trained model's config if available + trained_config_path = trained_path / "config.json" + if trained_config_path.exists(): + with open(trained_config_path) as f: + trained_config_dict = json.load(f) + + # The trained config is the text sub-config; merge its fields into + # the original composite config's text_config + if hasattr(original_config, "text_config"): + for key, val in trained_config_dict.items(): + if key not in ("model_type", "_name_or_path", "architectures"): + if hasattr(original_config.text_config, key): + setattr(original_config.text_config, key, val) + + original_config.save_pretrained(out_path) + + # 8. Copy tokenizer files from trained model if present + tokenizer_files = list(trained_path.glob("tokenizer*")) + list( + trained_path.glob("special_tokens_map*") + ) + if tokenizer_files: + import shutil + + for tok_file in tokenizer_files: + shutil.copy2(tok_file, out_path / tok_file.name) + LOG.info("Copied %d tokenizer files", len(tokenizer_files)) + + LOG.info("Merge complete. Output saved to %s", out_path) + + +def main(): + parser = argparse.ArgumentParser( + description="Merge trained text-only Gemma3 weights back into a multimodal checkpoint." + ) + parser.add_argument( + "--original-model", + required=True, + help="HuggingFace model ID or local path to the original multimodal model", + ) + parser.add_argument( + "--trained-model", + required=True, + help="Local path to the trained text-only model output directory", + ) + parser.add_argument( + "--output-dir", + required=True, + help="Directory to save the merged multimodal checkpoint", + ) + parser.add_argument( + "--trust-remote-code", + action="store_true", + default=False, + help="Trust remote code when loading model config", + ) + args = parser.parse_args() + + merge( + original_model=args.original_model, + trained_model=args.trained_model, + output_dir=args.output_dir, + trust_remote_code=args.trust_remote_code, + ) + + +if __name__ == "__main__": + main() diff --git a/src/axolotl/integrations/gemma3/README.md b/src/axolotl/integrations/gemma3/README.md new file mode 100644 index 000000000..a27b53880 --- /dev/null +++ b/src/axolotl/integrations/gemma3/README.md @@ -0,0 +1,37 @@ +# Gemma3 Text-from-Multimodal Plugin + +Load a Gemma3 multimodal checkpoint (e.g. `google/gemma-3-4b-it`) directly into `Gemma3ForCausalLM` for text-only training. This bypasses the multimodal trainer path and enables sample packing and other text-specific optimizations. + +## How it works + +The plugin uses transformers v5's `key_mapping` parameter on `from_pretrained` to remap `model.language_model.*` checkpoint keys to `model.*`, matching what `Gemma3ForCausalLM` expects. Vision tower and projector weights are automatically discarded. On save, transformers reverses the mapping so checkpoints retain the original `model.language_model.*` prefix. + +## Usage + +Add the plugin to your YAML config: + +```yaml +base_model: google/gemma-3-4b-it + +plugins: + - axolotl.integrations.gemma3.Gemma3TextFromMultimodalPlugin +``` + +See `examples/gemma3/gemma-3-4b-qlora.yml` for a complete example. + +## Merging weights back into a multimodal checkpoint + +After training, the saved checkpoint contains only the language model weights. To reconstruct a full `Gemma3ForConditionalGeneration` checkpoint (with the original vision tower and projector), use the merge script: + +```bash +python scripts/merge_gemma3_multimodal_weights.py \ + --original-model google/gemma-3-4b-it \ + --trained-model /path/to/trained/output \ + --output-dir /path/to/merged +``` + +This combines: +- **Trained language model weights** from your output checkpoint +- **Original vision tower + projector weights** from the base multimodal model + +The merged checkpoint can be loaded as `Gemma3ForConditionalGeneration` for multimodal inference or further training. diff --git a/src/axolotl/integrations/gemma3/__init__.py b/src/axolotl/integrations/gemma3/__init__.py new file mode 100644 index 000000000..cfea2a384 --- /dev/null +++ b/src/axolotl/integrations/gemma3/__init__.py @@ -0,0 +1,9 @@ +"""Gemma3 integration for loading multimodal checkpoints as text-only models.""" + +from .args import Gemma3TextFromMultimodalArgs +from .plugin import Gemma3TextFromMultimodalPlugin + +__all__ = [ + "Gemma3TextFromMultimodalArgs", + "Gemma3TextFromMultimodalPlugin", +] diff --git a/src/axolotl/integrations/gemma3/args.py b/src/axolotl/integrations/gemma3/args.py new file mode 100644 index 000000000..128fc3a6e --- /dev/null +++ b/src/axolotl/integrations/gemma3/args.py @@ -0,0 +1,31 @@ +"""Pydantic input args for the Gemma3 text-from-multimodal plugin.""" + +from pydantic import BaseModel, model_validator + +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + + +class Gemma3TextFromMultimodalArgs(BaseModel): + """Configuration args for loading a Gemma3 multimodal checkpoint as text-only.""" + + gemma3_text_from_multimodal: bool = True + extract_text_config: bool = False + + @model_validator(mode="before") + @classmethod + def set_model_type(cls, data): + if not isinstance(data, dict): + return data + + if not data.get("gemma3_text_from_multimodal", True): + return data + + if not data.get("model_type"): + LOG.info( + "Gemma3TextFromMultimodalPlugin: auto-setting model_type to Gemma3ForCausalLM" + ) + data["model_type"] = "Gemma3ForCausalLM" + + return data diff --git a/src/axolotl/integrations/gemma3/plugin.py b/src/axolotl/integrations/gemma3/plugin.py new file mode 100644 index 000000000..68bbafe29 --- /dev/null +++ b/src/axolotl/integrations/gemma3/plugin.py @@ -0,0 +1,88 @@ +"""Plugin for loading Gemma3 multimodal checkpoints into Gemma3ForCausalLM (text-only). + +Uses transformers v5's ``key_mapping`` parameter on ``from_pretrained`` to remap +``model.language_model.*`` keys to ``model.*``, discarding vision tower and projector +weights. On save, transformers automatically reverses the mapping so saved +checkpoints retain the original ``model.language_model.*`` prefix. +""" + +from axolotl.integrations.base import BasePlugin +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + +# key_mapping for transformers from_pretrained: +# Remap checkpoint keys matching ^model.language_model -> model +# Vision tower / projector keys won't match any model parameter and are discarded. +GEMMA3_KEY_MAPPING = {"^model.language_model": "model"} + + +class Gemma3TextFromMultimodalPlugin(BasePlugin): + """Load a Gemma3 multimodal checkpoint as a text-only Gemma3ForCausalLM. + + Hooks + ----- + register(cfg) + Runs before config validation. Sets the ``_extract_text_config`` flag, + ensures ``model_type`` is ``Gemma3ForCausalLM``, and injects + ``key_mapping`` into ``model_kwargs`` so that ``from_pretrained`` remaps + ``model.language_model.*`` → ``model.*``. + + pre_model_load(cfg) + Runs after config validation/normalization but before model instantiation. + Validates that ``model_config_type`` is ``gemma3_text`` and + ``is_multimodal`` is False (confirming that ``_extract_text_config`` + worked correctly). + """ + + def get_input_args(self) -> str: + return "axolotl.integrations.gemma3.Gemma3TextFromMultimodalArgs" + + def register(self, cfg: dict): + """Set up config for multimodal → text-only loading. + + This runs before Pydantic validation, so ``cfg`` is a raw dict. + """ + if not cfg.get("gemma3_text_from_multimodal", True): + LOG.info("Gemma3TextFromMultimodalPlugin: disabled via config") + return + + LOG.info( + "Gemma3TextFromMultimodalPlugin: configuring multimodal → text-only loading" + ) + + # Flag for load_model_config() to extract the text sub-config + cfg["extract_text_config"] = True + + # Ensure model_type is set for the text-only model class + if not cfg.get("model_type"): + cfg["model_type"] = "Gemma3ForCausalLM" + + # Inject key_mapping into model_kwargs so from_pretrained remaps weights + model_kwargs = cfg.setdefault("model_kwargs", {}) + model_kwargs["key_mapping"] = GEMMA3_KEY_MAPPING + + def pre_model_load(self, cfg): + """Validate that config extraction worked before model instantiation.""" + if not getattr(cfg, "gemma3_text_from_multimodal", True): + return + + if cfg.model_config_type != "gemma3_text": + LOG.warning( + "Gemma3TextFromMultimodalPlugin: expected model_config_type='gemma3_text' " + "but got '%s'. The text config extraction may not have worked.", + cfg.model_config_type, + ) + + if cfg.is_multimodal: + LOG.warning( + "Gemma3TextFromMultimodalPlugin: cfg.is_multimodal is True. " + "The model will be loaded via the multimodal trainer path, " + "which may not support sample packing or LoRA kernels." + ) + + LOG.info( + "Gemma3TextFromMultimodalPlugin: model_config_type=%s, is_multimodal=%s", + cfg.model_config_type, + cfg.is_multimodal, + ) diff --git a/src/axolotl/loaders/utils.py b/src/axolotl/loaders/utils.py index 187784b93..35b77636a 100644 --- a/src/axolotl/loaders/utils.py +++ b/src/axolotl/loaders/utils.py @@ -204,6 +204,13 @@ def load_model_config(cfg: DictDefault) -> PretrainedConfig | addict.Dict: check_model_config(cfg, model_config) + # Extract text config from composite config when explicitly requested + # (set by plugins like Gemma3TextFromMultimodalPlugin) + if getattr(cfg, "extract_text_config", False) and hasattr( + model_config, "get_text_config" + ): + model_config = model_config.get_text_config() + return model_config diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index fb381a8a1..b0785a761 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -247,7 +247,7 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2, raise_on_drop=F def process_datasets_for_packing(cfg, train_dataset, eval_dataset): - drop_attn_mask = cfg.model_config_type in ["mamba", "gemma3"] + drop_attn_mask = cfg.model_config_type in ["mamba", "gemma3", "gemma3_text"] if drop_attn_mask: LOG.info("dropping attention_mask column") train_dataset = train_dataset.remove_columns("attention_mask")