moe quant patch for merge miss match (#3483)
* moe quant patch for merge miss match * lint * revert test + fix moe patch * comment fixxes * e2e tests * mismatch fixx tested * mis match fix wwith vllm compatablity + test * comment lint * fix: missing os import, duplicate no op * chore: simplify comments --------- Co-authored-by: NanoCode012 <nano@axolotl.ai>
This commit is contained in:
@@ -416,16 +416,21 @@ class PatchManager:
|
|||||||
os.environ["HF_DEACTIVATE_ASYNC_LOAD"] = "1"
|
os.environ["HF_DEACTIVATE_ASYNC_LOAD"] = "1"
|
||||||
|
|
||||||
def _apply_moe_expert_quantization_patch(self):
|
def _apply_moe_expert_quantization_patch(self):
|
||||||
"""Patch transformers weight loading to quantize MoE expert params on-the-fly."""
|
"""Patch transformers weight loading and PEFT for MoE expert quantization."""
|
||||||
if not self.cfg.quantize_moe_experts:
|
has_target_params = bool(getattr(self.cfg, "lora_target_parameters", None))
|
||||||
|
|
||||||
|
if not self.cfg.quantize_moe_experts and not has_target_params:
|
||||||
return
|
return
|
||||||
|
|
||||||
from axolotl.monkeypatch.moe_quant import (
|
from axolotl.monkeypatch.moe_quant import (
|
||||||
patch_moe_quantization_on_load,
|
|
||||||
patch_peft_target_parameters_matching,
|
patch_peft_target_parameters_matching,
|
||||||
)
|
)
|
||||||
|
|
||||||
patch_moe_quantization_on_load(self.cfg)
|
if self.cfg.quantize_moe_experts:
|
||||||
|
from axolotl.monkeypatch.moe_quant import patch_moe_quantization_on_load
|
||||||
|
|
||||||
|
patch_moe_quantization_on_load(self.cfg)
|
||||||
|
|
||||||
patch_peft_target_parameters_matching()
|
patch_peft_target_parameters_matching()
|
||||||
|
|
||||||
def _finalize_moe_expert_quantization(self, model: PreTrainedModel):
|
def _finalize_moe_expert_quantization(self, model: PreTrainedModel):
|
||||||
|
|||||||
@@ -1,11 +1,4 @@
|
|||||||
"""
|
"""Loading-time quantization for MoE expert weights stored as 3D nn.Parameter tensors."""
|
||||||
Loading-time quantization for MoE expert weights stored as 3D nn.Parameter tensors.
|
|
||||||
|
|
||||||
In transformers v5, MoE models store expert weights as fused 3D tensors that BnB
|
|
||||||
skips (only targets nn.Linear). This module patches weight loading to quantize them
|
|
||||||
on-the-fly (4-bit via bitsandbytes parametrize, 8-bit via custom int8 parametrization),
|
|
||||||
reducing peak VRAM from "all experts in bf16" to "one expert at a time."
|
|
||||||
"""
|
|
||||||
|
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
import torch
|
import torch
|
||||||
@@ -15,18 +8,20 @@ from axolotl.utils.logging import get_logger
|
|||||||
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
# Module-level state for the loading-time quantization patch.
|
|
||||||
_moe_load_state = {
|
_moe_load_state = {
|
||||||
"count": 0,
|
"count": 0,
|
||||||
"mode": "4bit",
|
"mode": "4bit",
|
||||||
"quant_type": "nf4",
|
"quant_type": "nf4",
|
||||||
"compress_statistics": True,
|
"compress_statistics": True,
|
||||||
"patched": False,
|
"patched": False,
|
||||||
|
# Module path → param names in definition order, captured before quantization.
|
||||||
|
# Without this, alphabetical loading order would mismatch merge order.
|
||||||
|
"expert_param_order": {},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class Bnb8bitParametrization(torch.nn.Module):
|
class Bnb8bitParametrization(torch.nn.Module):
|
||||||
"""Parametrization that dequantizes int8 row-wise quantized data on access."""
|
"""Dequantizes int8 row-wise quantized data on access."""
|
||||||
|
|
||||||
def __init__(self, row_stats: torch.Tensor):
|
def __init__(self, row_stats: torch.Tensor):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -34,7 +29,7 @@ class Bnb8bitParametrization(torch.nn.Module):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(self, quantized_param: torch.Tensor) -> torch.Tensor:
|
def forward(self, quantized_param: torch.Tensor) -> torch.Tensor:
|
||||||
# Flatten 3D+ to 2D for BnB's dequant, then reshape back.
|
"""Flatten 3D+ to 2D for BnB's dequant, then reshape back."""
|
||||||
orig_shape = quantized_param.shape
|
orig_shape = quantized_param.shape
|
||||||
if quantized_param.ndim > 2:
|
if quantized_param.ndim > 2:
|
||||||
quantized_param = quantized_param.reshape(-1, orig_shape[-1])
|
quantized_param = quantized_param.reshape(-1, orig_shape[-1])
|
||||||
@@ -74,14 +69,11 @@ def replace_parameter_8bit(module, param_name):
|
|||||||
|
|
||||||
|
|
||||||
def patch_moe_quantization_on_load(cfg):
|
def patch_moe_quantization_on_load(cfg):
|
||||||
"""Patch transformers' weight loading to quantize MoE expert params on-the-fly.
|
"""Patch transformers' weight loading to quantize MoE expert params on-the-fly."""
|
||||||
|
|
||||||
Wraps ``set_param_for_module`` so that 3D+ CUDA tensors with "expert" in their
|
|
||||||
name are quantized (4-bit or 8-bit) as they're loaded, keeping peak VRAM low.
|
|
||||||
"""
|
|
||||||
mode = "8bit" if getattr(cfg, "load_in_8bit", False) else "4bit"
|
mode = "8bit" if getattr(cfg, "load_in_8bit", False) else "4bit"
|
||||||
_moe_load_state["mode"] = mode
|
_moe_load_state["mode"] = mode
|
||||||
_moe_load_state["count"] = 0
|
_moe_load_state["count"] = 0
|
||||||
|
_moe_load_state["expert_param_order"] = {}
|
||||||
|
|
||||||
if _moe_load_state["patched"]:
|
if _moe_load_state["patched"]:
|
||||||
LOG.debug("MoE loading-time quantization patch already active")
|
LOG.debug("MoE loading-time quantization patch already active")
|
||||||
@@ -113,7 +105,6 @@ def patch_moe_quantization_on_load(cfg):
|
|||||||
def _patched_set_param_for_module(model, target_name, param_value, *args, **kwargs):
|
def _patched_set_param_for_module(model, target_name, param_value, *args, **kwargs):
|
||||||
original_set_param(model, target_name, param_value, *args, **kwargs)
|
original_set_param(model, target_name, param_value, *args, **kwargs)
|
||||||
|
|
||||||
# Quantize 3D+ expert params that BnB skipped (only on CUDA).
|
|
||||||
if param_value.ndim >= 3 and param_value.is_cuda:
|
if param_value.ndim >= 3 and param_value.is_cuda:
|
||||||
mod_path, _, pname = target_name.rpartition(".")
|
mod_path, _, pname = target_name.rpartition(".")
|
||||||
mod = model.get_submodule(mod_path) if mod_path else model
|
mod = model.get_submodule(mod_path) if mod_path else model
|
||||||
@@ -126,6 +117,13 @@ def patch_moe_quantization_on_load(cfg):
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Record definition order before parametrizations override it
|
||||||
|
# with alphabetical order.
|
||||||
|
if mod_path not in _moe_load_state["expert_param_order"]:
|
||||||
|
_moe_load_state["expert_param_order"][mod_path] = list(
|
||||||
|
mod._parameters.keys()
|
||||||
|
)
|
||||||
|
|
||||||
if _moe_load_state["mode"] == "4bit":
|
if _moe_load_state["mode"] == "4bit":
|
||||||
replace_parameter_4bit(
|
replace_parameter_4bit(
|
||||||
mod,
|
mod,
|
||||||
@@ -151,20 +149,28 @@ def get_moe_quantized_count():
|
|||||||
|
|
||||||
|
|
||||||
def patch_peft_target_parameters_matching():
|
def patch_peft_target_parameters_matching():
|
||||||
"""Fix PEFT's _inject_parameters to use suffix matching for parametrized modules."""
|
"""Fix PEFT's _inject_parameters for target_parameters on quantized MoE experts.
|
||||||
|
|
||||||
|
1. Expands short suffixes to full module paths for parametrized modules.
|
||||||
|
2. Iterates params in definition order (not alphabetical order) so saved
|
||||||
|
adapters are compatible with standard PEFT, vLLM, etc.
|
||||||
|
"""
|
||||||
if getattr(patch_peft_target_parameters_matching, "_axolotl_patched", False):
|
if getattr(patch_peft_target_parameters_matching, "_axolotl_patched", False):
|
||||||
return
|
return
|
||||||
from peft.tuners.tuners_utils import BaseTuner
|
|
||||||
|
|
||||||
original_inject = BaseTuner._inject_parameters
|
from contextlib import nullcontext
|
||||||
|
|
||||||
|
from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer
|
||||||
|
from peft.utils.integrations import init_empty_weights
|
||||||
|
from peft.utils.other import _get_submodules
|
||||||
|
|
||||||
def _patched_inject_parameters(
|
def _patched_inject_parameters(
|
||||||
self, peft_config, model, adapter_name, low_cpu_mem_usage
|
self, peft_config, model, adapter_name, low_cpu_mem_usage
|
||||||
):
|
):
|
||||||
# Patch target_parameters to use full paths for parametrized modules
|
|
||||||
original_targets = list(peft_config.target_parameters)
|
original_targets = list(peft_config.target_parameters)
|
||||||
expanded = set(original_targets)
|
expanded = set(original_targets)
|
||||||
|
|
||||||
|
# Expand short suffixes to full paths for parametrized modules.
|
||||||
for module_name, module in model.named_modules():
|
for module_name, module in model.named_modules():
|
||||||
if not hasattr(module, "parametrizations"):
|
if not hasattr(module, "parametrizations"):
|
||||||
continue
|
continue
|
||||||
@@ -175,14 +181,74 @@ def patch_peft_target_parameters_matching():
|
|||||||
) and hasattr(module, param_name):
|
) and hasattr(module, param_name):
|
||||||
expanded.add(f"{module_name}.{param_name}")
|
expanded.add(f"{module_name}.{param_name}")
|
||||||
|
|
||||||
peft_config.target_parameters = sorted(expanded)
|
target_names_set = expanded
|
||||||
try:
|
|
||||||
return original_inject(
|
def strip_base_layer_from_name(module_name):
|
||||||
self, peft_config, model, adapter_name, low_cpu_mem_usage
|
name = ".base_layer"
|
||||||
)
|
while name in module_name:
|
||||||
finally:
|
prefix, _, suffix = module_name.rpartition(name)
|
||||||
peft_config.target_parameters = original_targets
|
module_name = prefix + suffix
|
||||||
|
return module_name
|
||||||
|
|
||||||
|
def create_and_replace_param(module_name, key, param_name):
|
||||||
|
parent, target, target_name = _get_submodules(model, module_name)
|
||||||
|
unwrapped_module_name = strip_base_layer_from_name(module_name)
|
||||||
|
unwrapped_module = model.get_submodule(unwrapped_module_name)
|
||||||
|
if (
|
||||||
|
isinstance(unwrapped_module, BaseTunerLayer)
|
||||||
|
and unwrapped_module.__class__.__name__ != "ParamWrapper"
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"Trying to wrap an `nn.Parameter` of layer "
|
||||||
|
f"'{unwrapped_module_name}' of type "
|
||||||
|
f"{type(target).__name__}, which is not a valid target. "
|
||||||
|
f"Make sure that this layer is not also targeted with "
|
||||||
|
f"`target_modules`."
|
||||||
|
)
|
||||||
|
self._check_target_module_compatiblity(peft_config, model, target_name)
|
||||||
|
ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
|
||||||
|
with ctx():
|
||||||
|
self._create_and_replace(
|
||||||
|
peft_config,
|
||||||
|
adapter_name,
|
||||||
|
target,
|
||||||
|
target_name,
|
||||||
|
parent,
|
||||||
|
current_key=key,
|
||||||
|
parameter_name=param_name.rpartition(".")[-1],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use definition order (not alphabetical order) for parametrized modules
|
||||||
|
# so ParamWrapper nesting matches vanilla PEFT on a plain model.
|
||||||
|
expert_param_order = _moe_load_state.get("expert_param_order", {})
|
||||||
|
|
||||||
|
for module_name, module in model.named_modules():
|
||||||
|
if hasattr(module, "parametrizations"):
|
||||||
|
stored_order = expert_param_order.get(module_name)
|
||||||
|
if stored_order is not None:
|
||||||
|
params_iter = [
|
||||||
|
p for p in stored_order if p in module.parametrizations
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
# Fallback for paths that bypass model loading (e.g. unit tests).
|
||||||
|
params_iter = list(module.parametrizations.keys())
|
||||||
|
for param_name in params_iter:
|
||||||
|
key = f"{module_name}.{param_name}"
|
||||||
|
if (key in target_names_set) or any(
|
||||||
|
key.endswith(f".{t}") for t in target_names_set
|
||||||
|
):
|
||||||
|
create_and_replace_param(module_name, key, param_name)
|
||||||
|
self.targeted_parameter_names.append(key)
|
||||||
|
else:
|
||||||
|
unwrapped_module_name = strip_base_layer_from_name(module_name)
|
||||||
|
for param_name, _ in module.named_parameters(recurse=False):
|
||||||
|
key = f"{unwrapped_module_name}.{param_name}"
|
||||||
|
if (key in target_names_set) or any(
|
||||||
|
key.endswith(f".{t}") for t in target_names_set
|
||||||
|
):
|
||||||
|
create_and_replace_param(module_name, key, param_name)
|
||||||
|
self.targeted_parameter_names.append(key)
|
||||||
|
|
||||||
BaseTuner._inject_parameters = _patched_inject_parameters
|
BaseTuner._inject_parameters = _patched_inject_parameters
|
||||||
patch_peft_target_parameters_matching._axolotl_patched = True
|
patch_peft_target_parameters_matching._axolotl_patched = True
|
||||||
LOG.info("Patched PEFT _inject_parameters for parametrized module suffix matching")
|
LOG.info("Patched PEFT _inject_parameters for consistent ParamWrapper ordering")
|
||||||
|
|||||||
@@ -154,3 +154,119 @@ class TestPeftPatchIdempotency:
|
|||||||
finally:
|
finally:
|
||||||
BaseTuner._inject_parameters = original
|
BaseTuner._inject_parameters = original
|
||||||
patch_peft_target_parameters_matching._axolotl_patched = False
|
patch_peft_target_parameters_matching._axolotl_patched = False
|
||||||
|
|
||||||
|
|
||||||
|
class TestMoeAdapterTrainMergeRoundtrip:
|
||||||
|
"""E2E: train adapter on quantized MoE experts, then merge onto plain model.
|
||||||
|
|
||||||
|
Verifies that param wrapping order during training matches merge, preventing
|
||||||
|
size mismatch errors when loading adapters in standard PEFT/vLLM.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _make_classes():
|
||||||
|
"""Return FakeExperts and FakeModel classes shared by both model builders."""
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
class FakeExperts(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
# Model definition order: gate_up_proj first, then down_proj.
|
||||||
|
self.gate_up_proj = nn.Parameter(torch.randn(4, 16, 8))
|
||||||
|
self.down_proj = nn.Parameter(torch.randn(4, 8, 16))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = torch.matmul(x, self.gate_up_proj[0].T) # (batch, 16)
|
||||||
|
x = torch.matmul(x, self.down_proj[0].T) # (batch, 8)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class FakeModel(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.linear = nn.Linear(8, 8)
|
||||||
|
self.experts = FakeExperts()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.linear(x) + self.experts(x)
|
||||||
|
|
||||||
|
return FakeExperts, FakeModel
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _make_quantized_model():
|
||||||
|
"""Training model: parametrizations registered in alphabetical order."""
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.utils.parametrize as P
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.moe_quant import _moe_load_state
|
||||||
|
|
||||||
|
_, FakeModel = TestMoeAdapterTrainMergeRoundtrip._make_classes()
|
||||||
|
|
||||||
|
class PassthroughParametrization(nn.Module):
|
||||||
|
def forward(self, x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
model = FakeModel()
|
||||||
|
|
||||||
|
# Record definition order before parametrization (mirrors real loading).
|
||||||
|
_moe_load_state["expert_param_order"]["experts"] = list(
|
||||||
|
model.experts._parameters.keys()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Register in alphabetical order to expose the ordering mismatch.
|
||||||
|
P.register_parametrization(
|
||||||
|
model.experts, "down_proj", PassthroughParametrization(), unsafe=True
|
||||||
|
)
|
||||||
|
P.register_parametrization(
|
||||||
|
model.experts, "gate_up_proj", PassthroughParametrization(), unsafe=True
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _make_plain_model():
|
||||||
|
"""Merge model: no parametrizations — standard branch uses definition order."""
|
||||||
|
_, FakeModel = TestMoeAdapterTrainMergeRoundtrip._make_classes()
|
||||||
|
return FakeModel()
|
||||||
|
|
||||||
|
def test_train_save_merge_no_size_mismatch(self, tmp_path):
|
||||||
|
"""Train on quantized experts, merge onto plain model — must not raise."""
|
||||||
|
import torch
|
||||||
|
from peft import LoraConfig, PeftModel, get_peft_model
|
||||||
|
from peft.tuners.tuners_utils import BaseTuner
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.moe_quant import (
|
||||||
|
_moe_load_state,
|
||||||
|
patch_peft_target_parameters_matching,
|
||||||
|
)
|
||||||
|
|
||||||
|
adapter_dir = tmp_path / "adapter"
|
||||||
|
lora_cfg = LoraConfig(
|
||||||
|
r=4,
|
||||||
|
lora_alpha=8,
|
||||||
|
target_modules=[],
|
||||||
|
target_parameters=["experts.gate_up_proj", "experts.down_proj"],
|
||||||
|
lora_dropout=0.0,
|
||||||
|
bias="none",
|
||||||
|
)
|
||||||
|
original_inject = BaseTuner._inject_parameters
|
||||||
|
|
||||||
|
# Training phase: quantized model (parametrized branch) with axolotl patch.
|
||||||
|
_moe_load_state["expert_param_order"] = {}
|
||||||
|
patch_peft_target_parameters_matching()
|
||||||
|
try:
|
||||||
|
peft_model = get_peft_model(self._make_quantized_model(), lora_cfg)
|
||||||
|
finally:
|
||||||
|
BaseTuner._inject_parameters = original_inject
|
||||||
|
patch_peft_target_parameters_matching._axolotl_patched = False
|
||||||
|
|
||||||
|
optimizer = torch.optim.SGD(peft_model.parameters(), lr=1e-3)
|
||||||
|
for _ in range(3):
|
||||||
|
peft_model(torch.randn(2, 8)).sum().backward()
|
||||||
|
optimizer.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
peft_model.save_pretrained(str(adapter_dir))
|
||||||
|
|
||||||
|
# Merge with standard PEFT (no axolotl patch) to verify external compatibility.
|
||||||
|
loaded = PeftModel.from_pretrained(self._make_plain_model(), str(adapter_dir))
|
||||||
|
merged = loaded.merge_and_unload()
|
||||||
|
assert merged is not None
|
||||||
|
|||||||
Reference in New Issue
Block a user