moe quant patch for merge miss match (#3483)

* moe quant patch for merge miss match

* lint

* revert test + fix moe patch

* comment fixxes

* e2e tests

* mismatch fixx tested

* mis match fix wwith vllm compatablity + test

* comment lint

* fix: missing os import, duplicate no op

* chore: simplify comments

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
This commit is contained in:
VED
2026-03-16 07:40:30 +05:30
committed by GitHub
parent d8a05744d7
commit a806704e94
3 changed files with 220 additions and 33 deletions

View File

@@ -416,16 +416,21 @@ class PatchManager:
os.environ["HF_DEACTIVATE_ASYNC_LOAD"] = "1"
def _apply_moe_expert_quantization_patch(self):
"""Patch transformers weight loading to quantize MoE expert params on-the-fly."""
if not self.cfg.quantize_moe_experts:
"""Patch transformers weight loading and PEFT for MoE expert quantization."""
has_target_params = bool(getattr(self.cfg, "lora_target_parameters", None))
if not self.cfg.quantize_moe_experts and not has_target_params:
return
from axolotl.monkeypatch.moe_quant import (
patch_moe_quantization_on_load,
patch_peft_target_parameters_matching,
)
patch_moe_quantization_on_load(self.cfg)
if self.cfg.quantize_moe_experts:
from axolotl.monkeypatch.moe_quant import patch_moe_quantization_on_load
patch_moe_quantization_on_load(self.cfg)
patch_peft_target_parameters_matching()
def _finalize_moe_expert_quantization(self, model: PreTrainedModel):

View File

@@ -1,11 +1,4 @@
"""
Loading-time quantization for MoE expert weights stored as 3D nn.Parameter tensors.
In transformers v5, MoE models store expert weights as fused 3D tensors that BnB
skips (only targets nn.Linear). This module patches weight loading to quantize them
on-the-fly (4-bit via bitsandbytes parametrize, 8-bit via custom int8 parametrization),
reducing peak VRAM from "all experts in bf16" to "one expert at a time."
"""
"""Loading-time quantization for MoE expert weights stored as 3D nn.Parameter tensors."""
import bitsandbytes as bnb
import torch
@@ -15,18 +8,20 @@ from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
# Module-level state for the loading-time quantization patch.
_moe_load_state = {
"count": 0,
"mode": "4bit",
"quant_type": "nf4",
"compress_statistics": True,
"patched": False,
# Module path → param names in definition order, captured before quantization.
# Without this, alphabetical loading order would mismatch merge order.
"expert_param_order": {},
}
class Bnb8bitParametrization(torch.nn.Module):
"""Parametrization that dequantizes int8 row-wise quantized data on access."""
"""Dequantizes int8 row-wise quantized data on access."""
def __init__(self, row_stats: torch.Tensor):
super().__init__()
@@ -34,7 +29,7 @@ class Bnb8bitParametrization(torch.nn.Module):
@torch.no_grad()
def forward(self, quantized_param: torch.Tensor) -> torch.Tensor:
# Flatten 3D+ to 2D for BnB's dequant, then reshape back.
"""Flatten 3D+ to 2D for BnB's dequant, then reshape back."""
orig_shape = quantized_param.shape
if quantized_param.ndim > 2:
quantized_param = quantized_param.reshape(-1, orig_shape[-1])
@@ -74,14 +69,11 @@ def replace_parameter_8bit(module, param_name):
def patch_moe_quantization_on_load(cfg):
"""Patch transformers' weight loading to quantize MoE expert params on-the-fly.
Wraps ``set_param_for_module`` so that 3D+ CUDA tensors with "expert" in their
name are quantized (4-bit or 8-bit) as they're loaded, keeping peak VRAM low.
"""
"""Patch transformers' weight loading to quantize MoE expert params on-the-fly."""
mode = "8bit" if getattr(cfg, "load_in_8bit", False) else "4bit"
_moe_load_state["mode"] = mode
_moe_load_state["count"] = 0
_moe_load_state["expert_param_order"] = {}
if _moe_load_state["patched"]:
LOG.debug("MoE loading-time quantization patch already active")
@@ -113,7 +105,6 @@ def patch_moe_quantization_on_load(cfg):
def _patched_set_param_for_module(model, target_name, param_value, *args, **kwargs):
original_set_param(model, target_name, param_value, *args, **kwargs)
# Quantize 3D+ expert params that BnB skipped (only on CUDA).
if param_value.ndim >= 3 and param_value.is_cuda:
mod_path, _, pname = target_name.rpartition(".")
mod = model.get_submodule(mod_path) if mod_path else model
@@ -126,6 +117,13 @@ def patch_moe_quantization_on_load(cfg):
)
return
# Record definition order before parametrizations override it
# with alphabetical order.
if mod_path not in _moe_load_state["expert_param_order"]:
_moe_load_state["expert_param_order"][mod_path] = list(
mod._parameters.keys()
)
if _moe_load_state["mode"] == "4bit":
replace_parameter_4bit(
mod,
@@ -151,20 +149,28 @@ def get_moe_quantized_count():
def patch_peft_target_parameters_matching():
"""Fix PEFT's _inject_parameters to use suffix matching for parametrized modules."""
"""Fix PEFT's _inject_parameters for target_parameters on quantized MoE experts.
1. Expands short suffixes to full module paths for parametrized modules.
2. Iterates params in definition order (not alphabetical order) so saved
adapters are compatible with standard PEFT, vLLM, etc.
"""
if getattr(patch_peft_target_parameters_matching, "_axolotl_patched", False):
return
from peft.tuners.tuners_utils import BaseTuner
original_inject = BaseTuner._inject_parameters
from contextlib import nullcontext
from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer
from peft.utils.integrations import init_empty_weights
from peft.utils.other import _get_submodules
def _patched_inject_parameters(
self, peft_config, model, adapter_name, low_cpu_mem_usage
):
# Patch target_parameters to use full paths for parametrized modules
original_targets = list(peft_config.target_parameters)
expanded = set(original_targets)
# Expand short suffixes to full paths for parametrized modules.
for module_name, module in model.named_modules():
if not hasattr(module, "parametrizations"):
continue
@@ -175,14 +181,74 @@ def patch_peft_target_parameters_matching():
) and hasattr(module, param_name):
expanded.add(f"{module_name}.{param_name}")
peft_config.target_parameters = sorted(expanded)
try:
return original_inject(
self, peft_config, model, adapter_name, low_cpu_mem_usage
)
finally:
peft_config.target_parameters = original_targets
target_names_set = expanded
def strip_base_layer_from_name(module_name):
name = ".base_layer"
while name in module_name:
prefix, _, suffix = module_name.rpartition(name)
module_name = prefix + suffix
return module_name
def create_and_replace_param(module_name, key, param_name):
parent, target, target_name = _get_submodules(model, module_name)
unwrapped_module_name = strip_base_layer_from_name(module_name)
unwrapped_module = model.get_submodule(unwrapped_module_name)
if (
isinstance(unwrapped_module, BaseTunerLayer)
and unwrapped_module.__class__.__name__ != "ParamWrapper"
):
raise ValueError(
f"Trying to wrap an `nn.Parameter` of layer "
f"'{unwrapped_module_name}' of type "
f"{type(target).__name__}, which is not a valid target. "
f"Make sure that this layer is not also targeted with "
f"`target_modules`."
)
self._check_target_module_compatiblity(peft_config, model, target_name)
ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
with ctx():
self._create_and_replace(
peft_config,
adapter_name,
target,
target_name,
parent,
current_key=key,
parameter_name=param_name.rpartition(".")[-1],
)
# Use definition order (not alphabetical order) for parametrized modules
# so ParamWrapper nesting matches vanilla PEFT on a plain model.
expert_param_order = _moe_load_state.get("expert_param_order", {})
for module_name, module in model.named_modules():
if hasattr(module, "parametrizations"):
stored_order = expert_param_order.get(module_name)
if stored_order is not None:
params_iter = [
p for p in stored_order if p in module.parametrizations
]
else:
# Fallback for paths that bypass model loading (e.g. unit tests).
params_iter = list(module.parametrizations.keys())
for param_name in params_iter:
key = f"{module_name}.{param_name}"
if (key in target_names_set) or any(
key.endswith(f".{t}") for t in target_names_set
):
create_and_replace_param(module_name, key, param_name)
self.targeted_parameter_names.append(key)
else:
unwrapped_module_name = strip_base_layer_from_name(module_name)
for param_name, _ in module.named_parameters(recurse=False):
key = f"{unwrapped_module_name}.{param_name}"
if (key in target_names_set) or any(
key.endswith(f".{t}") for t in target_names_set
):
create_and_replace_param(module_name, key, param_name)
self.targeted_parameter_names.append(key)
BaseTuner._inject_parameters = _patched_inject_parameters
patch_peft_target_parameters_matching._axolotl_patched = True
LOG.info("Patched PEFT _inject_parameters for parametrized module suffix matching")
LOG.info("Patched PEFT _inject_parameters for consistent ParamWrapper ordering")

