fix: attempt on-load quantize experts instead of post-load
This commit is contained in:
@@ -172,18 +172,30 @@ class ModelLoader:
|
||||
# Build the model
|
||||
PLUGIN_MANAGER.pre_model_load(self.cfg)
|
||||
self.patch_manager.apply_post_plugin_pre_model_load_patches()
|
||||
|
||||
# Activate loading-time quantization for 3D MoE expert params before
|
||||
# from_pretrained() runs. This patches set_param_for_module so each
|
||||
# expert weight is quantized to 4-bit as it's loaded, keeping peak
|
||||
# VRAM to one expert param in bf16 at a time.
|
||||
moe_quant_active = False
|
||||
if self.cfg.adapter in ("qlora", "lora") and self.cfg.load_in_4bit:
|
||||
from axolotl.monkeypatch.moe_quant import patch_moe_quantization_on_load
|
||||
|
||||
patch_moe_quantization_on_load(self.cfg)
|
||||
moe_quant_active = True
|
||||
|
||||
skip_move_to_device = self._build_model()
|
||||
|
||||
# Quantize 3D MoE expert nn.Parameter tensors that BnB skips during loading.
|
||||
# Check if any MoE expert params were quantized during loading.
|
||||
self.model._moe_experts_quantized = False
|
||||
if self.cfg.adapter in ("qlora", "lora") and self.cfg.load_in_4bit:
|
||||
if moe_quant_active:
|
||||
from axolotl.monkeypatch.moe_quant import (
|
||||
get_moe_quantized_count,
|
||||
patch_peft_target_parameters_matching,
|
||||
quantize_moe_expert_params,
|
||||
)
|
||||
|
||||
self.model._moe_experts_quantized = quantize_moe_expert_params(self.model)
|
||||
if self.model._moe_experts_quantized:
|
||||
if get_moe_quantized_count() > 0:
|
||||
self.model._moe_experts_quantized = True
|
||||
patch_peft_target_parameters_matching()
|
||||
|
||||
PLUGIN_MANAGER.post_model_build(self.cfg, self.model)
|
||||
|
||||
@@ -1,14 +1,20 @@
|
||||
"""
|
||||
Post-load quantization for MoE expert weights stored as 3D nn.Parameter tensors.
|
||||
Loading-time quantization for MoE expert weights stored as 3D nn.Parameter tensors.
|
||||
|
||||
In transformers v5, many MoE models store expert weights as fused 3D nn.Parameter
|
||||
tensors instead of individual nn.Linear modules. BnB 4-bit quantization only targets
|
||||
nn.Linear, so these expert weights are skipped during model loading, causing OOM.
|
||||
nn.Linear, so these expert weights are loaded in full precision, causing high peak VRAM.
|
||||
|
||||
This module patches transformers' weight loading to quantize 3D expert parameters
|
||||
on-the-fly as they're assigned to modules, using bitsandbytes.nn.parametrize.
|
||||
replace_parameter_4bit (requires bitsandbytes >= 0.48.0). This reduces peak VRAM
|
||||
from "all experts in bf16" to "one expert param in bf16 at a time."
|
||||
|
||||
This module provides a post-load fixup that quantizes those skipped parameters using
|
||||
bitsandbytes.nn.parametrize.replace_parameter_4bit (requires bitsandbytes >= 0.48.0).
|
||||
PEFT's target_parameters / ParamWrapper can then apply LoRA on top of these quantized
|
||||
params via stacked parametrizations.
|
||||
|
||||
Note: FSDP2 cpu ram efficient loading and Tensor Parallel (DTensor) compatibility
|
||||
with parametrization is untested.
|
||||
"""
|
||||
|
||||
import bitsandbytes as bnb
|
||||
@@ -18,6 +24,86 @@ from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
# Module-level state for the loading-time quantization patch.
|
||||
_moe_load_state = {
|
||||
"count": 0,
|
||||
"quant_type": "nf4",
|
||||
"compress_statistics": True,
|
||||
"patched": False,
|
||||
}
|
||||
|
||||
|
||||
def patch_moe_quantization_on_load(cfg):
|
||||
"""Patch transformers' weight loading to quantize 3D MoE expert params on-the-fly.
|
||||
|
||||
Wraps ``transformers.core_model_loading.set_param_for_module`` so that after each
|
||||
parameter is assigned to its module, any 3D+ tensor on CUDA that BnB skipped
|
||||
(i.e. not inside a Linear4bit/Linear8bitLt) is immediately quantized via
|
||||
``replace_parameter_4bit``. This keeps peak VRAM to one expert param in bf16
|
||||
at a time, instead of loading all experts in bf16 first.
|
||||
|
||||
The patch stays active permanently — the ``ndim >= 3`` and ``is_cuda`` checks
|
||||
make it safe for non-MoE models (no false positives).
|
||||
|
||||
Args:
|
||||
cfg: Axolotl DictDefault config. Reads bnb_4bit_quant_type and
|
||||
bnb_4bit_use_double_quant for quantization settings.
|
||||
"""
|
||||
if _moe_load_state["patched"]:
|
||||
LOG.debug("MoE loading-time quantization patch already active")
|
||||
return
|
||||
|
||||
import transformers.core_model_loading
|
||||
from bitsandbytes.nn.parametrize import replace_parameter_4bit
|
||||
|
||||
# Read quantization settings from config
|
||||
quant_type = getattr(cfg, "bnb_4bit_quant_type", None) or "nf4"
|
||||
compress_statistics = getattr(cfg, "bnb_4bit_use_double_quant", None)
|
||||
if compress_statistics is None:
|
||||
compress_statistics = True
|
||||
|
||||
_moe_load_state["quant_type"] = quant_type
|
||||
_moe_load_state["compress_statistics"] = compress_statistics
|
||||
_moe_load_state["count"] = 0
|
||||
|
||||
original_set_param = transformers.core_model_loading.set_param_for_module
|
||||
|
||||
def _patched_set_param_for_module(model, target_name, param_value, *args, **kwargs):
|
||||
original_set_param(model, target_name, param_value, *args, **kwargs)
|
||||
|
||||
# Quantize 3D+ expert params that BnB skipped (only on CUDA).
|
||||
if param_value.ndim >= 3 and param_value.is_cuda:
|
||||
mod_path, _, pname = target_name.rpartition(".")
|
||||
mod = model.get_submodule(mod_path) if mod_path else model
|
||||
if not isinstance(mod, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt)):
|
||||
replace_parameter_4bit(
|
||||
mod,
|
||||
pname,
|
||||
compress_statistics=_moe_load_state["compress_statistics"],
|
||||
quant_type=_moe_load_state["quant_type"],
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
_moe_load_state["count"] += 1
|
||||
LOG.debug(
|
||||
"Quantized 3D expert param during loading: %s (shape %s)",
|
||||
target_name,
|
||||
param_value.shape,
|
||||
)
|
||||
|
||||
transformers.core_model_loading.set_param_for_module = _patched_set_param_for_module
|
||||
_moe_load_state["patched"] = True
|
||||
LOG.info(
|
||||
"Activated MoE loading-time quantization patch "
|
||||
"(quant_type=%s, compress_statistics=%s)",
|
||||
quant_type,
|
||||
compress_statistics,
|
||||
)
|
||||
|
||||
|
||||
def get_moe_quantized_count():
|
||||
"""Return the number of expert parameters quantized during loading."""
|
||||
return _moe_load_state["count"]
|
||||
|
||||
|
||||
def patch_peft_target_parameters_matching():
|
||||
"""Fix PEFT's _inject_parameters to use suffix matching for parametrized modules.
|
||||
@@ -59,72 +145,3 @@ def patch_peft_target_parameters_matching():
|
||||
|
||||
BaseTuner._inject_parameters = _patched_inject_parameters
|
||||
LOG.info("Patched PEFT _inject_parameters for parametrized module suffix matching")
|
||||
|
||||
|
||||
def find_unquantized_expert_params(model):
|
||||
"""Find 3D+ nn.Parameter tensors that BnB quantization skipped.
|
||||
|
||||
Returns:
|
||||
List of (module, param_name) tuples to quantize.
|
||||
"""
|
||||
params_to_quantize = []
|
||||
for _, module in model.named_modules():
|
||||
if isinstance(module, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt)):
|
||||
continue
|
||||
for param_name, param in module.named_parameters(recurse=False):
|
||||
if param.ndim >= 3 and any(
|
||||
kw in param_name for kw in ("experts", "gate_up_proj", "down_proj")
|
||||
):
|
||||
params_to_quantize.append((module, param_name))
|
||||
return params_to_quantize
|
||||
|
||||
|
||||
def quantize_moe_expert_params(model, quant_type=None, compress_statistics=None):
|
||||
"""Quantize 3D nn.Parameter expert weights that BnB skips during model loading.
|
||||
|
||||
Reads quant_type and compress_statistics from the model's quantization_config
|
||||
when not explicitly provided, so that the same settings used for nn.Linear
|
||||
quantization are applied to the MoE expert parameters.
|
||||
"""
|
||||
from bitsandbytes.nn.parametrize import replace_parameter_4bit
|
||||
|
||||
params_to_quantize = find_unquantized_expert_params(model)
|
||||
if not params_to_quantize:
|
||||
return False
|
||||
|
||||
# Derive settings from model's BnB config if not explicitly provided
|
||||
if quant_type is None or compress_statistics is None:
|
||||
bnb_config = getattr(model.config, "quantization_config", None)
|
||||
if bnb_config is not None:
|
||||
if quant_type is None:
|
||||
quant_type = getattr(bnb_config, "bnb_4bit_quant_type", "nf4")
|
||||
if compress_statistics is None:
|
||||
compress_statistics = getattr(
|
||||
bnb_config, "bnb_4bit_use_double_quant", True
|
||||
)
|
||||
# Final defaults
|
||||
if quant_type is None:
|
||||
quant_type = "nf4"
|
||||
if compress_statistics is None:
|
||||
compress_statistics = True
|
||||
|
||||
count = 0
|
||||
for module, param_name in params_to_quantize:
|
||||
replace_parameter_4bit(
|
||||
module,
|
||||
param_name,
|
||||
compress_statistics=compress_statistics,
|
||||
quant_type=quant_type,
|
||||
)
|
||||
count += 1
|
||||
# Free the bf16 → 4-bit conversion buffers after each parameter
|
||||
# to avoid accumulating peak reserved VRAM.
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
LOG.info(
|
||||
"Quantized %d MoE expert parameters to 4-bit (quant_type=%s, compress_statistics=%s)",
|
||||
count,
|
||||
quant_type,
|
||||
compress_statistics,
|
||||
)
|
||||
return True
|
||||
|
||||
Reference in New Issue
Block a user