better handling of dora merge on Conv layers in Qwen 3.5 (#3599)

* better handling of dora merge on Conv layers in Qwen 3.5

* address issues from code review

* stricter efficient merges for dora since we now have meta model to reference
This commit is contained in:
Wing Lian
2026-04-12 10:57:45 -04:00
committed by GitHub
parent b8358aa5ab
commit 66c3e5a3fd
3 changed files with 229 additions and 11 deletions

View File

@@ -115,6 +115,7 @@ def _do_merge_lora_efficient(*, cfg: DictDefault) -> None:
simulate_nf4_experts=simulate_nf4_experts,
nf4_blocksize=nf4_blocksize,
nf4_double_quant=nf4_double_quant,
trust_remote_code=bool(getattr(cfg, "trust_remote_code", False)),
)
LOG.debug("Memory-efficient LoRA merge completed successfully!")

View File

@@ -17,6 +17,93 @@ from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def _build_layer_type_map(
base_model_path: Path, trust_remote_code: bool = False
) -> dict[str, str]:
"""Build a map of module_name -> layer_type using a meta-device model.
Instantiates the model architecture on the meta device (zero memory)
to inspect which modules are Linear vs Conv1d/Conv2d/Conv3d.
This avoids relying on weight tensor ndim heuristics.
"""
import json as _json
import torch.nn as nn
from transformers import AutoConfig
config_path = base_model_path / "config.json"
if not config_path.exists():
return {}
try:
with open(config_path) as f:
model_config = _json.load(f)
except (OSError, _json.JSONDecodeError):
return {}
architectures = model_config.get("architectures", [])
if not architectures:
return {}
try:
config = AutoConfig.from_pretrained(
str(base_model_path), trust_remote_code=trust_remote_code
)
except Exception:
LOG.debug("Could not load config for layer type introspection")
return {}
# Determine the right Auto class from architectures
from transformers import (
AutoModel,
AutoModelForCausalLM,
)
auto_classes = [AutoModelForCausalLM, AutoModel]
try:
from transformers import AutoModelForImageTextToText
auto_classes.insert(0, AutoModelForImageTextToText)
except ImportError:
pass
model = None
for auto_cls in auto_classes:
try:
with torch.device("meta"):
model = auto_cls.from_config(
config, trust_remote_code=trust_remote_code
)
break
except Exception: # noqa: BLE001
LOG.debug(
"Could not instantiate meta model with %s, trying next",
auto_cls.__name__,
)
if model is None:
LOG.debug("Could not instantiate meta model for layer type introspection")
return {}
layer_types = {}
for name, module in model.named_modules():
if isinstance(module, nn.Conv3d):
layer_types[name] = "Conv3d"
elif isinstance(module, nn.Conv2d):
layer_types[name] = "Conv2d"
elif isinstance(module, nn.Conv1d):
layer_types[name] = "Conv1d"
elif isinstance(module, nn.Linear):
layer_types[name] = "Linear"
del model
LOG.debug(
f"Layer type map: {len(layer_types)} modules "
f"({sum(1 for v in layer_types.values() if 'Conv' in v)} conv layers)"
)
return layer_types
def _simulate_nf4_roundtrip(
tensor: torch.Tensor,
blocksize: Optional[int] = None,
@@ -191,6 +278,7 @@ def _build_peft_layer_and_get_delta(
adapter_name: str = "default",
is_param_wrapper: bool = False,
magnitude: Optional[torch.Tensor] = None,
layer_type: Optional[str] = None,
) -> torch.Tensor:
"""
Use PEFT's own layer classes to compute the LoRA delta weight.
@@ -211,7 +299,7 @@ def _build_peft_layer_and_get_delta(
out_features = lora_b.shape[0]
lora_alpha = lora_config_dict.get("lora_alpha", lora_config_dict.get("r", 1))
use_rslora = bool(lora_config_dict.get("use_rslora", False))
use_dora = bool(lora_config_dict.get("use_dora", False)) and magnitude is not None
use_dora = bool(lora_config_dict.get("use_dora", False))
if is_param_wrapper:
from peft.tuners.lora.layer import ParamWrapper
@@ -239,6 +327,77 @@ def _build_peft_layer_and_get_delta(
)
layer.lora_A[adapter_name].weight.data = lora_a
layer.lora_B[adapter_name].weight.data = lora_b
return layer.get_delta_weight(adapter_name)
elif (
layer_type and "Conv" in layer_type or (layer_type is None and lora_a.ndim > 2)
):
# Conv layer detected via model introspection (or ndim fallback)
from peft.tuners.lora import layer as peft_lora_layer
# Determine conv type from layer_type map or fall back to ndim
if layer_type and "Conv" in layer_type:
conv_type: str = layer_type
else:
ndim = lora_a.ndim
_conv_map = {3: "Conv1d", 4: "Conv2d", 5: "Conv3d"}
if ndim not in _conv_map:
raise ValueError(
f"Unsupported LoRA weight dimensionality {ndim} for conv layer"
)
conv_type = _conv_map[ndim]
LOG.warning(
f"Using ndim-based fallback for conv detection (ndim={ndim}). "
f"Consider providing layer_type from meta-device introspection."
)
conv_cls_map = {"Conv1d": nn.Conv1d, "Conv2d": nn.Conv2d, "Conv3d": nn.Conv3d}
ConvCls = conv_cls_map[conv_type]
PeftConvCls = getattr(peft_lora_layer, conv_type)
# Reconstruct conv parameters from base tensor and lora_a shapes
# base_tensor: [out_channels, in_channels/groups, *kernel_size]
# lora_a: [r, in_channels/groups, *kernel_size]
# lora_b: [out_channels, r, *ones]
out_channels = base_tensor.shape[0]
in_channels = base_tensor.shape[1]
kernel_size = tuple(base_tensor.shape[2:])
stride = (1,) * (base_tensor.ndim - 2)
padding = (0,) * (base_tensor.ndim - 2)
base_layer = ConvCls(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
bias=False,
)
base_layer.weight.data = base_tensor.clone()
layer = PeftConvCls(
base_layer,
adapter_name=adapter_name,
r=r_total,
lora_alpha=lora_alpha,
use_rslora=use_rslora,
use_dora=use_dora,
)
layer.lora_A[adapter_name].weight.data = lora_a
layer.lora_B[adapter_name].weight.data = lora_b
if use_dora:
if magnitude is None:
raise ValueError(
f"DoRA merge requires a magnitude vector but none was found "
f"for conv layer (adapter={adapter_name}). Check that the "
f"adapter checkpoint contains lora_magnitude_vector weights."
)
mag_layer = layer.lora_magnitude_vector[adapter_name]
mag_layer.weight = nn.Parameter(magnitude)
layer.merge(adapter_names=[adapter_name])
return base_layer.weight.data - base_tensor
return layer.get_delta_weight(adapter_name)
else:
from peft.tuners.lora.layer import Linear as LoraLinear
@@ -267,6 +426,12 @@ def _build_peft_layer_and_get_delta(
# DoRA merges magnitude normalization into the weight directly.
# Use PEFT's merge() which handles DoRA internally, then
# compute the delta as merged_weight - original_weight.
if magnitude is None:
raise ValueError(
f"DoRA merge requires a magnitude vector but none was found "
f"for linear layer (adapter={adapter_name}). Check that the "
f"adapter checkpoint contains lora_magnitude_vector weights."
)
mag_layer = layer.lora_magnitude_vector[adapter_name]
mag_layer.weight = nn.Parameter(magnitude)
layer.merge(adapter_names=[adapter_name])
@@ -382,6 +547,7 @@ def _merge_tensor_with_lora(
nf4_double_quant: bool = True,
use_dora: bool = False,
weight_renamings: Optional[Dict[str, str]] = None,
layer_type_map: Optional[Dict[str, str]] = None,
) -> tuple[torch.Tensor, bool]:
"""
Helper function to merge a single tensor with its corresponding LoRA weights.
@@ -426,12 +592,30 @@ def _merge_tensor_with_lora(
if use_dora
else None
)
# Look up layer type from meta-device model introspection
_layer_type = None
if layer_type_map:
mod_path = key.rsplit(".weight", 1)[0] if key.endswith(".weight") else key
_layer_type = layer_type_map.get(mod_path)
# Try common prefix variations (e.g. with/without "model." prefix)
if _layer_type is None:
for prefix in [
"model.",
"model.language_model.",
"model.language_model.model.",
]:
_layer_type = layer_type_map.get(prefix + mod_path)
if _layer_type:
break
delta = _build_peft_layer_and_get_delta(
lora_a.to(device),
lora_b.to(device),
lora_config_dict,
tensor.to(device),
magnitude=magnitude.to(device) if magnitude is not None else None,
layer_type=_layer_type,
)
merged_tensor = (
(tensor.to(device).to(torch.float32) + delta.to(torch.float32))
@@ -556,6 +740,7 @@ def _fuse_and_unfuse_with_merge(
nf4_double_quant: bool = True,
use_dora: bool = False,
weight_renamings: Optional[Dict[str, str]] = None,
layer_type_map: Optional[Dict[str, str]] = None,
) -> tuple[Dict[str, torch.Tensor], int, set]:
"""
For tensors matching WeightConverter patterns (MoE expert weights):
@@ -696,12 +881,32 @@ def _fuse_and_unfuse_with_merge(
if use_dora
else None
)
# Look up layer type for the fused key
_layer_type = None
if layer_type_map:
mod_path = (
fused_key.rsplit(".weight", 1)[0]
if fused_key.endswith(".weight")
else fused_key
)
_layer_type = layer_type_map.get(mod_path)
if _layer_type is None:
for prefix in [
"model.",
"model.language_model.",
"model.language_model.model.",
]:
_layer_type = layer_type_map.get(prefix + mod_path)
if _layer_type:
break
delta = _build_peft_layer_and_get_delta(
lora_a.to(device),
lora_b.to(device),
lora_config_dict,
fused_tensor.to(device),
magnitude=magnitude.to(device) if magnitude is not None else None,
layer_type=_layer_type,
)
fused_tensor = (
(
@@ -740,6 +945,7 @@ def merge_lora_sharded_efficient(
simulate_nf4_experts: bool = False,
nf4_blocksize: Optional[int] = None,
nf4_double_quant: bool = True,
trust_remote_code: bool = False,
) -> None:
"""
Memory-efficient LoRA merging that processes shards individually
@@ -750,6 +956,8 @@ def merge_lora_sharded_efficient(
simulate_nf4_experts: Apply NF4 roundtrip only to MoE expert tensors
(for quantize_moe_experts). Expert tensors are identified by having
"expert" in the key name and ndim >= 3.
trust_remote_code: Whether to trust remote code when loading model
config for layer-type introspection. Defaults to False for safety.
"""
base_model_path = Path(base_model_path)
lora_adapter_path = Path(lora_adapter_path)
@@ -780,6 +988,10 @@ def merge_lora_sharded_efficient(
use_dora = bool(lora_config_dict.get("use_dora", False))
# Build layer type map via meta-device model introspection
layer_type_map = _build_layer_type_map(
base_model_path, trust_remote_code=trust_remote_code
)
unsupported_methods = []
# Check for AdaLoRA (Adaptive LoRA)
@@ -904,6 +1116,7 @@ def merge_lora_sharded_efficient(
nf4_double_quant=nf4_double_quant,
use_dora=use_dora,
weight_renamings=weight_renamings,
layer_type_map=layer_type_map,
)
merged_count += fused_merged
@@ -926,6 +1139,7 @@ def merge_lora_sharded_efficient(
nf4_double_quant=nf4_double_quant,
use_dora=use_dora,
weight_renamings=weight_renamings,
layer_type_map=layer_type_map,
)
merged_tensors[key] = merged_tensor
if was_merged:

View File

@@ -2,6 +2,7 @@ import json
import math
from unittest.mock import Mock, patch
import pytest
import safetensors.torch
import torch
@@ -773,8 +774,8 @@ class TestEfficientMerge:
"v_proj should be unchanged (no LoRA weights for it)"
)
def test_dora_missing_magnitude_falls_back(self):
"""DoRA without magnitude vector falls back to standard LoRA merge."""
def test_dora_missing_magnitude_raises(self):
"""DoRA with missing magnitude vector raises an explicit error."""
hidden = 16
r = 4
alpha = 8
@@ -791,11 +792,13 @@ class TestEfficientMerge:
}
config = {"r": r, "lora_alpha": alpha, "use_dora": True}
merged, was_merged = _merge_tensor_with_lora(
base, "layer.proj.weight", lora_state, scale, config, "cpu", use_dora=True
)
assert was_merged
# No magnitude vector → PEFT creates DoRA layer but with default magnitude,
# which produces a result different from plain W + scale * B @ A.
# Just verify it was merged (not unchanged).
assert not torch.equal(merged, base)
with pytest.raises(ValueError, match="DoRA merge requires a magnitude vector"):
_merge_tensor_with_lora(
base,
"layer.proj.weight",
lora_state,
scale,
config,
"cpu",
use_dora=True,
)