View File

@@ -154,3 +154,119 @@ class TestPeftPatchIdempotency:
finally:
BaseTuner._inject_parameters = original
patch_peft_target_parameters_matching._axolotl_patched = False
class TestMoeAdapterTrainMergeRoundtrip:
"""E2E: train adapter on quantized MoE experts, then merge onto plain model.
Verifies that param wrapping order during training matches merge, preventing
size mismatch errors when loading adapters in standard PEFT/vLLM.
"""
@staticmethod
def _make_classes():
"""Return FakeExperts and FakeModel classes shared by both model builders."""
import torch
import torch.nn as nn
class FakeExperts(nn.Module):
def __init__(self):
super().__init__()
# Model definition order: gate_up_proj first, then down_proj.
self.gate_up_proj = nn.Parameter(torch.randn(4, 16, 8))
self.down_proj = nn.Parameter(torch.randn(4, 8, 16))
def forward(self, x):
x = torch.matmul(x, self.gate_up_proj[0].T) # (batch, 16)
x = torch.matmul(x, self.down_proj[0].T) # (batch, 8)
return x
class FakeModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(8, 8)
self.experts = FakeExperts()
def forward(self, x):
return self.linear(x) + self.experts(x)
return FakeExperts, FakeModel
@staticmethod
def _make_quantized_model():
"""Training model: parametrizations registered in alphabetical order."""
import torch.nn as nn
import torch.nn.utils.parametrize as P
from axolotl.monkeypatch.moe_quant import _moe_load_state
_, FakeModel = TestMoeAdapterTrainMergeRoundtrip._make_classes()
class PassthroughParametrization(nn.Module):
def forward(self, x):
return x
model = FakeModel()
# Record definition order before parametrization (mirrors real loading).
_moe_load_state["expert_param_order"]["experts"] = list(
model.experts._parameters.keys()
)
# Register in alphabetical order to expose the ordering mismatch.
P.register_parametrization(
model.experts, "down_proj", PassthroughParametrization(), unsafe=True
)
P.register_parametrization(
model.experts, "gate_up_proj", PassthroughParametrization(), unsafe=True
)
return model
@staticmethod
def _make_plain_model():
"""Merge model: no parametrizations — standard branch uses definition order."""
_, FakeModel = TestMoeAdapterTrainMergeRoundtrip._make_classes()
return FakeModel()
def test_train_save_merge_no_size_mismatch(self, tmp_path):
"""Train on quantized experts, merge onto plain model — must not raise."""
import torch
from peft import LoraConfig, PeftModel, get_peft_model
from peft.tuners.tuners_utils import BaseTuner
from axolotl.monkeypatch.moe_quant import (
_moe_load_state,
patch_peft_target_parameters_matching,
)
adapter_dir = tmp_path / "adapter"
lora_cfg = LoraConfig(
r=4,
lora_alpha=8,
target_modules=[],
target_parameters=["experts.gate_up_proj", "experts.down_proj"],
lora_dropout=0.0,
bias="none",
)
original_inject = BaseTuner._inject_parameters
# Training phase: quantized model (parametrized branch) with axolotl patch.
_moe_load_state["expert_param_order"] = {}
patch_peft_target_parameters_matching()
try:
peft_model = get_peft_model(self._make_quantized_model(), lora_cfg)
finally:
BaseTuner._inject_parameters = original_inject
patch_peft_target_parameters_matching._axolotl_patched = False
optimizer = torch.optim.SGD(peft_model.parameters(), lr=1e-3)
for _ in range(3):
peft_model(torch.randn(2, 8)).sum().backward()
optimizer.step()
optimizer.zero_grad()
peft_model.save_pretrained(str(adapter_dir))
# Merge with standard PEFT (no axolotl patch) to verify external compatibility.
loaded = PeftModel.from_pretrained(self._make_plain_model(), str(adapter_dir))
merged = loaded.merge_and_unload()
assert merged is not None