Fix: quantize and target moe layers in transformers v5 for adapters and many misc fixes (#3439)

* fix: saving clones state dict

* fix: apply fix for only CP mode

* fix: add dropout check when using lora target param

* fix: re-add patch from transformers PR #39866

* feat: add moe quant to test by ved

* fix: try match target param properly end with

* fix: clear cache per param quant

* fix: attempt on-load quantize experts instead of post-load

* fix: attempt disable async load

* chore: add log

* chore: adjust log

* fix: remove cuda alloc for moe and enable async load

* chore: remove leftover logs

* chore: add extra empty cache

* fix(doc): clarify support

* fix: handle fsdp2 for paramwrapper dtensor

* feat: attempt to quant experts in 8bit mode too

* feat: attempt to release bf16 experts from vram

* feat: upgrade cce

* fix: fsdp2 init_sharded_param load int8/uint4 dtensor as
require_grad=true on init

* fix: remove unnecessary gc and empty cache

* Revert "fix: remove unnecessary gc and empty cache"

This reverts commit 1d54518990.

* fix: do not call full_tensor on non-dtensors

* fix: attempt to address fsdp2 with quant exp high loss

* fix: attempt lora quant experts wrong dim

* fix: ensure require_grad patch applied for lora 8bit

* fix: attempt lora 8bit fsdp2

* fix: attribute access on save for lora 8bit fsdp2

* fix: wrong weight attrib access

* chore(refactor): add config, re-arrange position of patches, clean
comments

* feat: add example docs

* chore: cherry pick trinity fixes from PR 3399

* chore: comments refactor; add guards

* fix: guard using wrong key

* fix: mamba save does not accept main process param

* fix: guard prevent double hook

* fix: move gc to upper scope

* chore: add comment on proxy forward patch

* fix: add comment to clarify

* feat: add test idempotency

* fix: AttributeError: `e_score_correction_bias` is not an nn.Parameter

* fix: AttributeError: 'NoneType' object has no attribute 'to'

* fix: update docs on cpu_ram_efficient_loading
This commit is contained in:
NanoCode012
2026-03-03 22:06:23 +07:00
committed by GitHub
parent e672d37f33
commit 945c8aeb10
24 changed files with 1015 additions and 29 deletions

View File

@@ -18,4 +18,7 @@ MOE_ARCH_BLOCK = {
"gpt_oss": "GptOssDecoderLayer",
"lfm2_moe": "Lfm2MoeSparseMoeBlock",
"afmoe": "AfmoeMoE",
"glm4_moe": "Glm4MoeDecoderLayer",
"glm4_moe_lite": "Glm4MoeLiteDecoderLayer",
"glm_moe_dsa": "GlmMoeDsaDecoderLayer",
}

View File

@@ -720,12 +720,16 @@ class AxolotlTrainer(
os.makedirs(output_dir, exist_ok=True)
LOG.info(f"Saving model checkpoint to {output_dir}")
# fix for Context Parallel save
if state_dict is None:
state_dict = self.accelerator.get_state_dict(self.model)
if state_dict is not None:
# fix for Context Parallel save: CP eval invalidates tensor storage
# pointers, so clone to CPU to get fresh valid storage for safetensors
if (
state_dict is not None
and self.axolotl_cfg
and self.axolotl_cfg.context_parallel_size
and self.axolotl_cfg.context_parallel_size > 1
):
state_dict = {
k: v.clone() if isinstance(v, torch.Tensor) else v
k: v.detach().cpu() if isinstance(v, torch.Tensor) else v
for k, v in state_dict.items()
}
@@ -761,7 +765,11 @@ class AxolotlTrainer(
metadata={"format": "pt"},
)
else:
self.model.save_pretrained(output_dir, state_dict=state_dict)
self.model.save_pretrained(
output_dir,
state_dict=state_dict,
is_main_process=self.accelerator.is_main_process,
)
if self.processing_class is not None:
self.processing_class.save_pretrained(output_dir)

View File

@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
- If you are installing from pip
```bash
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@58d6572"
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a668583"
```
## Usage
@@ -88,9 +88,9 @@ plugins:
- qwen2_vl
- qwen3
- qwen3_5
- qwen3_5_text
- qwen3_5_moe
- qwen3_5_moe_vl
- qwen3_5_vl
- qwen3_5_moe_text
- qwen3_moe
- qwen3_next
- qwen3_vl

View File

@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
_CCE_INSTALL_MESSAGE = (
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@58d6572"`'
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a668583"`'
)

View File

@@ -39,6 +39,8 @@ This works for any MoE model in transformers that uses a `SparseMoeBlock` class
ScatterMoE uses a softmax -> topk routing, so results may be different for some model arch as baseline (GPT-OSS, GLM_MOE_DSA).
ScatterMoE does not work for GLM4.7 Flash (glm4_moe_lite) atm.
## Note on MegaBlocks
We tested [MegaBlocks](https://huggingface.co/kernels-community/megablocks) but were unable to ensure numerical accuracy, so we did not integrate it. It was also incompatible with many newer model architectures in transformers.

View File

@@ -34,7 +34,7 @@ def setup_quantized_meta_for_peft(model: torch.nn.Module):
return self
for param in model.parameters():
if isinstance(param, Params4bit):
if isinstance(param, Params4bit) and param.quant_state is not None:
param.quant_state._orig_to = param.quant_state.to
param.quant_state.to = types.MethodType(temp_to_method, param.quant_state)

View File

@@ -172,7 +172,10 @@ class ModelLoader:
# Build the model
PLUGIN_MANAGER.pre_model_load(self.cfg)
self.patch_manager.apply_post_plugin_pre_model_load_patches()
skip_move_to_device = self._build_model()
self.patch_manager.apply_post_model_build_patches(self.model)
PLUGIN_MANAGER.post_model_build(self.cfg, self.model)
# Post-build model configuration
@@ -860,6 +863,10 @@ class ModelLoader:
# Make sure everything is in the same dtype
skip_prepare_model_for_kbit_training = True
if getattr(self.model, "_moe_experts_quantized", False):
# Parametrized expert tensors dequantize on access — would OOM.
skip_prepare_model_for_kbit_training = True
if (
not skip_prepare_model_for_kbit_training
and self.cfg.adapter in ["lora", "qlora"]

View File

@@ -118,6 +118,7 @@ class PatchManager:
def apply_post_plugin_pre_model_load_patches(self):
"""Apply post plugin-pre_model_load load patches based on config."""
self._apply_tiled_mlp(self.cfg.model_config_type)
self._apply_moe_expert_quantization_patch()
def _apply_transformers_patches(self):
from axolotl.monkeypatch.transformers.trainer_loss_calc import (
@@ -135,6 +136,10 @@ class PatchManager:
patch_prepare_context_parallel_inputs()
def apply_post_model_build_patches(self, model: PreTrainedModel):
"""Apply patches right after model build, before post-load setup."""
self._finalize_moe_expert_quantization(model)
def apply_post_model_load_patches(self, model: PreTrainedModel):
"""Apply patches that require the model instance."""
self._apply_llama_flash_attn_patches(model)
@@ -170,9 +175,14 @@ class PatchManager:
patch_parallelism_config()
if self.cfg.fsdp_config and str(self.cfg.fsdp_version) == "2":
from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp2
from axolotl.monkeypatch.accelerate.fsdp2 import (
patch_accelerate_fsdp2,
patch_tied_keys_for_meta_device,
)
patch_accelerate_fsdp2()
if self.cfg.fsdp_config.cpu_ram_efficient_loading:
patch_tied_keys_for_meta_device()
if self.cfg.rl:
from axolotl.monkeypatch.trainer.trl import patch_trl_prepare_fsdp2
@@ -352,15 +362,54 @@ class PatchManager:
if (
self.cfg.fsdp_config
and str(self.cfg.fsdp_version) == "2"
and self.cfg.adapter == "qlora"
and (self.cfg.load_in_4bit or self.cfg.load_in_8bit)
):
from axolotl.monkeypatch.fsdp2_qlora import (
apply_init_dtype_attrs_patch,
apply_init_sharded_param_patch,
apply_init_unsharded_param_patch,
apply_linear8bitlt_save_patch,
)
apply_init_sharded_param_patch()
apply_init_unsharded_param_patch()
apply_init_dtype_attrs_patch()
if self.cfg.load_in_8bit:
apply_linear8bitlt_save_patch()
def _apply_moe_expert_quantization_patch(self):
"""Patch transformers weight loading to quantize MoE expert params on-the-fly."""
if not self.cfg.quantize_moe_experts:
return
from axolotl.monkeypatch.moe_quant import (
patch_moe_quantization_on_load,
patch_peft_target_parameters_matching,
)
patch_moe_quantization_on_load(self.cfg)
patch_peft_target_parameters_matching()
def _finalize_moe_expert_quantization(self, model: PreTrainedModel):
"""Log quantization results and set model flag for downstream use."""
import torch
model._moe_experts_quantized = False
if self.cfg.quantize_moe_experts:
from axolotl.monkeypatch.moe_quant import get_moe_quantized_count
count = get_moe_quantized_count()
if count > 0:
import gc
model._moe_experts_quantized = True
LOG.info(
"Quantized %d MoE expert parameter(s) to %s during model loading",
count,
"4-bit" if self.cfg.load_in_4bit else "8-bit",
)
gc.collect()
torch.cuda.empty_cache()
def _apply_tiled_mlp(self, model_type: str):
if self.cfg.tiled_mlp:

View File

@@ -111,6 +111,7 @@ class MambaLMHeadModel(nn.Module, GenerationMixin):
self,
save_directory: Union[str, os.PathLike],
state_dict: Optional[dict] = None,
**kwargs,
):
if state_dict is None:
state_dict = self.state_dict()

View File

@@ -150,13 +150,17 @@ def get_state_dict(self, model, unwrap=True):
)
elif self.is_fsdp2:
# https://github.com/pytorch/torchtune/blob/main/torchtune/training/_distributed.py#L465
from torch.distributed.tensor import DTensor
state_dict = {}
sharded_state_dict = model.state_dict()
for param_name, param in sharded_state_dict.items():
if param.is_cpu:
param = param.to(torch.device("cuda"))
param = param.full_tensor()
if isinstance(param, DTensor):
param = param.full_tensor()
if torch.distributed.get_rank() == 0:
state_dict[param_name] = param.cpu()
torch.distributed.barrier()
@@ -182,10 +186,56 @@ def get_state_dict(self, model, unwrap=True):
return state_dict
def patch_peft_param_wrapper_for_fsdp2():
"""Patch PEFT's _LoraParameterProxy.forward for FSDP2 DTensor compatibility.
PEFT's ParamWrapper applies LoRA via torch.nn.utils.parametrize, which adds
delta_weight to the base weight W inside _LoraParameterProxy.forward().
Under FSDP2, W may be a DTensor (from FSDP unshard) while delta_weight is a
regular Tensor (or vice versa), causing a RuntimeError on mixed types.
This patch promotes the non-DTensor operand to match the DTensor's spec
using DTensor.from_local(), which is free for Replicate placement (just
metadata wrapping, no communication).
"""
from peft.tuners.lora.layer import _LoraParameterProxy
if getattr(_LoraParameterProxy, "_axolotl_fsdp2_patched", False):
return
_original_forward = _LoraParameterProxy.forward
# NOTE: Replaces (not wraps) forward; assumes original is just `W + self.delta_weight`.
def _patched_forward(self, W):
from torch.distributed.tensor import DTensor
delta = self.delta_weight
w_is_dt = isinstance(W, DTensor)
d_is_dt = isinstance(delta, DTensor)
with torch.nn.utils.parametrize.cached():
if w_is_dt == d_is_dt:
return W + delta
if w_is_dt:
return W + DTensor.from_local(delta, W.device_mesh, W.placements)
return DTensor.from_local(W, delta.device_mesh, delta.placements) + delta
_LoraParameterProxy.forward = _patched_forward
_LoraParameterProxy._axolotl_fsdp2_patched = True
LOG.info("Patched PEFT _LoraParameterProxy.forward for FSDP2 DTensor compatibility")
def _process_lora_module_for_fsdp(module, fsdp2_kwargs):
"""Helper function to process LoRA modules for FSDP2."""
from peft.tuners.lora.layer import ParamWrapper
from torch.distributed.fsdp import fully_shard
# Skip ParamWrapper — its lora_A/B must not be independently sharded.
# The parent decoder layer's FSDP wrapper handles unsharding them.
# TODO: review if we even need to shard them separately in first place.
if isinstance(module, ParamWrapper):
return False
log_bias_dtype_mismatch = False
# Linear4Bit will keep it's bias term in fp32. If the weight dtype is in bf16 we are not able to
@@ -327,6 +377,14 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
is_peft_model = isinstance(model, PeftModel)
# Patch PEFT's _LoraParameterProxy for DTensor compatibility if any
# ParamWrapper modules exist (used for target_parameters / 3D expert params).
if is_peft_model:
from peft.tuners.lora.layer import ParamWrapper
if any(isinstance(m, ParamWrapper) for m in model.modules()):
patch_peft_param_wrapper_for_fsdp2()
auto_wrap_policy = fsdp2_prepare_auto_wrap_policy(fsdp2_plugin, model)
log_bias_dtype_mismatch = False
if auto_wrap_policy is not None:
@@ -376,6 +434,43 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
return model
def patch_tied_keys_for_meta_device():
"""Patch _adjust_tied_keys_with_tied_pointers to skip meta tensors.
Meta tensors all share data_ptr()==0, causing every parameter to be incorrectly
grouped as "tied". Skipping them is safe since they have no real storage.
"""
from collections import defaultdict
from transformers import PreTrainedModel
def _patched_adjust_tied_keys_with_tied_pointers(self, missing_keys):
param_pointers = defaultdict(list)
for param_name, param_value in self.state_dict().items():
if param_value.is_meta:
continue
param_pointers[param_value.data_ptr()].append(param_name)
tied_param_names = [
names
for names in param_pointers.values()
if len(names) > 1
and not any(name in self.all_tied_weights_keys.keys() for name in names)
and not all(name in missing_keys for name in names)
]
tied_weights_keys_by_pointers = {
param_name: group[0]
for group in tied_param_names
for param_name in group[1:]
}
self.all_tied_weights_keys.update(tied_weights_keys_by_pointers)
PreTrainedModel._adjust_tied_keys_with_tied_pointers = (
_patched_adjust_tied_keys_with_tied_pointers
)
def patch_accelerate_fsdp2():
import accelerate

View File

@@ -1,9 +1,10 @@
"""
Monkeypatch to add Params4bit support to FSDP2. This enables QLoRA + FSDP2, as well as
our LoRA / QLoRA Triton kernels to work with FSDP2.
Monkeypatch to add Params4bit and Int8Params support to FSDP2. This enables QLoRA + FSDP2
and 8-bit LoRA + FSDP2, as well as our LoRA / QLoRA Triton kernels to work with FSDP2.
This patch modifies the _init_sharded_param method in FSDPParam to handle bitsandbytes
Params4bit parameters.
This patch modifies the _init_sharded_param and init_unsharded_param methods in FSDPParam
to handle bitsandbytes Params4bit and Int8Params parameters, preserving their quantization
metadata through the FSDP2 shard/unshard cycle.
"""
import importlib
@@ -17,6 +18,8 @@ LOG = get_logger(__name__)
def apply_init_sharded_param_patch():
"""Apply patch to FSDPParam._init_sharded_param to support Params4bit."""
if getattr(apply_init_sharded_param_patch, "_axolotl_patched", False):
return
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
# Get original source
@@ -41,9 +44,20 @@ def apply_init_sharded_param_patch():
bnb_quantized=param.bnb_quantized,
)
self.sharded_param = self.to_sharded_dtensor(self.sharded_param)
elif isinstance(param, bnb.nn.modules.Int8Params):
self.sharded_param = bnb.nn.modules.Int8Params(
data=sharded_param,
requires_grad=param.requires_grad,
has_fp16_weights=param.has_fp16_weights,
CB=None,
SCB=param.SCB,
)
self.sharded_param = self.to_sharded_dtensor(self.sharded_param)
else:
self.sharded_param = nn.Parameter(self.to_sharded_dtensor(sharded_param))
self.sharded_param.requires_grad_(param.requires_grad)"""
self.sharded_param = nn.Parameter(
self.to_sharded_dtensor(sharded_param),
requires_grad=param.requires_grad,
)"""
# Apply the replacement
if original_param_creation in original_source:
@@ -73,6 +87,7 @@ def apply_init_sharded_param_patch():
# Replace the method
FSDPParam._init_sharded_param = patched_init_sharded_param
apply_init_sharded_param_patch._axolotl_patched = True
LOG.info("Successfully applied FSDP _init_sharded_param patch")
else:
LOG.warning("Could not find target code for _init_sharded_param patching")
@@ -80,6 +95,8 @@ def apply_init_sharded_param_patch():
def apply_init_unsharded_param_patch():
"""Apply patch to FSDPParam.init_unsharded_param to support Params4bit."""
if getattr(apply_init_unsharded_param_patch, "_axolotl_patched", False):
return
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
# Get original source
@@ -105,6 +122,14 @@ def apply_init_unsharded_param_patch():
module=local_tensor.module,
bnb_quantized=local_tensor.bnb_quantized,
)
elif isinstance(local_tensor, bnb.nn.modules.Int8Params):
self._unsharded_param = bnb.nn.modules.Int8Params(
data=unsharded_param,
requires_grad=self.sharded_param.requires_grad,
has_fp16_weights=local_tensor.has_fp16_weights,
CB=unsharded_param,
SCB=local_tensor.SCB,
)
else:
self._unsharded_param = nn.Parameter(
unsharded_param, requires_grad=self.sharded_param.requires_grad
@@ -138,6 +163,74 @@ def apply_init_unsharded_param_patch():
# Replace the method
FSDPParam.init_unsharded_param = patched_init_unsharded_param
apply_init_unsharded_param_patch._axolotl_patched = True
LOG.info("Successfully applied FSDP init_unsharded_param patch")
else:
LOG.warning("Could not find target code for patching")
def apply_linear8bitlt_save_patch():
"""Patch Linear8bitLt._save_to_state_dict to handle DTensor-wrapped Int8Params.
After FSDP2 sharding, Linear8bitLt.weight is a DTensor wrapping Int8Params.
BnB's _save_to_state_dict accesses self.weight.SCB directly, but DTensor
doesn't proxy custom attribute access to its _local_tensor. This patch
temporarily unwraps the DTensor during saving so BnB can find the SCB attribute.
"""
if getattr(apply_linear8bitlt_save_patch, "_axolotl_patched", False):
return
import bitsandbytes as bnb
from torch.distributed.tensor import DTensor
original_save = bnb.nn.Linear8bitLt._save_to_state_dict
def _patched_save_to_state_dict(self, destination, prefix, keep_vars):
# Use _parameters dict directly to bypass nn.Module.__setattr__ type check.
weight = self._parameters["weight"]
unwrapped = False
if isinstance(weight, DTensor) and hasattr(weight, "_local_tensor"):
self._parameters["weight"] = weight._local_tensor
unwrapped = True
try:
original_save(self, destination, prefix, keep_vars)
finally:
if unwrapped:
self._parameters["weight"] = weight
bnb.nn.Linear8bitLt._save_to_state_dict = _patched_save_to_state_dict
apply_linear8bitlt_save_patch._axolotl_patched = True
LOG.info("Patched Linear8bitLt._save_to_state_dict for DTensor compatibility")
def apply_init_dtype_attrs_patch():
"""Prevent FSDP2 mixed precision from casting non-float quantized params.
When mixed precision is enabled (e.g., bf16), FSDP2's init_dtype_attrs sets
param_dtype=bf16 for ALL params. During all-gather, _to_dtype_if_needed casts
the sharded param to param_dtype. For non-float params (uint8 packed 4-bit,
int8 quantized) without FSDP2 extensions, this destroys the quantized data.
Params4bit handles this via fsdp_pre/post_all_gather extensions, but our
parametrize-based expert quantization uses plain nn.Parameter(uint8/int8)
without extensions.
"""
if getattr(apply_init_dtype_attrs_patch, "_axolotl_patched", False):
return
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
original_init_dtype_attrs = FSDPParam.init_dtype_attrs
def patched_init_dtype_attrs(self, mp_policy):
original_init_dtype_attrs(self, mp_policy)
# Skip casting non-float quantized params (uint8/int8) without FSDP2
# extensions — the parametrization chain handles dequantization.
if self.param_dtype is not None and not self.sharded_param.is_floating_point():
local = self.sharded_param
if hasattr(local, "_local_tensor"):
local = local._local_tensor
if not hasattr(local, "fsdp_pre_all_gather"):
self.param_dtype = None
FSDPParam.init_dtype_attrs = patched_init_dtype_attrs
apply_init_dtype_attrs_patch._axolotl_patched = True
LOG.info("Patched FSDPParam.init_dtype_attrs for non-float quantized params")

View File

@@ -0,0 +1,188 @@
"""
Loading-time quantization for MoE expert weights stored as 3D nn.Parameter tensors.
In transformers v5, MoE models store expert weights as fused 3D tensors that BnB
skips (only targets nn.Linear). This module patches weight loading to quantize them
on-the-fly (4-bit via bitsandbytes parametrize, 8-bit via custom int8 parametrization),
reducing peak VRAM from "all experts in bf16" to "one expert at a time."
"""
import bitsandbytes as bnb
import torch
import torch.nn.utils.parametrize as P
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
# Module-level state for the loading-time quantization patch.
_moe_load_state = {
"count": 0,
"mode": "4bit",
"quant_type": "nf4",
"compress_statistics": True,
"patched": False,
}
class Bnb8bitParametrization(torch.nn.Module):
"""Parametrization that dequantizes int8 row-wise quantized data on access."""
def __init__(self, row_stats: torch.Tensor):
super().__init__()
self.register_buffer("row_stats", row_stats)
@torch.no_grad()
def forward(self, quantized_param: torch.Tensor) -> torch.Tensor:
# Flatten 3D+ to 2D for BnB's dequant, then reshape back.
orig_shape = quantized_param.shape
if quantized_param.ndim > 2:
quantized_param = quantized_param.reshape(-1, orig_shape[-1])
result = bnb.functional.int8_vectorwise_dequant(quantized_param, self.row_stats)
return result.reshape(orig_shape)
def _enable_parametrization_cache(module, inputs):
P._cache_enabled += 1
def _disable_parametrization_cache(module, inputs, output):
P._cache_enabled -= 1
if not P._cache_enabled:
P._cache = {}
def replace_parameter_8bit(module, param_name):
"""Replace a module parameter with an 8-bit quantized version using parametrization."""
original_param = getattr(module, param_name)
int8_data, row_stats, _ = bnb.functional.int8_vectorwise_quant(
original_param.data.to(torch.float16)
)
setattr(module, param_name, torch.nn.Parameter(int8_data, requires_grad=False))
del original_param
P.register_parametrization(
module, param_name, Bnb8bitParametrization(row_stats), unsafe=True
)
# Cache dequantized values during forward to avoid redundant dequantization.
if not getattr(module, "_axolotl_8bit_hooks_registered", False):
module.register_forward_pre_hook(_enable_parametrization_cache)
module.register_forward_hook(_disable_parametrization_cache)
module._axolotl_8bit_hooks_registered = True
def patch_moe_quantization_on_load(cfg):
"""Patch transformers' weight loading to quantize MoE expert params on-the-fly.
Wraps ``set_param_for_module`` so that 3D+ CUDA tensors with "expert" in their
name are quantized (4-bit or 8-bit) as they're loaded, keeping peak VRAM low.
"""
mode = "8bit" if getattr(cfg, "load_in_8bit", False) else "4bit"
_moe_load_state["mode"] = mode
_moe_load_state["count"] = 0
if _moe_load_state["patched"]:
LOG.debug("MoE loading-time quantization patch already active")
return
import transformers.core_model_loading
import transformers.modeling_utils
if mode == "4bit":
from bitsandbytes.nn.parametrize import replace_parameter_4bit
quant_type = getattr(cfg, "bnb_4bit_quant_type", None) or "nf4"
compress_statistics = getattr(cfg, "bnb_4bit_use_double_quant", None)
if compress_statistics is None:
compress_statistics = True
_moe_load_state["quant_type"] = quant_type
_moe_load_state["compress_statistics"] = compress_statistics
# Disable caching_allocator_warmup — it pre-allocates a huge tensor at bf16
# size for all params, defeating our on-load quantization VRAM savings.
def _noop_warmup(*args, **kwargs):
pass
transformers.modeling_utils.caching_allocator_warmup = _noop_warmup
original_set_param = transformers.core_model_loading.set_param_for_module
def _patched_set_param_for_module(model, target_name, param_value, *args, **kwargs):
original_set_param(model, target_name, param_value, *args, **kwargs)
# Quantize 3D+ expert params that BnB skipped (only on CUDA).
if param_value.ndim >= 3 and param_value.is_cuda:
mod_path, _, pname = target_name.rpartition(".")
mod = model.get_submodule(mod_path) if mod_path else model
if not isinstance(mod, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt)):
if "expert" not in target_name.lower():
LOG.debug(
"Skipping non-expert 3D param: %s (shape=%s)",
target_name,
list(param_value.shape),
)
return
if _moe_load_state["mode"] == "4bit":
replace_parameter_4bit(
mod,
pname,
compress_statistics=_moe_load_state["compress_statistics"],
quant_type=_moe_load_state["quant_type"],
)
else:
replace_parameter_8bit(mod, pname)
_moe_load_state["count"] += 1
# Release the bf16 tensor so CUDA memory is freed immediately.
param_value.data = torch.empty(0, device="cpu")
torch.cuda.empty_cache()
transformers.core_model_loading.set_param_for_module = _patched_set_param_for_module
_moe_load_state["patched"] = True
def get_moe_quantized_count():
"""Return the number of expert parameters quantized during loading."""
return _moe_load_state["count"]
def patch_peft_target_parameters_matching():
"""Fix PEFT's _inject_parameters to use suffix matching for parametrized modules."""
if getattr(patch_peft_target_parameters_matching, "_axolotl_patched", False):
return
from peft.tuners.tuners_utils import BaseTuner
original_inject = BaseTuner._inject_parameters
def _patched_inject_parameters(
self, peft_config, model, adapter_name, low_cpu_mem_usage
):
# Patch target_parameters to use full paths for parametrized modules
original_targets = list(peft_config.target_parameters)
expanded = set(original_targets)
for module_name, module in model.named_modules():
if not hasattr(module, "parametrizations"):
continue
for target in original_targets:
mod_path, _, param_name = target.rpartition(".")
if (
module_name == mod_path or module_name.endswith("." + mod_path)
) and hasattr(module, param_name):
expanded.add(f"{module_name}.{param_name}")
peft_config.target_parameters = sorted(expanded)
try:
return original_inject(
self, peft_config, model, adapter_name, low_cpu_mem_usage
)
finally:
peft_config.target_parameters = original_targets
BaseTuner._inject_parameters = _patched_inject_parameters
patch_peft_target_parameters_matching._axolotl_patched = True
LOG.info("Patched PEFT _inject_parameters for parametrized module suffix matching")

View File

@@ -629,6 +629,17 @@ class AxolotlInputConfig(
},
)
quantize_moe_experts: bool = Field(
default=False,
json_schema_extra={
"description": "Quantize MoE expert weights on load to reduce VRAM. "
"Requires adapter (lora/qlora) with load_in_4bit or load_in_8bit. "
"Requires CUDA (not compatible with ROCm or other backends). "
"Note: total parameter count may be reported incorrectly when enabled "
"(trainable param count is correct)."
},
)
scaling_softmax: bool | None = Field(
default=None,
json_schema_extra={
@@ -1289,6 +1300,26 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
)
return data
@model_validator(mode="before")
@classmethod
def check_quantize_moe_experts(cls, data):
if data.get("quantize_moe_experts"):
if data.get("adapter") not in ("lora", "qlora"):
raise ValueError("quantize_moe_experts requires adapter: lora or qlora")
if not (data.get("load_in_4bit") or data.get("load_in_8bit")):
raise ValueError(
"quantize_moe_experts requires load_in_4bit or load_in_8bit"
)
if (
data.get("capabilities")
and data["capabilities"].get("compute_capability")
and not data["capabilities"]["compute_capability"].startswith("sm_")
):
raise ValueError(
"quantize_moe_experts requires CUDA (not compatible with ROCm or other backends)"
)
return data
@model_validator(mode="before")
@classmethod
def check_auto_enable_lora_kernels(cls, data):

View File

@@ -209,6 +209,19 @@ class LoraConfig(BaseModel):
data["lora_dropout"] = 0.0
return data
@model_validator(mode="after")
def validate_lora_target_parameters_dropout(self):
if (
self.lora_target_parameters
and self.lora_dropout
and self.lora_dropout != 0.0
):
raise ValueError(
"lora_dropout must be 0 when lora_target_parameters is set. "
"PEFT's ParamWrapper does not support lora_dropout != 0."
)
return self
class ReLoRAConfig(BaseModel):
"""ReLoRA configuration subset"""