fix: attempt on-load quantize experts instead of post-load

This commit is contained in:
NanoCode012
2026-02-25 17:20:40 +07:00
parent 593599a217
commit d3d6cb6b67
2 changed files with 107 additions and 78 deletions

View File

@@ -172,18 +172,30 @@ class ModelLoader:
# Build the model
PLUGIN_MANAGER.pre_model_load(self.cfg)
self.patch_manager.apply_post_plugin_pre_model_load_patches()
# Activate loading-time quantization for 3D MoE expert params before
# from_pretrained() runs. This patches set_param_for_module so each
# expert weight is quantized to 4-bit as it's loaded, keeping peak
# VRAM to one expert param in bf16 at a time.
moe_quant_active = False
if self.cfg.adapter in ("qlora", "lora") and self.cfg.load_in_4bit:
from axolotl.monkeypatch.moe_quant import patch_moe_quantization_on_load
patch_moe_quantization_on_load(self.cfg)
moe_quant_active = True
skip_move_to_device = self._build_model()
# Quantize 3D MoE expert nn.Parameter tensors that BnB skips during loading.
# Check if any MoE expert params were quantized during loading.
self.model._moe_experts_quantized = False
if self.cfg.adapter in ("qlora", "lora") and self.cfg.load_in_4bit:
if moe_quant_active:
from axolotl.monkeypatch.moe_quant import (
get_moe_quantized_count,
patch_peft_target_parameters_matching,
quantize_moe_expert_params,
)
self.model._moe_experts_quantized = quantize_moe_expert_params(self.model)
if self.model._moe_experts_quantized:
if get_moe_quantized_count() > 0:
self.model._moe_experts_quantized = True
patch_peft_target_parameters_matching()
PLUGIN_MANAGER.post_model_build(self.cfg, self.model)

View File

