feat: attempt to quant experts in 8bit mode too
This commit is contained in:
@@ -175,10 +175,12 @@ class ModelLoader:
|
||||
|
||||
# Activate loading-time quantization for 3D MoE expert params before
|
||||
# from_pretrained() runs. This patches set_param_for_module so each
|
||||
# expert weight is quantized to 4-bit as it's loaded, keeping peak
|
||||
# VRAM to one expert param in bf16 at a time.
|
||||
# expert weight is quantized (4-bit or 8-bit) as it's loaded, keeping
|
||||
# peak VRAM to one expert param in bf16 at a time.
|
||||
moe_quant_active = False
|
||||
if self.cfg.adapter in ("qlora", "lora") and self.cfg.load_in_4bit:
|
||||
if self.cfg.adapter in ("qlora", "lora") and (
|
||||
self.cfg.load_in_4bit or self.cfg.load_in_8bit
|
||||
):
|
||||
from axolotl.monkeypatch.moe_quant import patch_moe_quantization_on_load
|
||||
|
||||
patch_moe_quantization_on_load(self.cfg)
|
||||
|
||||
@@ -2,23 +2,23 @@
|
||||
Loading-time quantization for MoE expert weights stored as 3D nn.Parameter tensors.
|
||||
|
||||
In transformers v5, many MoE models store expert weights as fused 3D nn.Parameter
|
||||
tensors instead of individual nn.Linear modules. BnB 4-bit quantization only targets
|
||||
tensors instead of individual nn.Linear modules. BnB quantization only targets
|
||||
nn.Linear, so these expert weights are loaded in full precision, causing high peak VRAM.
|
||||
|
||||
This module patches transformers' weight loading to quantize 3D expert parameters
|
||||
on-the-fly as they're assigned to modules, using bitsandbytes.nn.parametrize.
|
||||
replace_parameter_4bit (requires bitsandbytes >= 0.48.0). This reduces peak VRAM
|
||||
from "all experts in bf16" to "one expert param in bf16 at a time."
|
||||
on-the-fly as they're assigned to modules. For 4-bit, it uses
|
||||
bitsandbytes.nn.parametrize.replace_parameter_4bit (requires bitsandbytes >= 0.48.0).
|
||||
For 8-bit, it uses a custom parametrization built on bitsandbytes.functional's
|
||||
int8_vectorwise_quant/dequant (row-wise absmax scaling). Both reduce peak VRAM from
|
||||
"all experts in bf16" to "one expert param in bf16 at a time."
|
||||
|
||||
PEFT's target_parameters / ParamWrapper can then apply LoRA on top of these quantized
|
||||
params via stacked parametrizations.
|
||||
|
||||
Note: FSDP2 cpu ram efficient loading and Tensor Parallel (DTensor) compatibility
|
||||
with parametrization is untested.
|
||||
"""
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
import torch.nn.utils.parametrize as P
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
@@ -27,27 +27,105 @@ LOG = get_logger(__name__)
|
||||
# Module-level state for the loading-time quantization patch.
|
||||
_moe_load_state = {
|
||||
"count": 0,
|
||||
"mode": "4bit",
|
||||
"quant_type": "nf4",
|
||||
"compress_statistics": True,
|
||||
"patched": False,
|
||||
}
|
||||
|
||||
|
||||
class Bnb8bitParametrization(torch.nn.Module):
|
||||
"""Parametrization that dequantizes int8 row-wise quantized data on access.
|
||||
|
||||
Mirrors ``Bnb4bitParametrization`` from ``bitsandbytes.nn.parametrize`` but for
|
||||
int8 row-wise (absmax) quantization. Stores the per-row scales as a buffer and
|
||||
delegates to ``bitsandbytes.functional.int8_vectorwise_dequant`` which computes
|
||||
``int8_data * row_stats * (1/127)``.
|
||||
"""
|
||||
|
||||
def __init__(self, row_stats: torch.Tensor):
|
||||
super().__init__()
|
||||
self.register_buffer("row_stats", row_stats)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, quantized_param: torch.Tensor) -> torch.Tensor:
|
||||
return bnb.functional.int8_vectorwise_dequant(quantized_param, self.row_stats)
|
||||
|
||||
|
||||
def _enable_parametrization_cache(module, inputs):
|
||||
P._cache_enabled += 1
|
||||
|
||||
|
||||
def _disable_parametrization_cache(module, inputs, output):
|
||||
P._cache_enabled -= 1
|
||||
if not P._cache_enabled:
|
||||
P._cache = {}
|
||||
|
||||
|
||||
def replace_parameter_8bit(module, param_name):
|
||||
"""Replace a module parameter with an 8-bit quantized version using parametrization.
|
||||
|
||||
Mirrors ``bitsandbytes.nn.parametrize.replace_parameter_4bit`` but for int8
|
||||
row-wise (absmax) quantization. Uses ``int8_vectorwise_quant`` which supports
|
||||
N-D tensors natively (scales shape = ``prod(shape[:-1])``).
|
||||
|
||||
Args:
|
||||
module: The module containing the parameter to quantize.
|
||||
param_name: Name of the parameter within the module.
|
||||
"""
|
||||
original_param = getattr(module, param_name)
|
||||
int8_data, row_stats, _ = bnb.functional.int8_vectorwise_quant(
|
||||
original_param.data.to(torch.float16)
|
||||
)
|
||||
|
||||
setattr(module, param_name, torch.nn.Parameter(int8_data, requires_grad=False))
|
||||
del original_param
|
||||
|
||||
P.register_parametrization(
|
||||
module, param_name, Bnb8bitParametrization(row_stats), unsafe=True
|
||||
)
|
||||
|
||||
# Register caching hooks (same pattern as BnB 4-bit). Caching avoids
|
||||
# redundant dequantization when the same param is accessed multiple times
|
||||
# in a single forward pass.
|
||||
module.register_forward_pre_hook(_enable_parametrization_cache)
|
||||
module.register_forward_hook(_disable_parametrization_cache)
|
||||
|
||||
|
||||
def _8bit_state_dict_post_hook(
|
||||
module, state_dict, prefix, local_metadata, *, param_name
|
||||
):
|
||||
"""Placeholder for 8-bit state_dict serialization hook.
|
||||
|
||||
For LoRA/QLoRA training, only adapter weights (lora_A/B) are saved — base model
|
||||
weights including quantized experts are not serialized. State dict hooks for 8-bit
|
||||
are therefore not needed for the primary use case. If full-model saving with 8-bit
|
||||
quantized expert params is needed, this hook should store int8 data + row_stats.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"State dict serialization for 8-bit quantized expert parameters is not yet "
|
||||
"implemented. This is not needed for LoRA/QLoRA training (only adapter weights "
|
||||
"are saved). If you need to save the full model with 8-bit quantized experts, "
|
||||
"please open an issue."
|
||||
)
|
||||
|
||||
|
||||
def patch_moe_quantization_on_load(cfg):
|
||||
"""Patch transformers' weight loading to quantize 3D MoE expert params on-the-fly.
|
||||
|
||||
Wraps ``transformers.core_model_loading.set_param_for_module`` so that after each
|
||||
parameter is assigned to its module, any 3D+ tensor on CUDA that BnB skipped
|
||||
(i.e. not inside a Linear4bit/Linear8bitLt) is immediately quantized via
|
||||
``replace_parameter_4bit``. This keeps peak VRAM to one expert param in bf16
|
||||
at a time, instead of loading all experts in bf16 first.
|
||||
``replace_parameter_4bit`` (for 4-bit) or ``replace_parameter_8bit`` (for 8-bit).
|
||||
This keeps peak VRAM to one expert param in bf16 at a time, instead of loading
|
||||
all experts in bf16 first.
|
||||
|
||||
The patch stays active permanently — the ``ndim >= 3`` and ``is_cuda`` checks
|
||||
make it safe for non-MoE models (no false positives).
|
||||
|
||||
Args:
|
||||
cfg: Axolotl DictDefault config. Reads bnb_4bit_quant_type and
|
||||
bnb_4bit_use_double_quant for quantization settings.
|
||||
cfg: Axolotl DictDefault config. For 4-bit, reads bnb_4bit_quant_type and
|
||||
bnb_4bit_use_double_quant. For 8-bit, no additional settings needed.
|
||||
"""
|
||||
if _moe_load_state["patched"]:
|
||||
LOG.debug("MoE loading-time quantization patch already active")
|
||||
@@ -55,7 +133,26 @@ def patch_moe_quantization_on_load(cfg):
|
||||
|
||||
import transformers.core_model_loading
|
||||
import transformers.modeling_utils
|
||||
from bitsandbytes.nn.parametrize import replace_parameter_4bit
|
||||
|
||||
# Determine quantization mode from config.
|
||||
if getattr(cfg, "load_in_8bit", False):
|
||||
mode = "8bit"
|
||||
else:
|
||||
mode = "4bit"
|
||||
|
||||
_moe_load_state["mode"] = mode
|
||||
_moe_load_state["count"] = 0
|
||||
|
||||
if mode == "4bit":
|
||||
from bitsandbytes.nn.parametrize import replace_parameter_4bit
|
||||
|
||||
quant_type = getattr(cfg, "bnb_4bit_quant_type", None) or "nf4"
|
||||
compress_statistics = getattr(cfg, "bnb_4bit_use_double_quant", None)
|
||||
if compress_statistics is None:
|
||||
compress_statistics = True
|
||||
|
||||
_moe_load_state["quant_type"] = quant_type
|
||||
_moe_load_state["compress_statistics"] = compress_statistics
|
||||
|
||||
# Patch caching_allocator_warmup to be a no-op. This function pre-allocates
|
||||
# a single huge GPU tensor equal to the model's total param bytes to warm the
|
||||
@@ -68,16 +165,6 @@ def patch_moe_quantization_on_load(cfg):
|
||||
|
||||
transformers.modeling_utils.caching_allocator_warmup = _noop_warmup
|
||||
|
||||
# Read quantization settings from config
|
||||
quant_type = getattr(cfg, "bnb_4bit_quant_type", None) or "nf4"
|
||||
compress_statistics = getattr(cfg, "bnb_4bit_use_double_quant", None)
|
||||
if compress_statistics is None:
|
||||
compress_statistics = True
|
||||
|
||||
_moe_load_state["quant_type"] = quant_type
|
||||
_moe_load_state["compress_statistics"] = compress_statistics
|
||||
_moe_load_state["count"] = 0
|
||||
|
||||
original_set_param = transformers.core_model_loading.set_param_for_module
|
||||
|
||||
def _patched_set_param_for_module(model, target_name, param_value, *args, **kwargs):
|
||||
@@ -88,12 +175,15 @@ def patch_moe_quantization_on_load(cfg):
|
||||
mod_path, _, pname = target_name.rpartition(".")
|
||||
mod = model.get_submodule(mod_path) if mod_path else model
|
||||
if not isinstance(mod, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt)):
|
||||
replace_parameter_4bit(
|
||||
mod,
|
||||
pname,
|
||||
compress_statistics=_moe_load_state["compress_statistics"],
|
||||
quant_type=_moe_load_state["quant_type"],
|
||||
)
|
||||
if _moe_load_state["mode"] == "4bit":
|
||||
replace_parameter_4bit(
|
||||
mod,
|
||||
pname,
|
||||
compress_statistics=_moe_load_state["compress_statistics"],
|
||||
quant_type=_moe_load_state["quant_type"],
|
||||
)
|
||||
else:
|
||||
replace_parameter_8bit(mod, pname)
|
||||
torch.cuda.empty_cache()
|
||||
_moe_load_state["count"] += 1
|
||||
|
||||
|
||||
Reference in New Issue
Block a user