feat: attempt to quant experts in 8bit mode too
This commit is contained in:
@@ -175,10 +175,12 @@ class ModelLoader:
|
|||||||
|
|
||||||
# Activate loading-time quantization for 3D MoE expert params before
|
# Activate loading-time quantization for 3D MoE expert params before
|
||||||
# from_pretrained() runs. This patches set_param_for_module so each
|
# from_pretrained() runs. This patches set_param_for_module so each
|
||||||
# expert weight is quantized to 4-bit as it's loaded, keeping peak
|
# expert weight is quantized (4-bit or 8-bit) as it's loaded, keeping
|
||||||
# VRAM to one expert param in bf16 at a time.
|
# peak VRAM to one expert param in bf16 at a time.
|
||||||
moe_quant_active = False
|
moe_quant_active = False
|
||||||
if self.cfg.adapter in ("qlora", "lora") and self.cfg.load_in_4bit:
|
if self.cfg.adapter in ("qlora", "lora") and (
|
||||||
|
self.cfg.load_in_4bit or self.cfg.load_in_8bit
|
||||||
|
):
|
||||||
from axolotl.monkeypatch.moe_quant import patch_moe_quantization_on_load
|
from axolotl.monkeypatch.moe_quant import patch_moe_quantization_on_load
|
||||||
|
|
||||||
patch_moe_quantization_on_load(self.cfg)
|
patch_moe_quantization_on_load(self.cfg)
|
||||||
|
|||||||
@@ -2,23 +2,23 @@
|
|||||||
Loading-time quantization for MoE expert weights stored as 3D nn.Parameter tensors.
|
Loading-time quantization for MoE expert weights stored as 3D nn.Parameter tensors.
|
||||||
|
|
||||||
In transformers v5, many MoE models store expert weights as fused 3D nn.Parameter
|
In transformers v5, many MoE models store expert weights as fused 3D nn.Parameter
|
||||||
tensors instead of individual nn.Linear modules. BnB 4-bit quantization only targets
|
tensors instead of individual nn.Linear modules. BnB quantization only targets
|
||||||
nn.Linear, so these expert weights are loaded in full precision, causing high peak VRAM.
|
nn.Linear, so these expert weights are loaded in full precision, causing high peak VRAM.
|
||||||
|
|
||||||
This module patches transformers' weight loading to quantize 3D expert parameters
|
This module patches transformers' weight loading to quantize 3D expert parameters
|
||||||
on-the-fly as they're assigned to modules, using bitsandbytes.nn.parametrize.
|
on-the-fly as they're assigned to modules. For 4-bit, it uses
|
||||||
replace_parameter_4bit (requires bitsandbytes >= 0.48.0). This reduces peak VRAM
|
bitsandbytes.nn.parametrize.replace_parameter_4bit (requires bitsandbytes >= 0.48.0).
|
||||||
from "all experts in bf16" to "one expert param in bf16 at a time."
|
For 8-bit, it uses a custom parametrization built on bitsandbytes.functional's
|
||||||
|
int8_vectorwise_quant/dequant (row-wise absmax scaling). Both reduce peak VRAM from
|
||||||
|
"all experts in bf16" to "one expert param in bf16 at a time."
|
||||||
|
|
||||||
PEFT's target_parameters / ParamWrapper can then apply LoRA on top of these quantized
|
PEFT's target_parameters / ParamWrapper can then apply LoRA on top of these quantized
|
||||||
params via stacked parametrizations.
|
params via stacked parametrizations.
|
||||||
|
|
||||||
Note: FSDP2 cpu ram efficient loading and Tensor Parallel (DTensor) compatibility
|
|
||||||
with parametrization is untested.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.utils.parametrize as P
|
||||||
|
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
@@ -27,27 +27,105 @@ LOG = get_logger(__name__)
|
|||||||
# Module-level state for the loading-time quantization patch.
|
# Module-level state for the loading-time quantization patch.
|
||||||
_moe_load_state = {
|
_moe_load_state = {
|
||||||
"count": 0,
|
"count": 0,
|
||||||
|
"mode": "4bit",
|
||||||
"quant_type": "nf4",
|
"quant_type": "nf4",
|
||||||
"compress_statistics": True,
|
"compress_statistics": True,
|
||||||
"patched": False,
|
"patched": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class Bnb8bitParametrization(torch.nn.Module):
|
||||||
|
"""Parametrization that dequantizes int8 row-wise quantized data on access.
|
||||||
|
|
||||||
|
Mirrors ``Bnb4bitParametrization`` from ``bitsandbytes.nn.parametrize`` but for
|
||||||
|
int8 row-wise (absmax) quantization. Stores the per-row scales as a buffer and
|
||||||
|
delegates to ``bitsandbytes.functional.int8_vectorwise_dequant`` which computes
|
||||||
|
``int8_data * row_stats * (1/127)``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, row_stats: torch.Tensor):
|
||||||
|
super().__init__()
|
||||||
|
self.register_buffer("row_stats", row_stats)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward(self, quantized_param: torch.Tensor) -> torch.Tensor:
|
||||||
|
return bnb.functional.int8_vectorwise_dequant(quantized_param, self.row_stats)
|
||||||
|
|
||||||
|
|
||||||
|
def _enable_parametrization_cache(module, inputs):
|
||||||
|
P._cache_enabled += 1
|
||||||
|
|
||||||
|
|
||||||
|
def _disable_parametrization_cache(module, inputs, output):
|
||||||
|
P._cache_enabled -= 1
|
||||||
|
if not P._cache_enabled:
|
||||||
|
P._cache = {}
|
||||||
|
|
||||||
|
|
||||||
|
def replace_parameter_8bit(module, param_name):
|
||||||
|
"""Replace a module parameter with an 8-bit quantized version using parametrization.
|
||||||
|
|
||||||
|
Mirrors ``bitsandbytes.nn.parametrize.replace_parameter_4bit`` but for int8
|
||||||
|
row-wise (absmax) quantization. Uses ``int8_vectorwise_quant`` which supports
|
||||||
|
N-D tensors natively (scales shape = ``prod(shape[:-1])``).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module: The module containing the parameter to quantize.
|
||||||
|
param_name: Name of the parameter within the module.
|
||||||
|
"""
|
||||||
|
original_param = getattr(module, param_name)
|
||||||
|
int8_data, row_stats, _ = bnb.functional.int8_vectorwise_quant(
|
||||||
|
original_param.data.to(torch.float16)
|
||||||
|
)
|
||||||
|
|
||||||
|
setattr(module, param_name, torch.nn.Parameter(int8_data, requires_grad=False))
|
||||||
|
del original_param
|
||||||
|
|
||||||
|
P.register_parametrization(
|
||||||
|
module, param_name, Bnb8bitParametrization(row_stats), unsafe=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Register caching hooks (same pattern as BnB 4-bit). Caching avoids
|
||||||
|
# redundant dequantization when the same param is accessed multiple times
|
||||||
|
# in a single forward pass.
|
||||||
|
module.register_forward_pre_hook(_enable_parametrization_cache)
|
||||||
|
module.register_forward_hook(_disable_parametrization_cache)
|
||||||
|
|
||||||
|
|
||||||
|
def _8bit_state_dict_post_hook(
|
||||||
|
module, state_dict, prefix, local_metadata, *, param_name
|
||||||
|
):
|
||||||
|
"""Placeholder for 8-bit state_dict serialization hook.
|
||||||
|
|
||||||
|
For LoRA/QLoRA training, only adapter weights (lora_A/B) are saved — base model
|
||||||
|
weights including quantized experts are not serialized. State dict hooks for 8-bit
|
||||||
|
are therefore not needed for the primary use case. If full-model saving with 8-bit
|
||||||
|
quantized expert params is needed, this hook should store int8 data + row_stats.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(
|
||||||
|
"State dict serialization for 8-bit quantized expert parameters is not yet "
|
||||||
|
"implemented. This is not needed for LoRA/QLoRA training (only adapter weights "
|
||||||
|
"are saved). If you need to save the full model with 8-bit quantized experts, "
|
||||||
|
"please open an issue."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def patch_moe_quantization_on_load(cfg):
|
def patch_moe_quantization_on_load(cfg):
|
||||||
"""Patch transformers' weight loading to quantize 3D MoE expert params on-the-fly.
|
"""Patch transformers' weight loading to quantize 3D MoE expert params on-the-fly.
|
||||||
|
|
||||||
Wraps ``transformers.core_model_loading.set_param_for_module`` so that after each
|
Wraps ``transformers.core_model_loading.set_param_for_module`` so that after each
|
||||||
parameter is assigned to its module, any 3D+ tensor on CUDA that BnB skipped
|
parameter is assigned to its module, any 3D+ tensor on CUDA that BnB skipped
|
||||||
(i.e. not inside a Linear4bit/Linear8bitLt) is immediately quantized via
|
(i.e. not inside a Linear4bit/Linear8bitLt) is immediately quantized via
|
||||||
``replace_parameter_4bit``. This keeps peak VRAM to one expert param in bf16
|
``replace_parameter_4bit`` (for 4-bit) or ``replace_parameter_8bit`` (for 8-bit).
|
||||||
at a time, instead of loading all experts in bf16 first.
|
This keeps peak VRAM to one expert param in bf16 at a time, instead of loading
|
||||||
|
all experts in bf16 first.
|
||||||
|
|
||||||
The patch stays active permanently — the ``ndim >= 3`` and ``is_cuda`` checks
|
The patch stays active permanently — the ``ndim >= 3`` and ``is_cuda`` checks
|
||||||
make it safe for non-MoE models (no false positives).
|
make it safe for non-MoE models (no false positives).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cfg: Axolotl DictDefault config. Reads bnb_4bit_quant_type and
|
cfg: Axolotl DictDefault config. For 4-bit, reads bnb_4bit_quant_type and
|
||||||
bnb_4bit_use_double_quant for quantization settings.
|
bnb_4bit_use_double_quant. For 8-bit, no additional settings needed.
|
||||||
"""
|
"""
|
||||||
if _moe_load_state["patched"]:
|
if _moe_load_state["patched"]:
|
||||||
LOG.debug("MoE loading-time quantization patch already active")
|
LOG.debug("MoE loading-time quantization patch already active")
|
||||||
@@ -55,7 +133,26 @@ def patch_moe_quantization_on_load(cfg):
|
|||||||
|
|
||||||
import transformers.core_model_loading
|
import transformers.core_model_loading
|
||||||
import transformers.modeling_utils
|
import transformers.modeling_utils
|
||||||
from bitsandbytes.nn.parametrize import replace_parameter_4bit
|
|
||||||
|
# Determine quantization mode from config.
|
||||||
|
if getattr(cfg, "load_in_8bit", False):
|
||||||
|
mode = "8bit"
|
||||||
|
else:
|
||||||
|
mode = "4bit"
|
||||||
|
|
||||||
|
_moe_load_state["mode"] = mode
|
||||||
|
_moe_load_state["count"] = 0
|
||||||
|
|
||||||
|
if mode == "4bit":
|
||||||
|
from bitsandbytes.nn.parametrize import replace_parameter_4bit
|
||||||
|
|
||||||
|
quant_type = getattr(cfg, "bnb_4bit_quant_type", None) or "nf4"
|
||||||
|
compress_statistics = getattr(cfg, "bnb_4bit_use_double_quant", None)
|
||||||
|
if compress_statistics is None:
|
||||||
|
compress_statistics = True
|
||||||
|
|
||||||
|
_moe_load_state["quant_type"] = quant_type
|
||||||
|
_moe_load_state["compress_statistics"] = compress_statistics
|
||||||
|
|
||||||
# Patch caching_allocator_warmup to be a no-op. This function pre-allocates
|
# Patch caching_allocator_warmup to be a no-op. This function pre-allocates
|
||||||
# a single huge GPU tensor equal to the model's total param bytes to warm the
|
# a single huge GPU tensor equal to the model's total param bytes to warm the
|
||||||
@@ -68,16 +165,6 @@ def patch_moe_quantization_on_load(cfg):
|
|||||||
|
|
||||||
transformers.modeling_utils.caching_allocator_warmup = _noop_warmup
|
transformers.modeling_utils.caching_allocator_warmup = _noop_warmup
|
||||||
|
|
||||||
# Read quantization settings from config
|
|
||||||
quant_type = getattr(cfg, "bnb_4bit_quant_type", None) or "nf4"
|
|
||||||
compress_statistics = getattr(cfg, "bnb_4bit_use_double_quant", None)
|
|
||||||
if compress_statistics is None:
|
|
||||||
compress_statistics = True
|
|
||||||
|
|
||||||
_moe_load_state["quant_type"] = quant_type
|
|
||||||
_moe_load_state["compress_statistics"] = compress_statistics
|
|
||||||
_moe_load_state["count"] = 0
|
|
||||||
|
|
||||||
original_set_param = transformers.core_model_loading.set_param_for_module
|
original_set_param = transformers.core_model_loading.set_param_for_module
|
||||||
|
|
||||||
def _patched_set_param_for_module(model, target_name, param_value, *args, **kwargs):
|
def _patched_set_param_for_module(model, target_name, param_value, *args, **kwargs):
|
||||||
@@ -88,12 +175,15 @@ def patch_moe_quantization_on_load(cfg):
|
|||||||
mod_path, _, pname = target_name.rpartition(".")
|
mod_path, _, pname = target_name.rpartition(".")
|
||||||
mod = model.get_submodule(mod_path) if mod_path else model
|
mod = model.get_submodule(mod_path) if mod_path else model
|
||||||
if not isinstance(mod, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt)):
|
if not isinstance(mod, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt)):
|
||||||
replace_parameter_4bit(
|
if _moe_load_state["mode"] == "4bit":
|
||||||
mod,
|
replace_parameter_4bit(
|
||||||
pname,
|
mod,
|
||||||
compress_statistics=_moe_load_state["compress_statistics"],
|
pname,
|
||||||
quant_type=_moe_load_state["quant_type"],
|
compress_statistics=_moe_load_state["compress_statistics"],
|
||||||
)
|
quant_type=_moe_load_state["quant_type"],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
replace_parameter_8bit(mod, pname)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
_moe_load_state["count"] += 1
|
_moe_load_state["count"] += 1
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user