@@ -1,14 +1,20 @@
"""
Post-load quantization for MoE expert weights stored as 3D nn.Parameter tensors.
Loading-time quantization for MoE expert weights stored as 3D nn.Parameter tensors.
In transformers v5, many MoE models store expert weights as fused 3D nn.Parameter
tensors instead of individual nn.Linear modules. BnB 4-bit quantization only targets
nn.Linear, so these expert weights are skipped during model loading, causing OOM.
nn.Linear, so these expert weights are loaded in full precision, causing high peak VRAM.
This module patches transformers' weight loading to quantize 3D expert parameters
on-the-fly as they're assigned to modules, using bitsandbytes.nn.parametrize.
replace_parameter_4bit (requires bitsandbytes >= 0.48.0). This reduces peak VRAM
from "all experts in bf16" to "one expert param in bf16 at a time."
This module provides a post-load fixup that quantizes those skipped parameters using
bitsandbytes.nn.parametrize.replace_parameter_4bit (requires bitsandbytes >= 0.48.0).
PEFT's target_parameters / ParamWrapper can then apply LoRA on top of these quantized
params via stacked parametrizations.
Note: FSDP2 cpu ram efficient loading and Tensor Parallel (DTensor) compatibility
with parametrization is untested.
"""
import bitsandbytes as bnb
@@ -18,6 +24,86 @@ from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
# Module-level state for the loading-time quantization patch.
_moe_load_state = {
"count": 0,
"quant_type": "nf4",
"compress_statistics": True,
"patched": False,
}
def patch_moe_quantization_on_load(cfg):
"""Patch transformers' weight loading to quantize 3D MoE expert params on-the-fly.
Wraps ``transformers.core_model_loading.set_param_for_module`` so that after each
parameter is assigned to its module, any 3D+ tensor on CUDA that BnB skipped
(i.e. not inside a Linear4bit/Linear8bitLt) is immediately quantized via
``replace_parameter_4bit``. This keeps peak VRAM to one expert param in bf16
at a time, instead of loading all experts in bf16 first.
The patch stays active permanently — the ``ndim >= 3`` and ``is_cuda`` checks
make it safe for non-MoE models (no false positives).
Args:
cfg: Axolotl DictDefault config. Reads bnb_4bit_quant_type and
bnb_4bit_use_double_quant for quantization settings.
"""
if _moe_load_state["patched"]:
LOG.debug("MoE loading-time quantization patch already active")
return
import transformers.core_model_loading
from bitsandbytes.nn.parametrize import replace_parameter_4bit
# Read quantization settings from config
quant_type = getattr(cfg, "bnb_4bit_quant_type", None) or "nf4"
compress_statistics = getattr(cfg, "bnb_4bit_use_double_quant", None)
if compress_statistics is None:
compress_statistics = True
_moe_load_state["quant_type"] = quant_type
_moe_load_state["compress_statistics"] = compress_statistics
_moe_load_state["count"] = 0
original_set_param = transformers.core_model_loading.set_param_for_module
def _patched_set_param_for_module(model, target_name, param_value, *args, **kwargs):
original_set_param(model, target_name, param_value, *args, **kwargs)
# Quantize 3D+ expert params that BnB skipped (only on CUDA).
if param_value.ndim >= 3 and param_value.is_cuda:
mod_path, _, pname = target_name.rpartition(".")
mod = model.get_submodule(mod_path) if mod_path else model
if not isinstance(mod, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt)):
replace_parameter_4bit(
mod,
pname,
compress_statistics=_moe_load_state["compress_statistics"],
quant_type=_moe_load_state["quant_type"],
)
torch.cuda.empty_cache()
_moe_load_state["count"] += 1
LOG.debug(
"Quantized 3D expert param during loading: %s (shape %s)",
target_name,
param_value.shape,
)
transformers.core_model_loading.set_param_for_module = _patched_set_param_for_module
_moe_load_state["patched"] = True
LOG.info(
"Activated MoE loading-time quantization patch "
"(quant_type=%s, compress_statistics=%s)",
quant_type,
compress_statistics,
)
def get_moe_quantized_count():
"""Return the number of expert parameters quantized during loading."""
return _moe_load_state["count"]
def patch_peft_target_parameters_matching():
"""Fix PEFT's _inject_parameters to use suffix matching for parametrized modules.
@@ -59,72 +145,3 @@ def patch_peft_target_parameters_matching():
BaseTuner._inject_parameters = _patched_inject_parameters
LOG.info("Patched PEFT _inject_parameters for parametrized module suffix matching")
def find_unquantized_expert_params(model):
"""Find 3D+ nn.Parameter tensors that BnB quantization skipped.
Returns:
List of (module, param_name) tuples to quantize.
"""
params_to_quantize = []
for _, module in model.named_modules():
if isinstance(module, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt)):
continue
for param_name, param in module.named_parameters(recurse=False):
if param.ndim >= 3 and any(
kw in param_name for kw in ("experts", "gate_up_proj", "down_proj")
):
params_to_quantize.append((module, param_name))
return params_to_quantize
def quantize_moe_expert_params(model, quant_type=None, compress_statistics=None):
"""Quantize 3D nn.Parameter expert weights that BnB skips during model loading.
Reads quant_type and compress_statistics from the model's quantization_config
when not explicitly provided, so that the same settings used for nn.Linear
quantization are applied to the MoE expert parameters.
"""
from bitsandbytes.nn.parametrize import replace_parameter_4bit
params_to_quantize = find_unquantized_expert_params(model)
if not params_to_quantize:
return False
# Derive settings from model's BnB config if not explicitly provided
if quant_type is None or compress_statistics is None:
bnb_config = getattr(model.config, "quantization_config", None)
if bnb_config is not None:
if quant_type is None:
quant_type = getattr(bnb_config, "bnb_4bit_quant_type", "nf4")
if compress_statistics is None:
compress_statistics = getattr(
bnb_config, "bnb_4bit_use_double_quant", True
)
# Final defaults
if quant_type is None:
quant_type = "nf4"
if compress_statistics is None:
compress_statistics = True
count = 0
for module, param_name in params_to_quantize:
replace_parameter_4bit(
module,
param_name,
compress_statistics=compress_statistics,
quant_type=quant_type,
)
count += 1
# Free the bf16 → 4-bit conversion buffers after each parameter
# to avoid accumulating peak reserved VRAM.
torch.cuda.empty_cache()
LOG.info(
"Quantized %d MoE expert parameters to 4-bit (quant_type=%s, compress_statistics=%s)",
count,
quant_type,
compress_statistics,
)
return True