moe quant patch for merge miss match (#3483)
* moe quant patch for merge miss match * lint * revert test + fix moe patch * comment fixxes * e2e tests * mismatch fixx tested * mis match fix wwith vllm compatablity + test * comment lint * fix: missing os import, duplicate no op * chore: simplify comments --------- Co-authored-by: NanoCode012 <nano@axolotl.ai>
This commit is contained in:
@@ -416,16 +416,21 @@ class PatchManager:
|
||||
os.environ["HF_DEACTIVATE_ASYNC_LOAD"] = "1"
|
||||
|
||||
def _apply_moe_expert_quantization_patch(self):
|
||||
"""Patch transformers weight loading to quantize MoE expert params on-the-fly."""
|
||||
if not self.cfg.quantize_moe_experts:
|
||||
"""Patch transformers weight loading and PEFT for MoE expert quantization."""
|
||||
has_target_params = bool(getattr(self.cfg, "lora_target_parameters", None))
|
||||
|
||||
if not self.cfg.quantize_moe_experts and not has_target_params:
|
||||
return
|
||||
|
||||
from axolotl.monkeypatch.moe_quant import (
|
||||
patch_moe_quantization_on_load,
|
||||
patch_peft_target_parameters_matching,
|
||||
)
|
||||
|
||||
patch_moe_quantization_on_load(self.cfg)
|
||||
if self.cfg.quantize_moe_experts:
|
||||
from axolotl.monkeypatch.moe_quant import patch_moe_quantization_on_load
|
||||
|
||||
patch_moe_quantization_on_load(self.cfg)
|
||||
|
||||
patch_peft_target_parameters_matching()
|
||||
|
||||
def _finalize_moe_expert_quantization(self, model: PreTrainedModel):
|
||||
|
||||
@@ -1,11 +1,4 @@
|
||||
"""
|
||||
Loading-time quantization for MoE expert weights stored as 3D nn.Parameter tensors.
|
||||
|
||||
In transformers v5, MoE models store expert weights as fused 3D tensors that BnB
|
||||
skips (only targets nn.Linear). This module patches weight loading to quantize them
|
||||
on-the-fly (4-bit via bitsandbytes parametrize, 8-bit via custom int8 parametrization),
|
||||
reducing peak VRAM from "all experts in bf16" to "one expert at a time."
|
||||
"""
|
||||
"""Loading-time quantization for MoE expert weights stored as 3D nn.Parameter tensors."""
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
@@ -15,18 +8,20 @@ from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
# Module-level state for the loading-time quantization patch.
|
||||
_moe_load_state = {
|
||||
"count": 0,
|
||||
"mode": "4bit",
|
||||
"quant_type": "nf4",
|
||||
"compress_statistics": True,
|
||||
"patched": False,
|
||||
# Module path → param names in definition order, captured before quantization.
|
||||
# Without this, alphabetical loading order would mismatch merge order.
|
||||
"expert_param_order": {},
|
||||
}
|
||||
|
||||
|
||||
class Bnb8bitParametrization(torch.nn.Module):
|
||||
"""Parametrization that dequantizes int8 row-wise quantized data on access."""
|
||||
"""Dequantizes int8 row-wise quantized data on access."""
|
||||
|
||||
def __init__(self, row_stats: torch.Tensor):
|
||||
super().__init__()
|
||||
@@ -34,7 +29,7 @@ class Bnb8bitParametrization(torch.nn.Module):
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, quantized_param: torch.Tensor) -> torch.Tensor:
|
||||
# Flatten 3D+ to 2D for BnB's dequant, then reshape back.
|
||||
"""Flatten 3D+ to 2D for BnB's dequant, then reshape back."""
|
||||
orig_shape = quantized_param.shape
|
||||
if quantized_param.ndim > 2:
|
||||
quantized_param = quantized_param.reshape(-1, orig_shape[-1])
|
||||
@@ -74,14 +69,11 @@ def replace_parameter_8bit(module, param_name):
|
||||
|
||||
|
||||
def patch_moe_quantization_on_load(cfg):
|
||||
"""Patch transformers' weight loading to quantize MoE expert params on-the-fly.
|
||||
|
||||
Wraps ``set_param_for_module`` so that 3D+ CUDA tensors with "expert" in their
|
||||
name are quantized (4-bit or 8-bit) as they're loaded, keeping peak VRAM low.
|
||||
"""
|
||||
"""Patch transformers' weight loading to quantize MoE expert params on-the-fly."""
|
||||
mode = "8bit" if getattr(cfg, "load_in_8bit", False) else "4bit"
|
||||
_moe_load_state["mode"] = mode
|
||||
_moe_load_state["count"] = 0
|
||||
_moe_load_state["expert_param_order"] = {}
|
||||
|
||||
if _moe_load_state["patched"]:
|
||||
LOG.debug("MoE loading-time quantization patch already active")
|
||||
@@ -113,7 +105,6 @@ def patch_moe_quantization_on_load(cfg):
|
||||
def _patched_set_param_for_module(model, target_name, param_value, *args, **kwargs):
|
||||
original_set_param(model, target_name, param_value, *args, **kwargs)
|
||||
|
||||
# Quantize 3D+ expert params that BnB skipped (only on CUDA).
|
||||
if param_value.ndim >= 3 and param_value.is_cuda:
|
||||
mod_path, _, pname = target_name.rpartition(".")
|
||||
mod = model.get_submodule(mod_path) if mod_path else model
|
||||
@@ -126,6 +117,13 @@ def patch_moe_quantization_on_load(cfg):
|
||||
)
|
||||
return
|
||||
|
||||
# Record definition order before parametrizations override it
|
||||
# with alphabetical order.
|
||||
if mod_path not in _moe_load_state["expert_param_order"]:
|
||||
_moe_load_state["expert_param_order"][mod_path] = list(
|
||||
mod._parameters.keys()
|
||||
)
|
||||
|
||||
if _moe_load_state["mode"] == "4bit":
|
||||
replace_parameter_4bit(
|
||||
mod,
|
||||
@@ -151,20 +149,28 @@ def get_moe_quantized_count():
|
||||
|
||||
|
||||
def patch_peft_target_parameters_matching():
|
||||
"""Fix PEFT's _inject_parameters to use suffix matching for parametrized modules."""
|
||||
"""Fix PEFT's _inject_parameters for target_parameters on quantized MoE experts.
|
||||
|
||||
1. Expands short suffixes to full module paths for parametrized modules.
|
||||
2. Iterates params in definition order (not alphabetical order) so saved
|
||||
adapters are compatible with standard PEFT, vLLM, etc.
|
||||
"""
|
||||
if getattr(patch_peft_target_parameters_matching, "_axolotl_patched", False):
|
||||
return
|
||||
from peft.tuners.tuners_utils import BaseTuner
|
||||
|
||||
original_inject = BaseTuner._inject_parameters
|
||||
from contextlib import nullcontext
|
||||
|
||||
from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer
|
||||
from peft.utils.integrations import init_empty_weights
|
||||
from peft.utils.other import _get_submodules
|
||||
|
||||
def _patched_inject_parameters(
|
||||
self, peft_config, model, adapter_name, low_cpu_mem_usage
|
||||
):
|
||||
# Patch target_parameters to use full paths for parametrized modules
|
||||
original_targets = list(peft_config.target_parameters)
|
||||
expanded = set(original_targets)
|
||||
|
||||
# Expand short suffixes to full paths for parametrized modules.
|
||||
for module_name, module in model.named_modules():
|
||||
if not hasattr(module, "parametrizations"):
|
||||
continue
|
||||
@@ -175,14 +181,74 @@ def patch_peft_target_parameters_matching():
|
||||
) and hasattr(module, param_name):
|
||||
expanded.add(f"{module_name}.{param_name}")
|
||||
|
||||
peft_config.target_parameters = sorted(expanded)
|
||||
try:
|
||||
return original_inject(
|
||||
self, peft_config, model, adapter_name, low_cpu_mem_usage
|
||||
)
|
||||
finally:
|
||||
peft_config.target_parameters = original_targets
|
||||
target_names_set = expanded
|
||||
|
||||
def strip_base_layer_from_name(module_name):
|
||||
name = ".base_layer"
|
||||
while name in module_name:
|
||||
prefix, _, suffix = module_name.rpartition(name)
|
||||
module_name = prefix + suffix
|
||||
return module_name
|
||||
|
||||
def create_and_replace_param(module_name, key, param_name):
|
||||
parent, target, target_name = _get_submodules(model, module_name)
|
||||
unwrapped_module_name = strip_base_layer_from_name(module_name)
|
||||
unwrapped_module = model.get_submodule(unwrapped_module_name)
|
||||
if (
|
||||
isinstance(unwrapped_module, BaseTunerLayer)
|
||||
and unwrapped_module.__class__.__name__ != "ParamWrapper"
|
||||
):
|
||||
raise ValueError(
|
||||
f"Trying to wrap an `nn.Parameter` of layer "
|
||||
f"'{unwrapped_module_name}' of type "
|
||||
f"{type(target).__name__}, which is not a valid target. "
|
||||
f"Make sure that this layer is not also targeted with "
|
||||
f"`target_modules`."
|
||||
)
|
||||
self._check_target_module_compatiblity(peft_config, model, target_name)
|
||||
ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
|
||||
with ctx():
|
||||
self._create_and_replace(
|
||||
peft_config,
|
||||
adapter_name,
|
||||
target,
|
||||
target_name,
|
||||
parent,
|
||||
current_key=key,
|
||||
parameter_name=param_name.rpartition(".")[-1],
|
||||
)
|
||||
|
||||
# Use definition order (not alphabetical order) for parametrized modules
|
||||
# so ParamWrapper nesting matches vanilla PEFT on a plain model.
|
||||
expert_param_order = _moe_load_state.get("expert_param_order", {})
|
||||
|
||||
for module_name, module in model.named_modules():
|
||||
if hasattr(module, "parametrizations"):
|
||||
stored_order = expert_param_order.get(module_name)
|
||||
if stored_order is not None:
|
||||
params_iter = [
|
||||
p for p in stored_order if p in module.parametrizations
|
||||
]
|
||||
else:
|
||||
# Fallback for paths that bypass model loading (e.g. unit tests).
|
||||
params_iter = list(module.parametrizations.keys())
|
||||
for param_name in params_iter:
|
||||
key = f"{module_name}.{param_name}"
|
||||
if (key in target_names_set) or any(
|
||||
key.endswith(f".{t}") for t in target_names_set
|
||||
):
|
||||
create_and_replace_param(module_name, key, param_name)
|
||||
self.targeted_parameter_names.append(key)
|
||||
else:
|
||||
unwrapped_module_name = strip_base_layer_from_name(module_name)
|
||||
for param_name, _ in module.named_parameters(recurse=False):
|
||||
key = f"{unwrapped_module_name}.{param_name}"
|
||||
if (key in target_names_set) or any(
|
||||
key.endswith(f".{t}") for t in target_names_set
|
||||
):
|
||||
create_and_replace_param(module_name, key, param_name)
|
||||
self.targeted_parameter_names.append(key)
|
||||
|
||||
BaseTuner._inject_parameters = _patched_inject_parameters
|
||||
patch_peft_target_parameters_matching._axolotl_patched = True
|
||||
LOG.info("Patched PEFT _inject_parameters for parametrized module suffix matching")
|
||||
LOG.info("Patched PEFT _inject_parameters for consistent ParamWrapper ordering")
|
||||
|
||||
@@ -154,3 +154,119 @@ class TestPeftPatchIdempotency:
|
||||
finally:
|
||||
BaseTuner._inject_parameters = original
|
||||
patch_peft_target_parameters_matching._axolotl_patched = False
|
||||
|
||||
|
||||
class TestMoeAdapterTrainMergeRoundtrip:
|
||||
"""E2E: train adapter on quantized MoE experts, then merge onto plain model.
|
||||
|
||||
Verifies that param wrapping order during training matches merge, preventing
|
||||
size mismatch errors when loading adapters in standard PEFT/vLLM.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _make_classes():
|
||||
"""Return FakeExperts and FakeModel classes shared by both model builders."""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class FakeExperts(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Model definition order: gate_up_proj first, then down_proj.
|
||||
self.gate_up_proj = nn.Parameter(torch.randn(4, 16, 8))
|
||||
self.down_proj = nn.Parameter(torch.randn(4, 8, 16))
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.matmul(x, self.gate_up_proj[0].T) # (batch, 16)
|
||||
x = torch.matmul(x, self.down_proj[0].T) # (batch, 8)
|
||||
return x
|
||||
|
||||
class FakeModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(8, 8)
|
||||
self.experts = FakeExperts()
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x) + self.experts(x)
|
||||
|
||||
return FakeExperts, FakeModel
|
||||
|
||||
@staticmethod
|
||||
def _make_quantized_model():
|
||||
"""Training model: parametrizations registered in alphabetical order."""
|
||||
import torch.nn as nn
|
||||
import torch.nn.utils.parametrize as P
|
||||
|
||||
from axolotl.monkeypatch.moe_quant import _moe_load_state
|
||||
|
||||
_, FakeModel = TestMoeAdapterTrainMergeRoundtrip._make_classes()
|
||||
|
||||
class PassthroughParametrization(nn.Module):
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
model = FakeModel()
|
||||
|
||||
# Record definition order before parametrization (mirrors real loading).
|
||||
_moe_load_state["expert_param_order"]["experts"] = list(
|
||||
model.experts._parameters.keys()
|
||||
)
|
||||
|
||||
# Register in alphabetical order to expose the ordering mismatch.
|
||||
P.register_parametrization(
|
||||
model.experts, "down_proj", PassthroughParametrization(), unsafe=True
|
||||
)
|
||||
P.register_parametrization(
|
||||
model.experts, "gate_up_proj", PassthroughParametrization(), unsafe=True
|
||||
)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _make_plain_model():
|
||||
"""Merge model: no parametrizations — standard branch uses definition order."""
|
||||
_, FakeModel = TestMoeAdapterTrainMergeRoundtrip._make_classes()
|
||||
return FakeModel()
|
||||
|
||||
def test_train_save_merge_no_size_mismatch(self, tmp_path):
|
||||
"""Train on quantized experts, merge onto plain model — must not raise."""
|
||||
import torch
|
||||
from peft import LoraConfig, PeftModel, get_peft_model
|
||||
from peft.tuners.tuners_utils import BaseTuner
|
||||
|
||||
from axolotl.monkeypatch.moe_quant import (
|
||||
_moe_load_state,
|
||||
patch_peft_target_parameters_matching,
|
||||
)
|
||||
|
||||
adapter_dir = tmp_path / "adapter"
|
||||
lora_cfg = LoraConfig(
|
||||
r=4,
|
||||
lora_alpha=8,
|
||||
target_modules=[],
|
||||
target_parameters=["experts.gate_up_proj", "experts.down_proj"],
|
||||
lora_dropout=0.0,
|
||||
bias="none",
|
||||
)
|
||||
original_inject = BaseTuner._inject_parameters
|
||||
|
||||
# Training phase: quantized model (parametrized branch) with axolotl patch.
|
||||
_moe_load_state["expert_param_order"] = {}
|
||||
patch_peft_target_parameters_matching()
|
||||
try:
|
||||
peft_model = get_peft_model(self._make_quantized_model(), lora_cfg)
|
||||
finally:
|
||||
BaseTuner._inject_parameters = original_inject
|
||||
patch_peft_target_parameters_matching._axolotl_patched = False
|
||||
|
||||
optimizer = torch.optim.SGD(peft_model.parameters(), lr=1e-3)
|
||||
for _ in range(3):
|
||||
peft_model(torch.randn(2, 8)).sum().backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
peft_model.save_pretrained(str(adapter_dir))
|
||||
|
||||
# Merge with standard PEFT (no axolotl patch) to verify external compatibility.
|
||||
loaded = PeftModel.from_pretrained(self._make_plain_model(), str(adapter_dir))
|
||||
merged = loaded.merge_and_unload()
|
||||
assert merged is not None
|
||||
|
||||
Reference in New Issue
Block a user