feat: add moe quant to test by ved
This commit is contained in:
@@ -173,6 +173,14 @@ class ModelLoader:
|
|||||||
PLUGIN_MANAGER.pre_model_load(self.cfg)
|
PLUGIN_MANAGER.pre_model_load(self.cfg)
|
||||||
self.patch_manager.apply_post_plugin_pre_model_load_patches()
|
self.patch_manager.apply_post_plugin_pre_model_load_patches()
|
||||||
skip_move_to_device = self._build_model()
|
skip_move_to_device = self._build_model()
|
||||||
|
|
||||||
|
# Quantize 3D MoE expert nn.Parameter tensors that BnB skips during loading.
|
||||||
|
self.model._moe_experts_quantized = False
|
||||||
|
if self.cfg.adapter in ("qlora", "lora") and self.cfg.load_in_4bit:
|
||||||
|
from axolotl.monkeypatch.moe_quant import quantize_moe_expert_params
|
||||||
|
|
||||||
|
self.model._moe_experts_quantized = quantize_moe_expert_params(self.model)
|
||||||
|
|
||||||
PLUGIN_MANAGER.post_model_build(self.cfg, self.model)
|
PLUGIN_MANAGER.post_model_build(self.cfg, self.model)
|
||||||
|
|
||||||
# Post-build model configuration
|
# Post-build model configuration
|
||||||
@@ -860,6 +868,10 @@ class ModelLoader:
|
|||||||
# Make sure everything is in the same dtype
|
# Make sure everything is in the same dtype
|
||||||
skip_prepare_model_for_kbit_training = True
|
skip_prepare_model_for_kbit_training = True
|
||||||
|
|
||||||
|
if getattr(self.model, "_moe_experts_quantized", False):
|
||||||
|
# Parametrized expert tensors dequantize on access — would OOM.
|
||||||
|
skip_prepare_model_for_kbit_training = True
|
||||||
|
|
||||||
if (
|
if (
|
||||||
not skip_prepare_model_for_kbit_training
|
not skip_prepare_model_for_kbit_training
|
||||||
and self.cfg.adapter in ["lora", "qlora"]
|
and self.cfg.adapter in ["lora", "qlora"]
|
||||||
|
|||||||
86
src/axolotl/monkeypatch/moe_quant.py
Normal file
86
src/axolotl/monkeypatch/moe_quant.py
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
"""
|
||||||
|
Post-load quantization for MoE expert weights stored as 3D nn.Parameter tensors.
|
||||||
|
|
||||||
|
In transformers v5, many MoE models store expert weights as fused 3D nn.Parameter
|
||||||
|
tensors instead of individual nn.Linear modules. BnB 4-bit quantization only targets
|
||||||
|
nn.Linear, so these expert weights are skipped during model loading, causing OOM.
|
||||||
|
|
||||||
|
This module provides a post-load fixup that quantizes those skipped parameters using
|
||||||
|
bitsandbytes.nn.parametrize.replace_parameter_4bit (requires bitsandbytes >= 0.48.0).
|
||||||
|
PEFT's target_parameters / ParamWrapper can then apply LoRA on top of these quantized
|
||||||
|
params via stacked parametrizations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import bitsandbytes as bnb
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def find_unquantized_expert_params(model):
|
||||||
|
"""Find 3D+ nn.Parameter tensors that BnB quantization skipped.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of (module, param_name) tuples to quantize.
|
||||||
|
"""
|
||||||
|
params_to_quantize = []
|
||||||
|
for _, module in model.named_modules():
|
||||||
|
if isinstance(module, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt)):
|
||||||
|
continue
|
||||||
|
for param_name, param in module.named_parameters(recurse=False):
|
||||||
|
if param.ndim >= 3 and any(
|
||||||
|
kw in param_name for kw in ("experts", "gate_up_proj", "down_proj")
|
||||||
|
):
|
||||||
|
params_to_quantize.append((module, param_name))
|
||||||
|
return params_to_quantize
|
||||||
|
|
||||||
|
|
||||||
|
def quantize_moe_expert_params(model, quant_type=None, compress_statistics=None):
|
||||||
|
"""Quantize 3D nn.Parameter expert weights that BnB skips during model loading.
|
||||||
|
|
||||||
|
Reads quant_type and compress_statistics from the model's quantization_config
|
||||||
|
when not explicitly provided, so that the same settings used for nn.Linear
|
||||||
|
quantization are applied to the MoE expert parameters.
|
||||||
|
"""
|
||||||
|
from bitsandbytes.nn.parametrize import replace_parameter_4bit
|
||||||
|
|
||||||
|
params_to_quantize = find_unquantized_expert_params(model)
|
||||||
|
if not params_to_quantize:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Derive settings from model's BnB config if not explicitly provided
|
||||||
|
if quant_type is None or compress_statistics is None:
|
||||||
|
bnb_config = getattr(model.config, "quantization_config", None)
|
||||||
|
if bnb_config is not None:
|
||||||
|
if quant_type is None:
|
||||||
|
quant_type = getattr(bnb_config, "bnb_4bit_quant_type", "nf4")
|
||||||
|
if compress_statistics is None:
|
||||||
|
compress_statistics = getattr(
|
||||||
|
bnb_config, "bnb_4bit_use_double_quant", True
|
||||||
|
)
|
||||||
|
# Final defaults
|
||||||
|
if quant_type is None:
|
||||||
|
quant_type = "nf4"
|
||||||
|
if compress_statistics is None:
|
||||||
|
compress_statistics = True
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
for module, param_name in params_to_quantize:
|
||||||
|
replace_parameter_4bit(
|
||||||
|
module,
|
||||||
|
param_name,
|
||||||
|
compress_statistics=compress_statistics,
|
||||||
|
quant_type=quant_type,
|
||||||
|
)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
LOG.info(
|
||||||
|
"Quantized %d MoE expert parameters to 4-bit (quant_type=%s, compress_statistics=%s)",
|
||||||
|
count,
|
||||||
|
quant_type,
|
||||||
|
compress_statistics,
|
||||||
|
)
|
||||||
|
return True
|
||||||
Reference in New Issue
Block a user