Compare commits

..

4 Commits

Author SHA1 Message Date
NanoCode012
53a12282bc fix: log merge command once done 2026-02-14 00:45:01 +07:00
NanoCode012
7271754902 fix: handle plugin logging 2026-02-14 00:40:43 +07:00
NanoCode012
6d5257d92e fix: ignore ds_store 2026-02-14 00:33:53 +07:00
NanoCode012
0e357b5df6 fix: load gemma3 as text only model with dynamic weights 2026-02-14 00:32:48 +07:00
32 changed files with 455 additions and 607 deletions

3
.gitignore vendored
View File

@@ -193,3 +193,6 @@ out/
# scm auto-versioning
src/axolotl/_version.py
# macOS
.DS_Store

View File

@@ -210,8 +210,6 @@ axolotl lm-eval config.yml
Configuration options:
```yaml
lm_eval_model: # model to evaluate (local or hf path)
# List of tasks to evaluate
lm_eval_tasks:
- arc_challenge
@@ -220,7 +218,7 @@ lm_eval_batch_size: # Batch size for evaluation
output_dir: # Directory to save evaluation results
```
See [LM Eval Harness integration docs](https://docs.axolotl.ai/docs/custom_integrations.html#language-model-evaluation-harness-lm-eval) for full configuration details.
See [LM Eval Harness](https://github.com/EleutherAI/lm-evaluation-harness) for more details.
### delinearize-llama4

View File

@@ -1,8 +1,7 @@
base_model: google/gemma-3-4b-it
# Need to set else transformers tries to load vision too
model_type: Gemma3ForCausalLM
cls_model_config: Gemma3TextConfig
plugins:
- axolotl.integrations.gemma3.Gemma3TextFromMultimodalPlugin
load_in_4bit: true
@@ -30,7 +29,6 @@ lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0

View File

@@ -1,12 +1,11 @@
base_model: google/gemma-3-12b-it
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: false
strict: false
plugins:
- axolotl.integrations.gemma3.Gemma3TextFromMultimodalPlugin
- axolotl.integrations.liger.LigerPlugin
liger_rope: true

View File

@@ -7,6 +7,7 @@ load_in_4bit: false
strict: false
plugins:
- axolotl.integrations.gemma3.Gemma3TextFromMultimodalPlugin
- axolotl.integrations.liger.LigerPlugin
liger_rope: true

View File

@@ -1,12 +1,11 @@
base_model: google/gemma-3-12b-it
# Math finetuning configuration for Gemma3-12B
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: false
strict: false
plugins:
- axolotl.integrations.gemma3.Gemma3TextFromMultimodalPlugin
- axolotl.integrations.liger.LigerPlugin
liger_rope: true

View File

@@ -7,6 +7,7 @@ load_in_4bit: false
strict: false
plugins:
- axolotl.integrations.gemma3.Gemma3TextFromMultimodalPlugin
- axolotl.integrations.liger.LigerPlugin
liger_rope: true

View File

@@ -1,12 +1,11 @@
base_model: google/gemma-3-27b-it
# Math finetuning configuration for Gemma3-27B
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: false
strict: false
plugins:
- axolotl.integrations.gemma3.Gemma3TextFromMultimodalPlugin
- axolotl.integrations.liger.LigerPlugin
liger_rope: true

View File

@@ -7,6 +7,7 @@ load_in_4bit: false
strict: false
plugins:
- axolotl.integrations.gemma3.Gemma3TextFromMultimodalPlugin
- axolotl.integrations.liger.LigerPlugin
liger_rope: true

View File

@@ -0,0 +1,225 @@
"""Merge trained text-only Gemma3 weights back into a full multimodal checkpoint.
After training with the Gemma3TextFromMultimodalPlugin, the saved checkpoint
contains only the language model weights (with ``model.language_model.*``
prefix, reversed by transformers v5's key_mapping on save).
This script reconstructs a full ``Gemma3ForConditionalGeneration`` checkpoint by
combining the trained language model weights with the original vision tower and
projector weights from the base multimodal model.
Usage::
python scripts/merge_gemma3_multimodal_weights.py \\
--original-model google/gemma-3-4b-it \\
--trained-model /path/to/trained/output \\
--output-dir /path/to/merged
"""
import argparse
import json
import logging
from pathlib import Path
import torch
from huggingface_hub import split_torch_state_dict_into_shards
from safetensors.torch import load_file, save_file
from transformers import AutoConfig
LOG = logging.getLogger(__name__)
def collect_safetensors(model_dir: Path) -> dict[str, torch.Tensor]:
"""Load and merge all safetensors shard files in a directory."""
shard_files = sorted(model_dir.glob("*.safetensors"))
if not shard_files:
raise FileNotFoundError(f"No safetensors files found in {model_dir}")
state_dict: dict[str, torch.Tensor] = {}
for shard in shard_files:
LOG.info("Loading %s", shard.name)
state_dict.update(load_file(str(shard)))
return state_dict
def merge(
original_model: str,
trained_model: str,
output_dir: str,
*,
trust_remote_code: bool = False,
) -> None:
original_path = Path(original_model)
trained_path = Path(trained_model)
out_path = Path(output_dir)
out_path.mkdir(parents=True, exist_ok=True)
# 1. Load the original multimodal checkpoint
LOG.info("Loading original multimodal weights from %s", original_model)
if original_path.is_dir():
original_sd = collect_safetensors(original_path)
else:
from huggingface_hub import snapshot_download
cached = Path(
snapshot_download(original_model, allow_patterns=["*.safetensors"])
)
original_sd = collect_safetensors(cached)
# 2. Load trained text-only weights (already reversed to model.language_model.* by
# transformers v5 key_mapping on save)
LOG.info("Loading trained text-only weights from %s", trained_model)
trained_sd = collect_safetensors(trained_path)
# 3. Classify original keys
lang_keys = {k for k in original_sd if k.startswith("model.language_model.")}
vision_keys = {k for k in original_sd if k.startswith("model.vision_tower.")}
projector_keys = {
k for k in original_sd if k.startswith("model.multi_modal_projector.")
}
other_keys = set(original_sd.keys()) - lang_keys - vision_keys - projector_keys
LOG.info(
"Original checkpoint: %d language, %d vision, %d projector, %d other keys",
len(lang_keys),
len(vision_keys),
len(projector_keys),
len(other_keys),
)
# 4. Classify trained keys (reverse mapping on save gives model.language_model.* prefix)
trained_lang_keys = {k for k in trained_sd if k.startswith("model.language_model.")}
trained_other = set(trained_sd.keys()) - trained_lang_keys
LOG.info(
"Trained checkpoint: %d language keys, %d other keys",
len(trained_lang_keys),
len(trained_other),
)
# 5. Build merged state dict
merged: dict[str, torch.Tensor] = {}
# Keep vision tower and projector from original
for key in vision_keys | projector_keys:
merged[key] = original_sd[key]
# Use trained language model weights (overwrite original)
for key in trained_lang_keys:
merged[key] = trained_sd[key]
# For other trained keys (like lm_head.weight), use trained version
for key in trained_other:
merged[key] = trained_sd[key]
# For any original other keys not covered by trained (shouldn't usually happen),
# keep original
for key in other_keys:
if key not in merged:
merged[key] = original_sd[key]
# Check for missing language keys that were in original but not in trained
missing_lang = lang_keys - trained_lang_keys
if missing_lang:
LOG.warning(
"%d language keys in original but not in trained; keeping original: %s",
len(missing_lang),
list(missing_lang)[:5],
)
for key in missing_lang:
merged[key] = original_sd[key]
LOG.info("Merged checkpoint: %d total keys", len(merged))
# 6. Save merged weights (sharded at 50GB, matching transformers default)
LOG.info("Saving merged weights to %s", out_path)
state_dict_split = split_torch_state_dict_into_shards(merged, max_shard_size="50GB")
for filename, tensors in state_dict_split.filename_to_tensors.items():
shard = {name: merged[name] for name in tensors}
save_file(shard, str(out_path / filename))
if state_dict_split.is_sharded:
index = {
"metadata": {
"total_size": sum(t.numel() * t.element_size() for t in merged.values())
},
"weight_map": state_dict_split.tensor_to_filename,
}
with open(out_path / "model.safetensors.index.json", "w") as f:
json.dump(index, f, indent=2)
LOG.info("Saved %d shards", len(state_dict_split.filename_to_tensors))
# 7. Copy/update config
LOG.info("Writing config.json")
original_config = AutoConfig.from_pretrained(
original_model, trust_remote_code=trust_remote_code
)
# Update text_config fields from trained model's config if available
trained_config_path = trained_path / "config.json"
if trained_config_path.exists():
with open(trained_config_path) as f:
trained_config_dict = json.load(f)
# The trained config is the text sub-config; merge its fields into
# the original composite config's text_config
if hasattr(original_config, "text_config"):
for key, val in trained_config_dict.items():
if key not in ("model_type", "_name_or_path", "architectures"):
if hasattr(original_config.text_config, key):
setattr(original_config.text_config, key, val)
original_config.save_pretrained(out_path)
# 8. Copy tokenizer files from trained model if present
tokenizer_files = list(trained_path.glob("tokenizer*")) + list(
trained_path.glob("special_tokens_map*")
)
if tokenizer_files:
import shutil
for tok_file in tokenizer_files:
shutil.copy2(tok_file, out_path / tok_file.name)
LOG.info("Copied %d tokenizer files", len(tokenizer_files))
LOG.info("Merge complete. Output saved to %s", out_path)
def main():
parser = argparse.ArgumentParser(
description="Merge trained text-only Gemma3 weights back into a multimodal checkpoint."
)
parser.add_argument(
"--original-model",
required=True,
help="HuggingFace model ID or local path to the original multimodal model",
)
parser.add_argument(
"--trained-model",
required=True,
help="Local path to the trained text-only model output directory",
)
parser.add_argument(
"--output-dir",
required=True,
help="Directory to save the merged multimodal checkpoint",
)
parser.add_argument(
"--trust-remote-code",
action="store_true",
default=False,
help="Trust remote code when loading model config",
)
args = parser.parse_args()
merge(
original_model=args.original_model,
trained_model=args.trained_model,
output_dir=args.output_dir,
trust_remote_code=args.trust_remote_code,
)
if __name__ == "__main__":
main()

View File

@@ -104,7 +104,7 @@ class CutCrossEntropyPlugin(BasePlugin):
def patch_llama_like(
self,
model_type_to_patch: str,
model_type: str,
) -> None:
"""
Generic patch for model architectures with causal lm similar to llama
@@ -112,10 +112,7 @@ class CutCrossEntropyPlugin(BasePlugin):
from cut_cross_entropy.transformers.patch import PATCH_FNS
def patch_generic(
maybe_model,
patch_options,
remote_model_id: str | None,
model_type: str,
maybe_model, patch_options, model_type: str, remote_model_id: str | None
):
import cut_cross_entropy.transformers.llama
from cut_cross_entropy.transformers.llama import cce_forward
@@ -139,13 +136,11 @@ class CutCrossEntropyPlugin(BasePlugin):
f"Error: {str(e)}"
) from e
if model_type_to_patch not in PATCH_FNS:
if model_type not in PATCH_FNS:
LOG.warning_once(
"Setting up generic cce patch for model type: %s", model_type_to_patch
"Setting up generic cce patch for model type: %s", model_type
)
LOG.warning_once(
f"Generic Cut Cross Entropy + {model_type_to_patch} support is experimental and may not work as expected."
)
PATCH_FNS[model_type_to_patch] = partial(
patch_generic, model_type=model_type_to_patch
f"Generic Cut Cross Entropy + {model_type} support is experimental and may not work as expected."
)
PATCH_FNS[model_type] = partial(patch_generic, model_type=model_type)

View File

@@ -0,0 +1,37 @@
# Gemma3 Text-from-Multimodal Plugin
Load a Gemma3 multimodal checkpoint (e.g. `google/gemma-3-4b-it`) directly into `Gemma3ForCausalLM` for text-only training. This bypasses the multimodal trainer path and enables sample packing and other text-specific optimizations.
## How it works
The plugin uses transformers v5's `key_mapping` parameter on `from_pretrained` to remap `model.language_model.*` checkpoint keys to `model.*`, matching what `Gemma3ForCausalLM` expects. Vision tower and projector weights are automatically discarded. On save, transformers reverses the mapping so checkpoints retain the original `model.language_model.*` prefix.
## Usage
Add the plugin to your YAML config:
```yaml
base_model: google/gemma-3-4b-it
plugins:
- axolotl.integrations.gemma3.Gemma3TextFromMultimodalPlugin
```
See `examples/gemma3/gemma-3-4b-qlora.yml` for a complete example.
## Merging weights back into a multimodal checkpoint
After training, the saved checkpoint contains only the language model weights. To reconstruct a full `Gemma3ForConditionalGeneration` checkpoint (with the original vision tower and projector), use the merge script:
```bash
python scripts/merge_gemma3_multimodal_weights.py \
--original-model google/gemma-3-4b-it \
--trained-model /path/to/trained/output \
--output-dir /path/to/merged
```
This combines:
- **Trained language model weights** from your output checkpoint
- **Original vision tower + projector weights** from the base multimodal model
The merged checkpoint can be loaded as `Gemma3ForConditionalGeneration` for multimodal inference or further training.

View File

@@ -0,0 +1,9 @@
"""Gemma3 integration for loading multimodal checkpoints as text-only models."""
from .args import Gemma3TextFromMultimodalArgs
from .plugin import Gemma3TextFromMultimodalPlugin
__all__ = [
"Gemma3TextFromMultimodalArgs",
"Gemma3TextFromMultimodalPlugin",
]

View File

@@ -0,0 +1,31 @@
"""Pydantic input args for the Gemma3 text-from-multimodal plugin."""
from pydantic import BaseModel, model_validator
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
class Gemma3TextFromMultimodalArgs(BaseModel):
"""Configuration args for loading a Gemma3 multimodal checkpoint as text-only."""
gemma3_text_from_multimodal: bool = True
extract_text_config: bool = False
@model_validator(mode="before")
@classmethod
def set_model_type(cls, data):
if not isinstance(data, dict):
return data
if not data.get("gemma3_text_from_multimodal", True):
return data
if not data.get("model_type"):
LOG.info(
"Gemma3TextFromMultimodalPlugin: auto-setting model_type to Gemma3ForCausalLM"
)
data["model_type"] = "Gemma3ForCausalLM"
return data

View File

@@ -0,0 +1,107 @@
"""Plugin for loading Gemma3 multimodal checkpoints into Gemma3ForCausalLM (text-only).
Uses transformers v5's ``key_mapping`` parameter on ``from_pretrained`` to remap
``model.language_model.*`` keys to ``model.*``, discarding vision tower and projector
weights. On save, transformers automatically reverses the mapping so saved
checkpoints retain the original ``model.language_model.*`` prefix.
"""
from axolotl.integrations.base import BasePlugin
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
# key_mapping for transformers from_pretrained:
# Remap checkpoint keys matching ^model.language_model -> model
# Vision tower / projector keys won't match any model parameter and are discarded.
GEMMA3_KEY_MAPPING = {"^model.language_model": "model"}
class Gemma3TextFromMultimodalPlugin(BasePlugin):
"""Load a Gemma3 multimodal checkpoint as a text-only Gemma3ForCausalLM.
Hooks
-----
register(cfg)
Runs before config validation. Sets the ``_extract_text_config`` flag,
ensures ``model_type`` is ``Gemma3ForCausalLM``, and injects
``key_mapping`` into ``model_kwargs`` so that ``from_pretrained`` remaps
``model.language_model.*`` → ``model.*``.
pre_model_load(cfg)
Runs after config validation/normalization but before model instantiation.
Validates that ``model_config_type`` is ``gemma3_text`` and
``is_multimodal`` is False (confirming that ``_extract_text_config``
worked correctly).
"""
def get_input_args(self) -> str:
return "axolotl.integrations.gemma3.Gemma3TextFromMultimodalArgs"
def register(self, cfg: dict):
"""Set up config for multimodal → text-only loading.
This runs before Pydantic validation, so ``cfg`` is a raw dict.
"""
if not cfg.get("gemma3_text_from_multimodal", True):
raise ValueError(
"Gemma3TextFromMultimodalPlugin: disabled via config, but plugin selected"
)
# Flag for load_model_config() to extract the text sub-config
cfg["extract_text_config"] = True
# Ensure model_type is set for the text-only model class
if not cfg.get("model_type"):
cfg["model_type"] = "Gemma3ForCausalLM"
# Inject key_mapping into model_kwargs so from_pretrained remaps weights
model_kwargs = cfg.setdefault("model_kwargs", {})
model_kwargs["key_mapping"] = GEMMA3_KEY_MAPPING
def pre_model_load(self, cfg):
"""Validate that config extraction worked before model instantiation."""
if not getattr(cfg, "gemma3_text_from_multimodal", True):
return
if cfg.model_config_type != "gemma3_text":
LOG.warning(
"Gemma3TextFromMultimodalPlugin: expected model_config_type='gemma3_text' "
"but got '%s'. The text config extraction may not have worked.",
cfg.model_config_type,
)
if cfg.is_multimodal or cfg.processor_type:
raise ValueError(
"Multimodal mode is enabled (processor_type set), but "
"Gemma3TextFromMultimodalPlugin enabled. "
"Please disable one of the two."
)
def post_train(self, cfg, model):
"""Log merge command after training completes."""
if cfg.adapter:
LOG.info(
"Adapter training detected. To reconstruct the multimodal checkpoint:\n"
" 1. Merge adapter weights into the text-only base model:\n"
" axolotl merge_lora <your_config.yml>\n"
" 2. Then merge the resulting full model back into the multimodal checkpoint:\n"
" python scripts/merge_gemma3_multimodal_weights.py \\\n"
" --original-model %s \\\n"
" --trained-model %s/merged \\\n"
" --output-dir %s/multi-modal/merged",
cfg.base_model,
cfg.output_dir,
cfg.output_dir,
)
else:
LOG.info(
"To merge trained weights back into the multimodal checkpoint, run:\n"
" python scripts/merge_gemma3_multimodal_weights.py \\\n"
" --original-model %s \\\n"
" --trained-model %s \\\n"
" --output-dir %s/multi-modal/merged",
cfg.base_model,
cfg.output_dir,
cfg.output_dir,
)

View File

@@ -1,44 +0,0 @@
# Kernels Integration
MoE (Mixture of Experts) kernels speed up training for MoE layers and reduce VRAM costs. In transformers v5, `batched_mm` and `grouped_mm` were integrated as built-in options via the `experts_implementation` config kwarg:
```python
class ExpertsInterface(GeneralInterface):
_global_mapping = {
"batched_mm": batched_mm_experts_forward,
"grouped_mm": grouped_mm_experts_forward,
}
```
In our custom integration, we add support for **ScatterMoE**, which is even more efficient and faster than `grouped_mm`.
## Usage
Add the following to your axolotl YAML config:
```yaml
plugins:
- axolotl.integrations.kernels.KernelsPlugin
use_kernels: true
use_scattermoe: true
```
**Important:** Setting `experts_implementation` is incompatible with `use_scattermoe`.
## How It Works
The `KernelsPlugin` runs before model loading and:
1. Registers the ScatterMoE kernel from the [`axolotl-ai-co/scattermoe`](https://huggingface.co/axolotl-ai-co/scattermoe) Hub repo.
2. Patches the model's `SparseMoeBlock` forward method with the optimized ScatterMoE implementation.
This works for any MoE model in transformers that uses a `SparseMoeBlock` class (Mixtral, Qwen2-MoE, OLMoE, etc.).
## Limitations
ScatterMoE uses a softmax -> topk routing, so results may be different for some model arch as baseline (GPT-OSS, GLM_MOE_DSA).
## Note on MegaBlocks
We tested [MegaBlocks](https://huggingface.co/kernels-community/megablocks) but were unable to ensure numerical accuracy, so we did not integrate it. It was also incompatible with many newer model architectures in transformers.

View File

@@ -6,12 +6,6 @@ See https://github.com/EleutherAI/lm-evaluation-harness
## Usage
There are two ways to use the LM Eval integration:
### 1. Post-Training Evaluation
When training with the plugin enabled, evaluation runs automatically after training completes:
```yaml
plugins:
- axolotl.integrations.lm_eval.LMEvalPlugin
@@ -22,50 +16,9 @@ lm_eval_tasks:
- arc_easy
lm_eval_batch_size: # Batch size for evaluation
# Directory to save evaluation results.
# The final model is loaded from this directory
# unless specified otherwise (see below)
output_dir:
output_dir: # Directory to save evaluation results
```
Run training as usual:
```bash
axolotl train config.yml
```
### 2. Standalone CLI Evaluation
Evaluate any model directly without training:
```yaml
lm_eval_model: meta-llama/Llama-2-7b-hf
plugins:
- axolotl.integrations.lm_eval.LMEvalPlugin
lm_eval_tasks:
- gsm8k
- hellaswag
- arc_easy
lm_eval_batch_size: 8
output_dir: ./outputs
```
Run evaluation:
```bash
axolotl lm-eval config.yml
```
## Model Selection Priority
The model to evaluate is selected in the following priority order:
1. **`lm_eval_model`** - Explicit model path or HuggingFace repo (highest priority)
2. **`hub_model_id`** - Trained model pushed to HuggingFace Hub
3. **`output_dir`** - Local checkpoint directory containing trained model weights
## Citation
```bib

View File

@@ -5,7 +5,7 @@ Module for the Plugin for LM Eval Harness
import subprocess # nosec
from axolotl.integrations.base import BasePlugin
from axolotl.integrations.lm_eval.cli import build_lm_eval_command, get_model_path
from axolotl.integrations.lm_eval.cli import build_lm_eval_command
from .args import LMEvalArgs as LMEvalArgs
@@ -29,7 +29,7 @@ class LMEvalPlugin(BasePlugin):
wandb_project=cfg.wandb_project,
wandb_entity=cfg.wandb_entity,
wandb_name=cfg.wandb_name,
model=get_model_path(cfg),
model=cfg.lm_eval_model or cfg.hub_model_id,
):
subprocess.run( # nosec
lm_eval_args,

View File

@@ -13,21 +13,6 @@ import yaml
from axolotl.utils.dict import DictDefault
def get_model_path(cfg: DictDefault) -> str | None:
"""
Determine which model path to use for evaluation.
Priority order (highest to lowest):
1. lm_eval_model - Explicit model path override
2. hub_model_id - Model pushed to HuggingFace Hub
3. None - Falls back to output_dir in build_lm_eval_command
Returns:
Model path string or None to use output_dir fallback
"""
return cfg.lm_eval_model or cfg.hub_model_id or None
def build_lm_eval_command(
tasks: list[str],
bfloat16=True,
@@ -123,7 +108,7 @@ def lm_eval(config: str, cloud: Optional[str] = None):
wandb_project=cfg.wandb_project,
wandb_entity=cfg.wandb_entity,
wandb_name=cfg.wandb_name,
model=get_model_path(cfg),
model=cfg.lm_eval_model or cfg.hub_model_id,
revision=cfg.revision,
apply_chat_template=cfg.apply_chat_template,
fewshot_as_multiturn=cfg.fewshot_as_multiturn,

View File

@@ -15,7 +15,7 @@ from torch import nn
from torch.distributed.tensor import DTensor
from .geglu import geglu_backward, geglu_forward
from .quantize import dequantize_weight
from .quantize import dequantize
from .swiglu import swiglu_backward, swiglu_forward
from .utils import torch_amp_custom_bwd, torch_amp_custom_fwd
@@ -46,12 +46,6 @@ def get_lora_parameters(
W = base_layer.weight
b = base_layer.bias
# Unwrap DTensor if FSDP2 left the weight wrapped -- DTensor does not proxy
# attribute access to the underlying tensor subclass, so torchao methods like
# .dequantize() or .get_original_weight() would not be visible.
if isinstance(W, DTensor):
W = W.full_tensor()
if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
quant_state = getattr(W, "quant_state", None)
return W, b, quant_state, None, None, None
@@ -92,7 +86,6 @@ def matmul_lora(
B: torch.Tensor | None,
s: float | None,
out: torch.Tensor | None = None,
transpose: bool = True,
) -> torch.Tensor:
"""
Efficient fused matmul + LoRA computation.
@@ -105,15 +98,12 @@ def matmul_lora(
B: LoRA B matrix [out_features, rank]
s: LoRA scaling factor
out: Optional output tensor for inplace operations
transpose: If True (default), transpose W before matmul (forward path).
Set to False for backward paths where W is already in the correct layout.
Returns:
Result of X @ W + X @ A @ B
"""
dtype = X.dtype
is_quantized = W_quant is not None or type(W) is not torch.Tensor
W = dequantize_weight(W, W_quant, transpose=transpose)
W = dequantize(W.t(), W_quant)
reshape = False
if X.dim() == 3:
@@ -122,7 +112,7 @@ def matmul_lora(
reshape = True
out = torch.matmul(X, W, out=out)
if is_quantized:
if W_quant is not None:
del W
if A is not None:
@@ -302,16 +292,15 @@ class LoRA_MLP(torch.autograd.Function):
up = up.view(-1, up.shape[-1])
dtype = X.dtype
# Down projection (backward: no transpose needed, W is already [out, in])
# Down projection
grad_down = matmul_lora(
grad_output,
down_weight,
down_weight.t(),
None,
down_quant,
down_B,
down_A,
down_scale,
transpose=False,
)
# Activation backward
@@ -343,7 +332,7 @@ class LoRA_MLP(torch.autograd.Function):
if dX is not None:
# Up projection gradients
up_weight = dequantize_weight(up_weight, up_quant, transpose=True)
up_weight = dequantize(up_weight.t(), up_quant)
if ctx.inplace:
dX = torch.matmul(grad_up, up_weight.t(), out=X)
else:
@@ -355,7 +344,7 @@ class LoRA_MLP(torch.autograd.Function):
dX += grad_up @ up_B.to(dtype).t() @ (up_scale * up_A.to(dtype).t())
# Gate projection gradients
gate_weight = dequantize_weight(gate_weight, gate_quant)
gate_weight = dequantize(gate_weight, gate_quant)
dX += grad_gate @ gate_weight
del gate_weight
@@ -642,7 +631,7 @@ class LoRA_QKV(torch.autograd.Function):
out_buffer = X if ctx.inplace else None
# Q path
q_weight_t = dequantize_weight(q_weight, q_quant)
q_weight_t = dequantize(q_weight, q_quant)
grad_X = torch.mm(q_grad, q_weight_t, out=out_buffer)
del q_weight
del q_weight_t
@@ -650,7 +639,7 @@ class LoRA_QKV(torch.autograd.Function):
grad_X.addmm_(q_grad, torch.mm(B_q_scaled, A_q_scaled))
# K path
k_weight_t = dequantize_weight(k_weight, k_quant)
k_weight_t = dequantize(k_weight, k_quant)
grad_X.addmm_(k_grad, k_weight_t)
del k_weight
del k_weight_t
@@ -658,7 +647,7 @@ class LoRA_QKV(torch.autograd.Function):
grad_X.addmm_(k_grad, torch.mm(B_k_scaled, A_k_scaled))
# V path
v_weight_t = dequantize_weight(v_weight, v_quant)
v_weight_t = dequantize(v_weight, v_quant)
grad_X.addmm_(v_grad, v_weight_t)
del v_weight
del v_weight_t
@@ -821,7 +810,7 @@ class LoRA_O(torch.autograd.Function):
d_B = s * A @ dY_X
# Get derivative for dX
W = dequantize_weight(W, W_quant, transpose=True)
W = dequantize(W.t(), W_quant)
dX = dY @ W.t()
del W

View File

@@ -146,43 +146,3 @@ def dequantize(
# Handle transposed data
is_transposed: bool = W.shape[0] == 1
return out.t() if is_transposed else out
def dequantize_weight(
W: torch.Tensor,
quant_state: QuantState | list | None = None,
transpose: bool = False,
) -> torch.Tensor:
"""Unified dequantization for both torchao and bnb quantized weights.
For torchao tensor subclasses (AffineQuantizedTensor, NF4Tensor), dequantizes
using the appropriate instance method. For bnb Params4bit, delegates to the
optimized CUDA kernel in ``dequantize``.
Args:
W: Quantized weight tensor ``[out_features, in_features]``.
quant_state: bnb ``QuantState`` (None for torchao / unquantized).
transpose: If True, return ``[in_features, out_features]``.
Returns:
Dequantized float tensor, optionally transposed.
"""
# torchao path: tensor subclass with embedded quantization state
if quant_state is None and type(W) is not torch.Tensor:
result = None
# NF4Tensor (check first — NF4Tensor.dequantize is a static method)
if hasattr(W, "get_original_weight"):
result = W.get_original_weight()
else:
# AffineQuantizedTensor (INT4, etc.)
try:
result = W.dequantize()
except (TypeError, RuntimeError):
pass
if result is not None:
return result.t() if transpose else result
# bnb path: transpose input before the CUDA kernel (existing convention)
if transpose:
return dequantize(W.t(), quant_state)
return dequantize(W, quant_state)

View File

@@ -23,7 +23,6 @@ from axolotl.loaders.utils import get_linear_embedding_layers
from axolotl.telemetry.errors import send_errors
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.enums import TorchAOQuantDType
LOG = get_logger(__name__)
@@ -135,13 +134,11 @@ def load_lora(
rank = int(os.environ.get("LOCAL_RANK", 0))
is_torchao = cfg.peft and cfg.peft.backend == "torchao"
if (
cfg.fsdp_config
and cfg.adapter
and cfg.fsdp_config.cpu_ram_efficient_loading
and rank != 0
and not is_torchao
):
setup_quantized_meta_for_peft(model)
@@ -149,15 +146,6 @@ def load_lora(
if cfg.peft_autocast_adapter_dtype is not None:
model_kwargs["autocast_adapter_dtype"] = cfg.peft_autocast_adapter_dtype
# Patch PEFT's torchao dispatch before any model creation/loading.
# Must happen before both get_peft_model and PeftModel.from_pretrained,
# as both trigger LoRA layer dispatch that would fail for INT4/NF4 weights.
# INT8 is natively supported by PEFT's TorchaoLoraLinear, so skip the patch.
if is_torchao and cfg.peft.weight_dtype != TorchAOQuantDType.int8:
from axolotl.monkeypatch.peft.utils import patch_peft_torchao_dispatch
patch_peft_torchao_dispatch()
if cfg.lora_model_dir:
LOG.debug("Loading pretrained PEFT - LoRA")
if cfg.lora_on_cpu:
@@ -184,7 +172,6 @@ def load_lora(
and cfg.adapter
and cfg.fsdp_config.cpu_ram_efficient_loading
and rank != 0
and not is_torchao
):
setup_quantized_peft_meta_for_training(model)

View File

@@ -158,15 +158,6 @@ class ModelLoader:
"""Property that determines if FSDP with QLoRA is enabled."""
return self.is_fsdp_enabled and self.cfg.adapter == "qlora"
@property
def is_torchao_qlora(self):
"""Property that determines if torchao backend is used for QLoRA."""
return (
self.cfg.adapter == "qlora"
and self.cfg.peft
and self.cfg.peft.backend == "torchao"
)
@send_errors
def load(self) -> tuple[PreTrainedModel | PeftModelForCausalLM, PeftConfig | None]:
"""Load and prepare the model with all configurations and patches.
@@ -500,9 +491,8 @@ class ModelLoader:
# FSDP requires control over device placement, so don't set device_map when FSDP is enabled
if self.is_fsdp_enabled:
# For QLoRA + FSDP with bnb, we still need to set device_map for proper initialization
# torchao tensors work natively with FSDP2, no device_map override needed
if self.is_qlora_and_fsdp_enabled and not self.is_torchao_qlora:
# For QLoRA + FSDP, we still need to set device_map to "auto" for proper initialization
if self.is_qlora_and_fsdp_enabled:
self.model_kwargs["device_map"] = {
"": int(os.environ.get("LOCAL_RANK", 0))
}
@@ -571,44 +561,6 @@ class ModelLoader:
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
**self.model_config.quantization_config
)
elif (
self.cfg.adapter == "qlora"
and self.cfg.peft
and self.cfg.peft.backend == "torchao"
and not self.cfg.merge_lora
):
from transformers import TorchAoConfig
from axolotl.utils.schemas.enums import TorchAOQuantDType
weight_dtype = self.cfg.peft.weight_dtype
if weight_dtype == TorchAOQuantDType.int4:
group_size = self.cfg.peft.group_size or 128
self.model_kwargs["quantization_config"] = TorchAoConfig(
quant_type="int4_weight_only",
group_size=group_size,
)
elif weight_dtype == TorchAOQuantDType.int8:
group_size = self.cfg.peft.group_size or 128
self.model_kwargs["quantization_config"] = TorchAoConfig(
quant_type="int8_weight_only",
group_size=group_size,
)
elif weight_dtype == TorchAOQuantDType.nf4:
from torchao.dtypes._nf4tensor_api import NF4WeightOnlyConfig
block_size = self.cfg.peft.group_size or 64
self.model_kwargs["quantization_config"] = TorchAoConfig(
quant_type=NF4WeightOnlyConfig(
block_size=block_size,
scaler_block_size=256,
),
)
else:
raise ValueError(
f"Unsupported torchao weight_dtype for QLoRA: {weight_dtype}. "
"Supported: int4, int8, nf4"
)
elif self.cfg.adapter == "qlora" and self.cfg.load_in_4bit:
bnb_config = {
"load_in_4bit": True,
@@ -908,10 +860,6 @@ class ModelLoader:
# Make sure everything is in the same dtype
skip_prepare_model_for_kbit_training = True
# torchao quantized models don't use Params4bit and don't need kbit preparation
if self.is_torchao_qlora:
skip_prepare_model_for_kbit_training = True
if (
not skip_prepare_model_for_kbit_training
and self.cfg.adapter in ["lora", "qlora"]

View File

@@ -348,12 +348,10 @@ class PatchManager:
def _apply_fsdp2_bnb_patches(self):
"""Apply FSDP2 BNB patches."""
is_torchao = self.cfg.peft and self.cfg.peft.backend == "torchao"
if (
self.cfg.fsdp_config
and str(self.cfg.fsdp_version) == "2"
and self.cfg.adapter == "qlora"
and not is_torchao
):
from axolotl.monkeypatch.fsdp2_qlora import (
apply_init_sharded_param_patch,

View File

@@ -204,6 +204,13 @@ def load_model_config(cfg: DictDefault) -> PretrainedConfig | addict.Dict:
check_model_config(cfg, model_config)
# Extract text config from composite config when explicitly requested
# (set by plugins like Gemma3TextFromMultimodalPlugin)
if getattr(cfg, "extract_text_config", False) and hasattr(
model_config, "get_text_config"
):
model_config = model_config.get_text_config()
return model_config

View File

@@ -78,30 +78,3 @@ def patch_peft_prep_code():
axolotl.loaders.model.prepare_model_for_kbit_training = (
fixed_prepare_model_for_kbit_training
)
def patch_peft_torchao_dispatch():
"""Skip PEFT's TorchaoLoraLinear for non-INT8 torchao weights.
PEFT's dispatch_torchao() matches AffineQuantizedTensor but then errors in
_check_dtype_supported() because it only allows INT8. Our LoRA kernels handle
dequantization explicitly, so we bypass PEFT's torchao dispatch entirely and
let it fall back to standard Linear LoRA layers.
"""
try:
from peft.tuners.lora import torchao as peft_torchao
except ImportError:
LOG.warning("Could not import peft.tuners.lora.torchao for patching")
return
if getattr(peft_torchao, "_axolotl_patched", False):
return
def patched_dispatch(target, adapter_name, lora_config, **kwargs):
# Return None so PEFT falls back to standard Linear LoRA layers.
# Our LoRA kernels handle torchao dequantization explicitly.
return None
peft_torchao.dispatch_torchao = patched_dispatch
peft_torchao._axolotl_patched = True
LOG.info("Patched PEFT dispatch_torchao to skip TorchaoLoraLinear")

View File

@@ -8,7 +8,6 @@ import torch
class TorchAOQuantDType(Enum):
int4 = torch.int4
int8 = torch.int8
nf4 = "nf4"
float8_e4m3fn = torch.float8_e4m3fn
nvfp4 = "nvfp4"
@@ -17,8 +16,6 @@ class TorchAOQuantDType(Enum):
return TorchAOQuantDType.int4
if str == "int8":
return TorchAOQuantDType.int8
if str == "nf4":
return TorchAOQuantDType.nf4
if str in ["float8_e4m3fn", "fp8", "float8"]:
return TorchAOQuantDType.float8_e4m3fn
if str == "nvfp4":

View File

@@ -1,12 +1,9 @@
"""Pydantic models for PEFT-related configuration"""
from typing import Any, Literal
from typing import Any
from pydantic import BaseModel, Field, field_validator, model_validator
from axolotl.utils.schemas.enums import TorchAOQuantDType
from axolotl.utils.schemas.quantization import validate_ao_dtype
class LoftQConfig(BaseModel):
"""LoftQ configuration subset"""
@@ -18,7 +15,7 @@ class LoftQConfig(BaseModel):
class PeftConfig(BaseModel):
"""PEFT configuration subset"""
"""peftq configuration subset"""
loftq_config: LoftQConfig | None = Field(
default=None,
@@ -26,29 +23,6 @@ class PeftConfig(BaseModel):
"description": "Configuration options for loftq initialization for LoRA"
},
)
backend: Literal["bnb", "torchao"] | None = Field(
default=None,
json_schema_extra={
"description": "Quantization backend for QLoRA. 'bnb' for bitsandbytes (default), 'torchao' for torchao."
},
)
weight_dtype: TorchAOQuantDType | None = Field(
default=None,
json_schema_extra={
"description": "Weight quantization dtype (int4, int8, or nf4). Also used with bnb backend to auto-configure quantization."
},
)
group_size: int | None = Field(
default=None,
json_schema_extra={
"description": "Group size for quantization. Defaults to 128 for int4, 64 for nf4."
},
)
@field_validator("weight_dtype", mode="before")
@classmethod
def validate_weight_dtype(cls, v):
return validate_ao_dtype(v)
class LoraConfig(BaseModel):
@@ -182,56 +156,6 @@ class LoraConfig(BaseModel):
merge_lora: bool | None = None
@model_validator(mode="before")
@classmethod
def auto_detect_qlora(cls, data):
"""Auto-set adapter type and quantization flags from peft config.
When peft.backend and peft.weight_dtype are set, this infers the correct
adapter type and internal flags (load_in_4bit, load_in_8bit) so users
don't need to set them manually.
"""
peft = data.get("peft")
if not isinstance(peft, dict):
return data
backend = peft.get("backend")
weight_dtype = peft.get("weight_dtype")
# Validate: weight_dtype requires backend
if weight_dtype and not backend:
raise ValueError(
"peft.backend is required when peft.weight_dtype is set. "
"Use 'torchao' or 'bnb'."
)
if not weight_dtype:
return data
adapter = data.get("adapter")
if backend == "torchao":
# torchao: any quantized weight_dtype means qlora
if adapter == "lora":
data["adapter"] = "qlora"
elif backend == "bnb":
if weight_dtype == "nf4":
# bnb nf4 = qlora with load_in_4bit
if adapter == "lora":
data["adapter"] = "qlora"
data.setdefault("load_in_4bit", True)
elif weight_dtype == "int8":
# bnb int8 = lora with load_in_8bit
data.setdefault("load_in_8bit", True)
else:
raise ValueError(
f"peft.weight_dtype '{weight_dtype}' is not supported with bnb backend. "
"Supported: nf4, int8."
)
return data
@model_validator(mode="before")
@classmethod
def validate_adapter(cls, data):
@@ -249,8 +173,6 @@ class LoraConfig(BaseModel):
@model_validator(mode="after")
def validate_qlora(self):
if self.adapter == "qlora":
is_torchao = self.peft and self.peft.backend == "torchao"
if self.merge_lora:
# can't merge qlora if loaded in 8bit or 4bit
if self.load_in_8bit:
@@ -262,20 +184,7 @@ class LoraConfig(BaseModel):
if self.load_in_4bit:
raise ValueError("Can't merge qlora if loaded in 4bit")
elif is_torchao:
# torchao backend: validate torchao-specific requirements
if self.load_in_4bit or self.load_in_8bit:
raise ValueError(
"load_in_4bit/load_in_8bit are for bitsandbytes. "
"With peft.backend: torchao, quantization is handled by torchao."
)
if not self.peft.weight_dtype:
raise ValueError(
"peft.weight_dtype is required when peft.backend is 'torchao'"
)
else:
# Default bnb path
if self.load_in_8bit:
raise ValueError("Can't load qlora in 8bit")

View File

@@ -16,8 +16,6 @@ def validate_ao_dtype(v: Any) -> TorchAOQuantDType | None:
return TorchAOQuantDType.int4
if v == "int8":
return TorchAOQuantDType.int8
if v == "nf4":
return TorchAOQuantDType.nf4
if v in ["float8_e4m3fn", "fp8", "float8"]:
return TorchAOQuantDType.float8_e4m3fn
if v == "nvfp4":

View File

@@ -247,7 +247,7 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2, raise_on_drop=F
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
drop_attn_mask = cfg.model_config_type in ["mamba", "gemma3"]
drop_attn_mask = cfg.model_config_type in ["mamba", "gemma3", "gemma3_text"]
if drop_attn_mask:
LOG.info("dropping attention_mask column")
train_dataset = train_dataset.remove_columns("attention_mask")

View File

@@ -3,7 +3,7 @@
import torch
from bitsandbytes.functional import QuantState
from axolotl.kernels.quantize import dequantize, dequantize_weight
from axolotl.kernels.quantize import dequantize
def test_dequantize_null_state():
@@ -100,18 +100,3 @@ def test_dequantize_output_tensor():
result = dequantize(W, quant_state, out=out)
assert result is out
def test_dequantize_weight_plain_tensor():
"""Test that dequantize_weight passes through unquantized tensors unchanged"""
W = torch.randn(32, 64)
result = dequantize_weight(W, quant_state=None, transpose=False)
assert torch.equal(result, W)
def test_dequantize_weight_plain_tensor_transpose():
"""Test that dequantize_weight transposes unquantized tensors"""
W = torch.randn(32, 64)
result = dequantize_weight(W, quant_state=None, transpose=True)
assert result.shape == (64, 32)
assert torch.equal(result, W.t())

View File

@@ -3,14 +3,6 @@ import pytest
from axolotl.utils.config import validate_config
from axolotl.utils.dict import DictDefault
BASE_CFG = {
"datasets": [{"path": "dummy_dataset", "type": "alpaca"}],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 1e-5,
"base_model": "dummy_model",
}
class TestLoRAConfigValidation:
"""Test suite for LoRA/QLoRA configuration validation"""
@@ -157,195 +149,3 @@ class TestLoRAConfigValidation:
result = validate_config(valid_config)
assert result["lora_qkv_kernel"] is True
assert result["trust_remote_code"] is None
class TestTorchaoQLoRAConfigValidation:
"""Test suite for torchao QLoRA auto-detection and validation"""
# --- Auto-detection: torchao ---
@pytest.mark.parametrize("weight_dtype", ["int4", "int8", "nf4"])
def test_torchao_auto_detect_from_lora(self, weight_dtype):
"""adapter: lora + peft.backend: torchao auto-upgrades to qlora"""
cfg = DictDefault(
{
"adapter": "lora",
"peft": {"backend": "torchao", "weight_dtype": weight_dtype},
**BASE_CFG,
}
)
result = validate_config(cfg)
assert result["adapter"] == "qlora"
assert result["peft"]["backend"] == "torchao"
def test_torchao_explicit_qlora(self):
"""adapter: qlora + peft.backend: torchao works directly"""
cfg = DictDefault(
{
"adapter": "qlora",
"peft": {"backend": "torchao", "weight_dtype": "int4"},
**BASE_CFG,
}
)
result = validate_config(cfg)
assert result["adapter"] == "qlora"
# --- Auto-detection: bnb ---
def test_bnb_nf4_auto_detect_from_lora(self):
"""adapter: lora + peft.backend: bnb + weight_dtype: nf4 → qlora + load_in_4bit"""
cfg = DictDefault(
{
"adapter": "lora",
"peft": {"backend": "bnb", "weight_dtype": "nf4"},
**BASE_CFG,
}
)
result = validate_config(cfg)
assert result["adapter"] == "qlora"
assert result["load_in_4bit"] is True
def test_bnb_int8_auto_detect_from_lora(self):
"""adapter: lora + peft.backend: bnb + weight_dtype: int8 → lora + load_in_8bit"""
cfg = DictDefault(
{
"adapter": "lora",
"peft": {"backend": "bnb", "weight_dtype": "int8"},
**BASE_CFG,
}
)
result = validate_config(cfg)
assert result["adapter"] == "lora"
assert result["load_in_8bit"] is True
def test_bnb_nf4_explicit_qlora_auto_sets_load_in_4bit(self):
"""adapter: qlora + peft.backend: bnb + weight_dtype: nf4 auto-sets load_in_4bit"""
cfg = DictDefault(
{
"adapter": "qlora",
"peft": {"backend": "bnb", "weight_dtype": "nf4"},
**BASE_CFG,
}
)
result = validate_config(cfg)
assert result["adapter"] == "qlora"
assert result["load_in_4bit"] is True
# --- Backward compat ---
def test_old_style_qlora_unchanged(self):
"""Old-style adapter: qlora + load_in_4bit: true still works"""
cfg = DictDefault(
{
"adapter": "qlora",
"load_in_4bit": True,
**BASE_CFG,
}
)
result = validate_config(cfg)
assert result["adapter"] == "qlora"
assert result["load_in_4bit"] is True
def test_old_style_lora_8bit_unchanged(self):
"""Old-style adapter: lora + load_in_8bit: true still works"""
cfg = DictDefault(
{
"adapter": "lora",
"load_in_8bit": True,
**BASE_CFG,
}
)
result = validate_config(cfg)
assert result["adapter"] == "lora"
assert result["load_in_8bit"] is True
def test_plain_lora_unchanged(self):
"""adapter: lora without peft block stays as lora"""
cfg = DictDefault(
{
"adapter": "lora",
**BASE_CFG,
}
)
result = validate_config(cfg)
assert result["adapter"] == "lora"
# --- Validation errors ---
def test_torchao_with_load_in_4bit_errors(self):
"""peft.backend: torchao + load_in_4bit is a conflict"""
cfg = DictDefault(
{
"adapter": "qlora",
"load_in_4bit": True,
"peft": {"backend": "torchao", "weight_dtype": "int4"},
**BASE_CFG,
}
)
with pytest.raises(ValueError, match="load_in_4bit.*bitsandbytes"):
validate_config(cfg)
def test_torchao_with_load_in_8bit_errors(self):
"""peft.backend: torchao + load_in_8bit is a conflict"""
cfg = DictDefault(
{
"adapter": "qlora",
"load_in_8bit": True,
"peft": {"backend": "torchao", "weight_dtype": "int4"},
**BASE_CFG,
}
)
with pytest.raises(ValueError, match="load_in_4bit.*bitsandbytes"):
validate_config(cfg)
def test_torchao_without_weight_dtype_errors(self):
"""peft.backend: torchao without weight_dtype errors"""
cfg = DictDefault(
{
"adapter": "qlora",
"peft": {"backend": "torchao"},
**BASE_CFG,
}
)
with pytest.raises(ValueError, match="peft.weight_dtype is required"):
validate_config(cfg)
def test_weight_dtype_without_backend_errors(self):
"""peft.weight_dtype without peft.backend errors"""
cfg = DictDefault(
{
"adapter": "lora",
"peft": {"weight_dtype": "int4"},
**BASE_CFG,
}
)
with pytest.raises(ValueError, match="peft.backend is required"):
validate_config(cfg)
def test_bnb_unsupported_weight_dtype_errors(self):
"""peft.backend: bnb + unsupported weight_dtype errors"""
cfg = DictDefault(
{
"adapter": "lora",
"peft": {"backend": "bnb", "weight_dtype": "int4"},
**BASE_CFG,
}
)
with pytest.raises(ValueError, match="not supported with bnb"):
validate_config(cfg)
# --- Redundant flags don't conflict ---
def test_bnb_nf4_with_explicit_load_in_4bit(self):
"""peft.backend: bnb + weight_dtype: nf4 + load_in_4bit: true is fine (redundant)"""
cfg = DictDefault(
{
"adapter": "lora",
"load_in_4bit": True,
"peft": {"backend": "bnb", "weight_dtype": "nf4"},
**BASE_CFG,
}
)
result = validate_config(cfg)
assert result["adapter"] == "qlora"
assert result["load_in_4bit"] is True