fix: attempt on-load quantize experts instead of post-load
This commit is contained in:
@@ -172,18 +172,30 @@ class ModelLoader:
|
|||||||
# Build the model
|
# Build the model
|
||||||
PLUGIN_MANAGER.pre_model_load(self.cfg)
|
PLUGIN_MANAGER.pre_model_load(self.cfg)
|
||||||
self.patch_manager.apply_post_plugin_pre_model_load_patches()
|
self.patch_manager.apply_post_plugin_pre_model_load_patches()
|
||||||
|
|
||||||
|
# Activate loading-time quantization for 3D MoE expert params before
|
||||||
|
# from_pretrained() runs. This patches set_param_for_module so each
|
||||||
|
# expert weight is quantized to 4-bit as it's loaded, keeping peak
|
||||||
|
# VRAM to one expert param in bf16 at a time.
|
||||||
|
moe_quant_active = False
|
||||||
|
if self.cfg.adapter in ("qlora", "lora") and self.cfg.load_in_4bit:
|
||||||
|
from axolotl.monkeypatch.moe_quant import patch_moe_quantization_on_load
|
||||||
|
|
||||||
|
patch_moe_quantization_on_load(self.cfg)
|
||||||
|
moe_quant_active = True
|
||||||
|
|
||||||
skip_move_to_device = self._build_model()
|
skip_move_to_device = self._build_model()
|
||||||
|
|
||||||
# Quantize 3D MoE expert nn.Parameter tensors that BnB skips during loading.
|
# Check if any MoE expert params were quantized during loading.
|
||||||
self.model._moe_experts_quantized = False
|
self.model._moe_experts_quantized = False
|
||||||
if self.cfg.adapter in ("qlora", "lora") and self.cfg.load_in_4bit:
|
if moe_quant_active:
|
||||||
from axolotl.monkeypatch.moe_quant import (
|
from axolotl.monkeypatch.moe_quant import (
|
||||||
|
get_moe_quantized_count,
|
||||||
patch_peft_target_parameters_matching,
|
patch_peft_target_parameters_matching,
|
||||||
quantize_moe_expert_params,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.model._moe_experts_quantized = quantize_moe_expert_params(self.model)
|
if get_moe_quantized_count() > 0:
|
||||||
if self.model._moe_experts_quantized:
|
self.model._moe_experts_quantized = True
|
||||||
patch_peft_target_parameters_matching()
|
patch_peft_target_parameters_matching()
|
||||||
|
|
||||||
PLUGIN_MANAGER.post_model_build(self.cfg, self.model)
|
PLUGIN_MANAGER.post_model_build(self.cfg, self.model)
|
||||||
|
|||||||
@@ -1,14 +1,20 @@
|
|||||||
"""
|
"""
|
||||||
Post-load quantization for MoE expert weights stored as 3D nn.Parameter tensors.
|
Loading-time quantization for MoE expert weights stored as 3D nn.Parameter tensors.
|
||||||
|
|
||||||
In transformers v5, many MoE models store expert weights as fused 3D nn.Parameter
|
In transformers v5, many MoE models store expert weights as fused 3D nn.Parameter
|
||||||
tensors instead of individual nn.Linear modules. BnB 4-bit quantization only targets
|
tensors instead of individual nn.Linear modules. BnB 4-bit quantization only targets
|
||||||
nn.Linear, so these expert weights are skipped during model loading, causing OOM.
|
nn.Linear, so these expert weights are loaded in full precision, causing high peak VRAM.
|
||||||
|
|
||||||
|
This module patches transformers' weight loading to quantize 3D expert parameters
|
||||||
|
on-the-fly as they're assigned to modules, using bitsandbytes.nn.parametrize.
|
||||||
|
replace_parameter_4bit (requires bitsandbytes >= 0.48.0). This reduces peak VRAM
|
||||||
|
from "all experts in bf16" to "one expert param in bf16 at a time."
|
||||||
|
|
||||||
This module provides a post-load fixup that quantizes those skipped parameters using
|
|
||||||
bitsandbytes.nn.parametrize.replace_parameter_4bit (requires bitsandbytes >= 0.48.0).
|
|
||||||
PEFT's target_parameters / ParamWrapper can then apply LoRA on top of these quantized
|
PEFT's target_parameters / ParamWrapper can then apply LoRA on top of these quantized
|
||||||
params via stacked parametrizations.
|
params via stacked parametrizations.
|
||||||
|
|
||||||
|
Note: FSDP2 cpu ram efficient loading and Tensor Parallel (DTensor) compatibility
|
||||||
|
with parametrization is untested.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
@@ -18,6 +24,86 @@ from axolotl.utils.logging import get_logger
|
|||||||
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
# Module-level state for the loading-time quantization patch.
|
||||||
|
_moe_load_state = {
|
||||||
|
"count": 0,
|
||||||
|
"quant_type": "nf4",
|
||||||
|
"compress_statistics": True,
|
||||||
|
"patched": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def patch_moe_quantization_on_load(cfg):
|
||||||
|
"""Patch transformers' weight loading to quantize 3D MoE expert params on-the-fly.
|
||||||
|
|
||||||
|
Wraps ``transformers.core_model_loading.set_param_for_module`` so that after each
|
||||||
|
parameter is assigned to its module, any 3D+ tensor on CUDA that BnB skipped
|
||||||
|
(i.e. not inside a Linear4bit/Linear8bitLt) is immediately quantized via
|
||||||
|
``replace_parameter_4bit``. This keeps peak VRAM to one expert param in bf16
|
||||||
|
at a time, instead of loading all experts in bf16 first.
|
||||||
|
|
||||||
|
The patch stays active permanently — the ``ndim >= 3`` and ``is_cuda`` checks
|
||||||
|
make it safe for non-MoE models (no false positives).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg: Axolotl DictDefault config. Reads bnb_4bit_quant_type and
|
||||||
|
bnb_4bit_use_double_quant for quantization settings.
|
||||||
|
"""
|
||||||
|
if _moe_load_state["patched"]:
|
||||||
|
LOG.debug("MoE loading-time quantization patch already active")
|
||||||
|
return
|
||||||
|
|
||||||
|
import transformers.core_model_loading
|
||||||
|
from bitsandbytes.nn.parametrize import replace_parameter_4bit
|
||||||
|
|
||||||
|
# Read quantization settings from config
|
||||||
|
quant_type = getattr(cfg, "bnb_4bit_quant_type", None) or "nf4"
|
||||||
|
compress_statistics = getattr(cfg, "bnb_4bit_use_double_quant", None)
|
||||||
|
if compress_statistics is None:
|
||||||
|
compress_statistics = True
|
||||||
|
|
||||||
|
_moe_load_state["quant_type"] = quant_type
|
||||||
|
_moe_load_state["compress_statistics"] = compress_statistics
|
||||||
|
_moe_load_state["count"] = 0
|
||||||
|
|
||||||
|
original_set_param = transformers.core_model_loading.set_param_for_module
|
||||||
|
|
||||||
|
def _patched_set_param_for_module(model, target_name, param_value, *args, **kwargs):
|
||||||
|
original_set_param(model, target_name, param_value, *args, **kwargs)
|
||||||
|
|
||||||
|
# Quantize 3D+ expert params that BnB skipped (only on CUDA).
|
||||||
|
if param_value.ndim >= 3 and param_value.is_cuda:
|
||||||
|
mod_path, _, pname = target_name.rpartition(".")
|
||||||
|
mod = model.get_submodule(mod_path) if mod_path else model
|
||||||
|
if not isinstance(mod, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt)):
|
||||||
|
replace_parameter_4bit(
|
||||||
|
mod,
|
||||||
|
pname,
|
||||||
|
compress_statistics=_moe_load_state["compress_statistics"],
|
||||||
|
quant_type=_moe_load_state["quant_type"],
|
||||||
|
)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
_moe_load_state["count"] += 1
|
||||||
|
LOG.debug(
|
||||||
|
"Quantized 3D expert param during loading: %s (shape %s)",
|
||||||
|
target_name,
|
||||||
|
param_value.shape,
|
||||||
|
)
|
||||||
|
|
||||||
|
transformers.core_model_loading.set_param_for_module = _patched_set_param_for_module
|
||||||
|
_moe_load_state["patched"] = True
|
||||||
|
LOG.info(
|
||||||
|
"Activated MoE loading-time quantization patch "
|
||||||
|
"(quant_type=%s, compress_statistics=%s)",
|
||||||
|
quant_type,
|
||||||
|
compress_statistics,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_moe_quantized_count():
|
||||||
|
"""Return the number of expert parameters quantized during loading."""
|
||||||
|
return _moe_load_state["count"]
|
||||||
|
|
||||||
|
|
||||||
def patch_peft_target_parameters_matching():
|
def patch_peft_target_parameters_matching():
|
||||||
"""Fix PEFT's _inject_parameters to use suffix matching for parametrized modules.
|
"""Fix PEFT's _inject_parameters to use suffix matching for parametrized modules.
|
||||||
@@ -59,72 +145,3 @@ def patch_peft_target_parameters_matching():
|
|||||||
|
|
||||||
BaseTuner._inject_parameters = _patched_inject_parameters
|
BaseTuner._inject_parameters = _patched_inject_parameters
|
||||||
LOG.info("Patched PEFT _inject_parameters for parametrized module suffix matching")
|
LOG.info("Patched PEFT _inject_parameters for parametrized module suffix matching")
|
||||||
|
|
||||||
|
|
||||||
def find_unquantized_expert_params(model):
|
|
||||||
"""Find 3D+ nn.Parameter tensors that BnB quantization skipped.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of (module, param_name) tuples to quantize.
|
|
||||||
"""
|
|
||||||
params_to_quantize = []
|
|
||||||
for _, module in model.named_modules():
|
|
||||||
if isinstance(module, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt)):
|
|
||||||
continue
|
|
||||||
for param_name, param in module.named_parameters(recurse=False):
|
|
||||||
if param.ndim >= 3 and any(
|
|
||||||
kw in param_name for kw in ("experts", "gate_up_proj", "down_proj")
|
|
||||||
):
|
|
||||||
params_to_quantize.append((module, param_name))
|
|
||||||
return params_to_quantize
|
|
||||||
|
|
||||||
|
|
||||||
def quantize_moe_expert_params(model, quant_type=None, compress_statistics=None):
|
|
||||||
"""Quantize 3D nn.Parameter expert weights that BnB skips during model loading.
|
|
||||||
|
|
||||||
Reads quant_type and compress_statistics from the model's quantization_config
|
|
||||||
when not explicitly provided, so that the same settings used for nn.Linear
|
|
||||||
quantization are applied to the MoE expert parameters.
|
|
||||||
"""
|
|
||||||
from bitsandbytes.nn.parametrize import replace_parameter_4bit
|
|
||||||
|
|
||||||
params_to_quantize = find_unquantized_expert_params(model)
|
|
||||||
if not params_to_quantize:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Derive settings from model's BnB config if not explicitly provided
|
|
||||||
if quant_type is None or compress_statistics is None:
|
|
||||||
bnb_config = getattr(model.config, "quantization_config", None)
|
|
||||||
if bnb_config is not None:
|
|
||||||
if quant_type is None:
|
|
||||||
quant_type = getattr(bnb_config, "bnb_4bit_quant_type", "nf4")
|
|
||||||
if compress_statistics is None:
|
|
||||||
compress_statistics = getattr(
|
|
||||||
bnb_config, "bnb_4bit_use_double_quant", True
|
|
||||||
)
|
|
||||||
# Final defaults
|
|
||||||
if quant_type is None:
|
|
||||||
quant_type = "nf4"
|
|
||||||
if compress_statistics is None:
|
|
||||||
compress_statistics = True
|
|
||||||
|
|
||||||
count = 0
|
|
||||||
for module, param_name in params_to_quantize:
|
|
||||||
replace_parameter_4bit(
|
|
||||||
module,
|
|
||||||
param_name,
|
|
||||||
compress_statistics=compress_statistics,
|
|
||||||
quant_type=quant_type,
|
|
||||||
)
|
|
||||||
count += 1
|
|
||||||
# Free the bf16 → 4-bit conversion buffers after each parameter
|
|
||||||
# to avoid accumulating peak reserved VRAM.
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
LOG.info(
|
|
||||||
"Quantized %d MoE expert parameters to 4-bit (quant_type=%s, compress_statistics=%s)",
|
|
||||||
count,
|
|
||||||
quant_type,
|
|
||||||
compress_statistics,
|
|
||||||
)
|
|
||||||
return True
|
|
||||||
|
|||||||
Reference in New Issue
Block a user