moe quant patch for merge miss match (#3483)

* moe quant patch for merge miss match

* lint

* revert test + fix moe patch

* comment fixxes

* e2e tests

* mismatch fixx tested

* mis match fix wwith vllm compatablity + test

* comment lint

* fix: missing os import, duplicate no op

* chore: simplify comments

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
This commit is contained in:
VED
2026-03-16 07:40:30 +05:30
committed by GitHub
parent d8a05744d7
commit a806704e94
3 changed files with 220 additions and 33 deletions

View File

@@ -416,16 +416,21 @@ class PatchManager:
os.environ["HF_DEACTIVATE_ASYNC_LOAD"] = "1"
def _apply_moe_expert_quantization_patch(self):
"""Patch transformers weight loading to quantize MoE expert params on-the-fly."""
if not self.cfg.quantize_moe_experts:
"""Patch transformers weight loading and PEFT for MoE expert quantization."""
has_target_params = bool(getattr(self.cfg, "lora_target_parameters", None))
if not self.cfg.quantize_moe_experts and not has_target_params:
return
from axolotl.monkeypatch.moe_quant import (
patch_moe_quantization_on_load,
patch_peft_target_parameters_matching,
)
patch_moe_quantization_on_load(self.cfg)
if self.cfg.quantize_moe_experts:
from axolotl.monkeypatch.moe_quant import patch_moe_quantization_on_load
patch_moe_quantization_on_load(self.cfg)
patch_peft_target_parameters_matching()
def _finalize_moe_expert_quantization(self, model: PreTrainedModel):

View File

@@ -1,11 +1,4 @@
"""
Loading-time quantization for MoE expert weights stored as 3D nn.Parameter tensors.
In transformers v5, MoE models store expert weights as fused 3D tensors that BnB
skips (only targets nn.Linear). This module patches weight loading to quantize them
on-the-fly (4-bit via bitsandbytes parametrize, 8-bit via custom int8 parametrization),
reducing peak VRAM from "all experts in bf16" to "one expert at a time."
"""
"""Loading-time quantization for MoE expert weights stored as 3D nn.Parameter tensors."""
import bitsandbytes as bnb
import torch
@@ -15,18 +8,20 @@ from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
# Module-level state for the loading-time quantization patch.
_moe_load_state = {
"count": 0,
"mode": "4bit",
"quant_type": "nf4",
"compress_statistics": True,
"patched": False,
# Module path → param names in definition order, captured before quantization.
# Without this, alphabetical loading order would mismatch merge order.
"expert_param_order": {},
}
class Bnb8bitParametrization(torch.nn.Module):
"""Parametrization that dequantizes int8 row-wise quantized data on access."""
"""Dequantizes int8 row-wise quantized data on access."""
def __init__(self, row_stats: torch.Tensor):
super().__init__()
@@ -34,7 +29,7 @@ class Bnb8bitParametrization(torch.nn.Module):
@torch.no_grad()
def forward(self, quantized_param: torch.Tensor) -> torch.Tensor:
# Flatten 3D+ to 2D for BnB's dequant, then reshape back.
"""Flatten 3D+ to 2D for BnB's dequant, then reshape back."""
orig_shape = quantized_param.shape
if quantized_param.ndim > 2:
quantized_param = quantized_param.reshape(-1, orig_shape[-1])
@@ -74,14 +69,11 @@ def replace_parameter_8bit(module, param_name):
def patch_moe_quantization_on_load(cfg):
"""Patch transformers' weight loading to quantize MoE expert params on-the-fly.
Wraps ``set_param_for_module`` so that 3D+ CUDA tensors with "expert" in their
name are quantized (4-bit or 8-bit) as they're loaded, keeping peak VRAM low.
"""
"""Patch transformers' weight loading to quantize MoE expert params on-the-fly."""
mode = "8bit" if getattr(cfg, "load_in_8bit", False) else "4bit"
_moe_load_state["mode"] = mode
_moe_load_state["count"] = 0
_moe_load_state["expert_param_order"] = {}
if _moe_load_state["patched"]:
LOG.debug("MoE loading-time quantization patch already active")
@@ -113,7 +105,6 @@ def patch_moe_quantization_on_load(cfg):
def _patched_set_param_for_module(model, target_name, param_value, *args, **kwargs):
original_set_param(model, target_name, param_value, *args, **kwargs)
# Quantize 3D+ expert params that BnB skipped (only on CUDA).
if param_value.ndim >= 3 and param_value.is_cuda:
mod_path, _, pname = target_name.rpartition(".")
mod = model.get_submodule(mod_path) if mod_path else model
@@ -126,6 +117,13 @@ def patch_moe_quantization_on_load(cfg):
)
return
# Record definition order before parametrizations override it
# with alphabetical order.
if mod_path not in _moe_load_state["expert_param_order"]:
_moe_load_state["expert_param_order"][mod_path] = list(
mod._parameters.keys()
)
if _moe_load_state["mode"] == "4bit":
replace_parameter_4bit(
mod,
@@ -151,20 +149,28 @@ def get_moe_quantized_count():
def patch_peft_target_parameters_matching():
"""Fix PEFT's _inject_parameters to use suffix matching for parametrized modules."""
"""Fix PEFT's _inject_parameters for target_parameters on quantized MoE experts.
1. Expands short suffixes to full module paths for parametrized modules.
2. Iterates params in definition order (not alphabetical order) so saved
adapters are compatible with standard PEFT, vLLM, etc.
"""
if getattr(patch_peft_target_parameters_matching, "_axolotl_patched", False):
return
from peft.tuners.tuners_utils import BaseTuner
original_inject = BaseTuner._inject_parameters
from contextlib import nullcontext
from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer
from peft.utils.integrations import init_empty_weights
from peft.utils.other import _get_submodules
def _patched_inject_parameters(
self, peft_config, model, adapter_name, low_cpu_mem_usage
):
# Patch target_parameters to use full paths for parametrized modules
original_targets = list(peft_config.target_parameters)
expanded = set(original_targets)
# Expand short suffixes to full paths for parametrized modules.
for module_name, module in model.named_modules():
if not hasattr(module, "parametrizations"):
continue
@@ -175,14 +181,74 @@ def patch_peft_target_parameters_matching():
) and hasattr(module, param_name):
expanded.add(f"{module_name}.{param_name}")
peft_config.target_parameters = sorted(expanded)
try:
return original_inject(
self, peft_config, model, adapter_name, low_cpu_mem_usage
)
finally:
peft_config.target_parameters = original_targets
target_names_set = expanded
def strip_base_layer_from_name(module_name):
name = ".base_layer"
while name in module_name:
prefix, _, suffix = module_name.rpartition(name)
module_name = prefix + suffix
return module_name
def create_and_replace_param(module_name, key, param_name):
parent, target, target_name = _get_submodules(model, module_name)
unwrapped_module_name = strip_base_layer_from_name(module_name)
unwrapped_module = model.get_submodule(unwrapped_module_name)
if (
isinstance(unwrapped_module, BaseTunerLayer)
and unwrapped_module.__class__.__name__ != "ParamWrapper"
):
raise ValueError(
f"Trying to wrap an `nn.Parameter` of layer "
f"'{unwrapped_module_name}' of type "
f"{type(target).__name__}, which is not a valid target. "
f"Make sure that this layer is not also targeted with "
f"`target_modules`."
)
self._check_target_module_compatiblity(peft_config, model, target_name)
ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
with ctx():
self._create_and_replace(
peft_config,
adapter_name,
target,
target_name,
parent,
current_key=key,
parameter_name=param_name.rpartition(".")[-1],
)
# Use definition order (not alphabetical order) for parametrized modules
# so ParamWrapper nesting matches vanilla PEFT on a plain model.
expert_param_order = _moe_load_state.get("expert_param_order", {})
for module_name, module in model.named_modules():
if hasattr(module, "parametrizations"):
stored_order = expert_param_order.get(module_name)
if stored_order is not None:
params_iter = [
p for p in stored_order if p in module.parametrizations
]
else:
# Fallback for paths that bypass model loading (e.g. unit tests).
params_iter = list(module.parametrizations.keys())
for param_name in params_iter:
key = f"{module_name}.{param_name}"
if (key in target_names_set) or any(
key.endswith(f".{t}") for t in target_names_set
):
create_and_replace_param(module_name, key, param_name)
self.targeted_parameter_names.append(key)
else:
unwrapped_module_name = strip_base_layer_from_name(module_name)
for param_name, _ in module.named_parameters(recurse=False):
key = f"{unwrapped_module_name}.{param_name}"
if (key in target_names_set) or any(
key.endswith(f".{t}") for t in target_names_set
):
create_and_replace_param(module_name, key, param_name)
self.targeted_parameter_names.append(key)
BaseTuner._inject_parameters = _patched_inject_parameters
patch_peft_target_parameters_matching._axolotl_patched = True
LOG.info("Patched PEFT _inject_parameters for parametrized module suffix matching")
LOG.info("Patched PEFT _inject_parameters for consistent ParamWrapper ordering")