feat: attempt to quant experts in 8bit mode too

This commit is contained in:
NanoCode012
2026-02-26 16:11:54 +07:00
parent 88a48eff8a
commit f68d9f839d
2 changed files with 123 additions and 31 deletions

View File

@@ -175,10 +175,12 @@ class ModelLoader:
# Activate loading-time quantization for 3D MoE expert params before
# from_pretrained() runs. This patches set_param_for_module so each
# expert weight is quantized to 4-bit as it's loaded, keeping peak
# VRAM to one expert param in bf16 at a time.
# expert weight is quantized (4-bit or 8-bit) as it's loaded, keeping
# peak VRAM to one expert param in bf16 at a time.
moe_quant_active = False
if self.cfg.adapter in ("qlora", "lora") and self.cfg.load_in_4bit:
if self.cfg.adapter in ("qlora", "lora") and (
self.cfg.load_in_4bit or self.cfg.load_in_8bit
):
from axolotl.monkeypatch.moe_quant import patch_moe_quantization_on_load
patch_moe_quantization_on_load(self.cfg)

View File

@@ -2,23 +2,23 @@
Loading-time quantization for MoE expert weights stored as 3D nn.Parameter tensors.
In transformers v5, many MoE models store expert weights as fused 3D nn.Parameter
tensors instead of individual nn.Linear modules. BnB 4-bit quantization only targets
tensors instead of individual nn.Linear modules. BnB quantization only targets
nn.Linear, so these expert weights are loaded in full precision, causing high peak VRAM.
This module patches transformers' weight loading to quantize 3D expert parameters
on-the-fly as they're assigned to modules, using bitsandbytes.nn.parametrize.
replace_parameter_4bit (requires bitsandbytes >= 0.48.0). This reduces peak VRAM
from "all experts in bf16" to "one expert param in bf16 at a time."
on-the-fly as they're assigned to modules. For 4-bit, it uses
bitsandbytes.nn.parametrize.replace_parameter_4bit (requires bitsandbytes >= 0.48.0).
For 8-bit, it uses a custom parametrization built on bitsandbytes.functional's
int8_vectorwise_quant/dequant (row-wise absmax scaling). Both reduce peak VRAM from
"all experts in bf16" to "one expert param in bf16 at a time."
PEFT's target_parameters / ParamWrapper can then apply LoRA on top of these quantized
params via stacked parametrizations.
Note: FSDP2 cpu ram efficient loading and Tensor Parallel (DTensor) compatibility
with parametrization is untested.
"""
import bitsandbytes as bnb
import torch
import torch.nn.utils.parametrize as P
from axolotl.utils.logging import get_logger
@@ -27,27 +27,105 @@ LOG = get_logger(__name__)
# Module-level state for the loading-time quantization patch.
_moe_load_state = {
"count": 0,
"mode": "4bit",
"quant_type": "nf4",
"compress_statistics": True,
"patched": False,
}
class Bnb8bitParametrization(torch.nn.Module):
"""Parametrization that dequantizes int8 row-wise quantized data on access.
Mirrors ``Bnb4bitParametrization`` from ``bitsandbytes.nn.parametrize`` but for
int8 row-wise (absmax) quantization. Stores the per-row scales as a buffer and
delegates to ``bitsandbytes.functional.int8_vectorwise_dequant`` which computes
``int8_data * row_stats * (1/127)``.
"""
def __init__(self, row_stats: torch.Tensor):
super().__init__()
self.register_buffer("row_stats", row_stats)
@torch.no_grad()
def forward(self, quantized_param: torch.Tensor) -> torch.Tensor:
return bnb.functional.int8_vectorwise_dequant(quantized_param, self.row_stats)
def _enable_parametrization_cache(module, inputs):
P._cache_enabled += 1
def _disable_parametrization_cache(module, inputs, output):
P._cache_enabled -= 1
if not P._cache_enabled:
P._cache = {}
def replace_parameter_8bit(module, param_name):
"""Replace a module parameter with an 8-bit quantized version using parametrization.
Mirrors ``bitsandbytes.nn.parametrize.replace_parameter_4bit`` but for int8
row-wise (absmax) quantization. Uses ``int8_vectorwise_quant`` which supports
N-D tensors natively (scales shape = ``prod(shape[:-1])``).
Args:
module: The module containing the parameter to quantize.
param_name: Name of the parameter within the module.
"""
original_param = getattr(module, param_name)
int8_data, row_stats, _ = bnb.functional.int8_vectorwise_quant(
original_param.data.to(torch.float16)
)
setattr(module, param_name, torch.nn.Parameter(int8_data, requires_grad=False))
del original_param
P.register_parametrization(
module, param_name, Bnb8bitParametrization(row_stats), unsafe=True
)
# Register caching hooks (same pattern as BnB 4-bit). Caching avoids
# redundant dequantization when the same param is accessed multiple times
# in a single forward pass.
module.register_forward_pre_hook(_enable_parametrization_cache)
module.register_forward_hook(_disable_parametrization_cache)
def _8bit_state_dict_post_hook(
module, state_dict, prefix, local_metadata, *, param_name
):
"""Placeholder for 8-bit state_dict serialization hook.
For LoRA/QLoRA training, only adapter weights (lora_A/B) are saved — base model
weights including quantized experts are not serialized. State dict hooks for 8-bit
are therefore not needed for the primary use case. If full-model saving with 8-bit
quantized expert params is needed, this hook should store int8 data + row_stats.
"""
raise NotImplementedError(
"State dict serialization for 8-bit quantized expert parameters is not yet "
"implemented. This is not needed for LoRA/QLoRA training (only adapter weights "
"are saved). If you need to save the full model with 8-bit quantized experts, "
"please open an issue."
)
def patch_moe_quantization_on_load(cfg):
"""Patch transformers' weight loading to quantize 3D MoE expert params on-the-fly.
Wraps ``transformers.core_model_loading.set_param_for_module`` so that after each
parameter is assigned to its module, any 3D+ tensor on CUDA that BnB skipped
(i.e. not inside a Linear4bit/Linear8bitLt) is immediately quantized via
``replace_parameter_4bit``. This keeps peak VRAM to one expert param in bf16
at a time, instead of loading all experts in bf16 first.
``replace_parameter_4bit`` (for 4-bit) or ``replace_parameter_8bit`` (for 8-bit).
This keeps peak VRAM to one expert param in bf16 at a time, instead of loading
all experts in bf16 first.
The patch stays active permanently — the ``ndim >= 3`` and ``is_cuda`` checks
make it safe for non-MoE models (no false positives).
Args:
cfg: Axolotl DictDefault config. Reads bnb_4bit_quant_type and
bnb_4bit_use_double_quant for quantization settings.
cfg: Axolotl DictDefault config. For 4-bit, reads bnb_4bit_quant_type and
bnb_4bit_use_double_quant. For 8-bit, no additional settings needed.
"""
if _moe_load_state["patched"]:
LOG.debug("MoE loading-time quantization patch already active")
@@ -55,7 +133,26 @@ def patch_moe_quantization_on_load(cfg):
import transformers.core_model_loading
import transformers.modeling_utils
from bitsandbytes.nn.parametrize import replace_parameter_4bit
# Determine quantization mode from config.
if getattr(cfg, "load_in_8bit", False):
mode = "8bit"
else:
mode = "4bit"
_moe_load_state["mode"] = mode
_moe_load_state["count"] = 0
if mode == "4bit":
from bitsandbytes.nn.parametrize import replace_parameter_4bit
quant_type = getattr(cfg, "bnb_4bit_quant_type", None) or "nf4"
compress_statistics = getattr(cfg, "bnb_4bit_use_double_quant", None)
if compress_statistics is None:
compress_statistics = True
_moe_load_state["quant_type"] = quant_type
_moe_load_state["compress_statistics"] = compress_statistics
# Patch caching_allocator_warmup to be a no-op. This function pre-allocates
# a single huge GPU tensor equal to the model's total param bytes to warm the
@@ -68,16 +165,6 @@ def patch_moe_quantization_on_load(cfg):
transformers.modeling_utils.caching_allocator_warmup = _noop_warmup
# Read quantization settings from config
quant_type = getattr(cfg, "bnb_4bit_quant_type", None) or "nf4"
compress_statistics = getattr(cfg, "bnb_4bit_use_double_quant", None)
if compress_statistics is None:
compress_statistics = True
_moe_load_state["quant_type"] = quant_type
_moe_load_state["compress_statistics"] = compress_statistics
_moe_load_state["count"] = 0
original_set_param = transformers.core_model_loading.set_param_for_module
def _patched_set_param_for_module(model, target_name, param_value, *args, **kwargs):
@@ -88,12 +175,15 @@ def patch_moe_quantization_on_load(cfg):
mod_path, _, pname = target_name.rpartition(".")
mod = model.get_submodule(mod_path) if mod_path else model
if not isinstance(mod, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt)):
replace_parameter_4bit(
mod,
pname,
compress_statistics=_moe_load_state["compress_statistics"],
quant_type=_moe_load_state["quant_type"],
)
if _moe_load_state["mode"] == "4bit":
replace_parameter_4bit(
mod,
pname,
compress_statistics=_moe_load_state["compress_statistics"],
quant_type=_moe_load_state["quant_type"],
)
else:
replace_parameter_8bit(mod, pname)
torch.cuda.empty_cache()
_moe_load_state["count"] += 1