diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index 9f6e92075..dc93c659f 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -175,10 +175,12 @@ class ModelLoader: # Activate loading-time quantization for 3D MoE expert params before # from_pretrained() runs. This patches set_param_for_module so each - # expert weight is quantized to 4-bit as it's loaded, keeping peak - # VRAM to one expert param in bf16 at a time. + # expert weight is quantized (4-bit or 8-bit) as it's loaded, keeping + # peak VRAM to one expert param in bf16 at a time. moe_quant_active = False - if self.cfg.adapter in ("qlora", "lora") and self.cfg.load_in_4bit: + if self.cfg.adapter in ("qlora", "lora") and ( + self.cfg.load_in_4bit or self.cfg.load_in_8bit + ): from axolotl.monkeypatch.moe_quant import patch_moe_quantization_on_load patch_moe_quantization_on_load(self.cfg) diff --git a/src/axolotl/monkeypatch/moe_quant.py b/src/axolotl/monkeypatch/moe_quant.py index a22d5b136..360e4c571 100644 --- a/src/axolotl/monkeypatch/moe_quant.py +++ b/src/axolotl/monkeypatch/moe_quant.py @@ -2,23 +2,23 @@ Loading-time quantization for MoE expert weights stored as 3D nn.Parameter tensors. In transformers v5, many MoE models store expert weights as fused 3D nn.Parameter -tensors instead of individual nn.Linear modules. BnB 4-bit quantization only targets +tensors instead of individual nn.Linear modules. BnB quantization only targets nn.Linear, so these expert weights are loaded in full precision, causing high peak VRAM. This module patches transformers' weight loading to quantize 3D expert parameters -on-the-fly as they're assigned to modules, using bitsandbytes.nn.parametrize. -replace_parameter_4bit (requires bitsandbytes >= 0.48.0). This reduces peak VRAM -from "all experts in bf16" to "one expert param in bf16 at a time." +on-the-fly as they're assigned to modules. For 4-bit, it uses +bitsandbytes.nn.parametrize.replace_parameter_4bit (requires bitsandbytes >= 0.48.0). +For 8-bit, it uses a custom parametrization built on bitsandbytes.functional's +int8_vectorwise_quant/dequant (row-wise absmax scaling). Both reduce peak VRAM from +"all experts in bf16" to "one expert param in bf16 at a time." PEFT's target_parameters / ParamWrapper can then apply LoRA on top of these quantized params via stacked parametrizations. - -Note: FSDP2 cpu ram efficient loading and Tensor Parallel (DTensor) compatibility -with parametrization is untested. """ import bitsandbytes as bnb import torch +import torch.nn.utils.parametrize as P from axolotl.utils.logging import get_logger @@ -27,27 +27,105 @@ LOG = get_logger(__name__) # Module-level state for the loading-time quantization patch. _moe_load_state = { "count": 0, + "mode": "4bit", "quant_type": "nf4", "compress_statistics": True, "patched": False, } +class Bnb8bitParametrization(torch.nn.Module): + """Parametrization that dequantizes int8 row-wise quantized data on access. + + Mirrors ``Bnb4bitParametrization`` from ``bitsandbytes.nn.parametrize`` but for + int8 row-wise (absmax) quantization. Stores the per-row scales as a buffer and + delegates to ``bitsandbytes.functional.int8_vectorwise_dequant`` which computes + ``int8_data * row_stats * (1/127)``. + """ + + def __init__(self, row_stats: torch.Tensor): + super().__init__() + self.register_buffer("row_stats", row_stats) + + @torch.no_grad() + def forward(self, quantized_param: torch.Tensor) -> torch.Tensor: + return bnb.functional.int8_vectorwise_dequant(quantized_param, self.row_stats) + + +def _enable_parametrization_cache(module, inputs): + P._cache_enabled += 1 + + +def _disable_parametrization_cache(module, inputs, output): + P._cache_enabled -= 1 + if not P._cache_enabled: + P._cache = {} + + +def replace_parameter_8bit(module, param_name): + """Replace a module parameter with an 8-bit quantized version using parametrization. + + Mirrors ``bitsandbytes.nn.parametrize.replace_parameter_4bit`` but for int8 + row-wise (absmax) quantization. Uses ``int8_vectorwise_quant`` which supports + N-D tensors natively (scales shape = ``prod(shape[:-1])``). + + Args: + module: The module containing the parameter to quantize. + param_name: Name of the parameter within the module. + """ + original_param = getattr(module, param_name) + int8_data, row_stats, _ = bnb.functional.int8_vectorwise_quant( + original_param.data.to(torch.float16) + ) + + setattr(module, param_name, torch.nn.Parameter(int8_data, requires_grad=False)) + del original_param + + P.register_parametrization( + module, param_name, Bnb8bitParametrization(row_stats), unsafe=True + ) + + # Register caching hooks (same pattern as BnB 4-bit). Caching avoids + # redundant dequantization when the same param is accessed multiple times + # in a single forward pass. + module.register_forward_pre_hook(_enable_parametrization_cache) + module.register_forward_hook(_disable_parametrization_cache) + + +def _8bit_state_dict_post_hook( + module, state_dict, prefix, local_metadata, *, param_name +): + """Placeholder for 8-bit state_dict serialization hook. + + For LoRA/QLoRA training, only adapter weights (lora_A/B) are saved — base model + weights including quantized experts are not serialized. State dict hooks for 8-bit + are therefore not needed for the primary use case. If full-model saving with 8-bit + quantized expert params is needed, this hook should store int8 data + row_stats. + """ + raise NotImplementedError( + "State dict serialization for 8-bit quantized expert parameters is not yet " + "implemented. This is not needed for LoRA/QLoRA training (only adapter weights " + "are saved). If you need to save the full model with 8-bit quantized experts, " + "please open an issue." + ) + + def patch_moe_quantization_on_load(cfg): """Patch transformers' weight loading to quantize 3D MoE expert params on-the-fly. Wraps ``transformers.core_model_loading.set_param_for_module`` so that after each parameter is assigned to its module, any 3D+ tensor on CUDA that BnB skipped (i.e. not inside a Linear4bit/Linear8bitLt) is immediately quantized via - ``replace_parameter_4bit``. This keeps peak VRAM to one expert param in bf16 - at a time, instead of loading all experts in bf16 first. + ``replace_parameter_4bit`` (for 4-bit) or ``replace_parameter_8bit`` (for 8-bit). + This keeps peak VRAM to one expert param in bf16 at a time, instead of loading + all experts in bf16 first. The patch stays active permanently — the ``ndim >= 3`` and ``is_cuda`` checks make it safe for non-MoE models (no false positives). Args: - cfg: Axolotl DictDefault config. Reads bnb_4bit_quant_type and - bnb_4bit_use_double_quant for quantization settings. + cfg: Axolotl DictDefault config. For 4-bit, reads bnb_4bit_quant_type and + bnb_4bit_use_double_quant. For 8-bit, no additional settings needed. """ if _moe_load_state["patched"]: LOG.debug("MoE loading-time quantization patch already active") @@ -55,7 +133,26 @@ def patch_moe_quantization_on_load(cfg): import transformers.core_model_loading import transformers.modeling_utils - from bitsandbytes.nn.parametrize import replace_parameter_4bit + + # Determine quantization mode from config. + if getattr(cfg, "load_in_8bit", False): + mode = "8bit" + else: + mode = "4bit" + + _moe_load_state["mode"] = mode + _moe_load_state["count"] = 0 + + if mode == "4bit": + from bitsandbytes.nn.parametrize import replace_parameter_4bit + + quant_type = getattr(cfg, "bnb_4bit_quant_type", None) or "nf4" + compress_statistics = getattr(cfg, "bnb_4bit_use_double_quant", None) + if compress_statistics is None: + compress_statistics = True + + _moe_load_state["quant_type"] = quant_type + _moe_load_state["compress_statistics"] = compress_statistics # Patch caching_allocator_warmup to be a no-op. This function pre-allocates # a single huge GPU tensor equal to the model's total param bytes to warm the @@ -68,16 +165,6 @@ def patch_moe_quantization_on_load(cfg): transformers.modeling_utils.caching_allocator_warmup = _noop_warmup - # Read quantization settings from config - quant_type = getattr(cfg, "bnb_4bit_quant_type", None) or "nf4" - compress_statistics = getattr(cfg, "bnb_4bit_use_double_quant", None) - if compress_statistics is None: - compress_statistics = True - - _moe_load_state["quant_type"] = quant_type - _moe_load_state["compress_statistics"] = compress_statistics - _moe_load_state["count"] = 0 - original_set_param = transformers.core_model_loading.set_param_for_module def _patched_set_param_for_module(model, target_name, param_value, *args, **kwargs): @@ -88,12 +175,15 @@ def patch_moe_quantization_on_load(cfg): mod_path, _, pname = target_name.rpartition(".") mod = model.get_submodule(mod_path) if mod_path else model if not isinstance(mod, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt)): - replace_parameter_4bit( - mod, - pname, - compress_statistics=_moe_load_state["compress_statistics"], - quant_type=_moe_load_state["quant_type"], - ) + if _moe_load_state["mode"] == "4bit": + replace_parameter_4bit( + mod, + pname, + compress_statistics=_moe_load_state["compress_statistics"], + quant_type=_moe_load_state["quant_type"], + ) + else: + replace_parameter_8bit(mod, pname) torch.cuda.empty_cache() _moe_load_state["count"] += 1