better handling of dora merge on Conv layers in Qwen 3.5 (#3599)
* better handling of dora merge on Conv layers in Qwen 3.5 * address issues from code review * stricter efficient merges for dora since we now have meta model to reference
This commit is contained in:
@@ -115,6 +115,7 @@ def _do_merge_lora_efficient(*, cfg: DictDefault) -> None:
|
||||
simulate_nf4_experts=simulate_nf4_experts,
|
||||
nf4_blocksize=nf4_blocksize,
|
||||
nf4_double_quant=nf4_double_quant,
|
||||
trust_remote_code=bool(getattr(cfg, "trust_remote_code", False)),
|
||||
)
|
||||
|
||||
LOG.debug("Memory-efficient LoRA merge completed successfully!")
|
||||
|
||||
@@ -17,6 +17,93 @@ from axolotl.utils.logging import get_logger
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def _build_layer_type_map(
|
||||
base_model_path: Path, trust_remote_code: bool = False
|
||||
) -> dict[str, str]:
|
||||
"""Build a map of module_name -> layer_type using a meta-device model.
|
||||
|
||||
Instantiates the model architecture on the meta device (zero memory)
|
||||
to inspect which modules are Linear vs Conv1d/Conv2d/Conv3d.
|
||||
This avoids relying on weight tensor ndim heuristics.
|
||||
"""
|
||||
import json as _json
|
||||
|
||||
import torch.nn as nn
|
||||
from transformers import AutoConfig
|
||||
|
||||
config_path = base_model_path / "config.json"
|
||||
if not config_path.exists():
|
||||
return {}
|
||||
|
||||
try:
|
||||
with open(config_path) as f:
|
||||
model_config = _json.load(f)
|
||||
except (OSError, _json.JSONDecodeError):
|
||||
return {}
|
||||
|
||||
architectures = model_config.get("architectures", [])
|
||||
if not architectures:
|
||||
return {}
|
||||
|
||||
try:
|
||||
config = AutoConfig.from_pretrained(
|
||||
str(base_model_path), trust_remote_code=trust_remote_code
|
||||
)
|
||||
except Exception:
|
||||
LOG.debug("Could not load config for layer type introspection")
|
||||
return {}
|
||||
|
||||
# Determine the right Auto class from architectures
|
||||
from transformers import (
|
||||
AutoModel,
|
||||
AutoModelForCausalLM,
|
||||
)
|
||||
|
||||
auto_classes = [AutoModelForCausalLM, AutoModel]
|
||||
try:
|
||||
from transformers import AutoModelForImageTextToText
|
||||
|
||||
auto_classes.insert(0, AutoModelForImageTextToText)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
model = None
|
||||
for auto_cls in auto_classes:
|
||||
try:
|
||||
with torch.device("meta"):
|
||||
model = auto_cls.from_config(
|
||||
config, trust_remote_code=trust_remote_code
|
||||
)
|
||||
break
|
||||
except Exception: # noqa: BLE001
|
||||
LOG.debug(
|
||||
"Could not instantiate meta model with %s, trying next",
|
||||
auto_cls.__name__,
|
||||
)
|
||||
|
||||
if model is None:
|
||||
LOG.debug("Could not instantiate meta model for layer type introspection")
|
||||
return {}
|
||||
|
||||
layer_types = {}
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, nn.Conv3d):
|
||||
layer_types[name] = "Conv3d"
|
||||
elif isinstance(module, nn.Conv2d):
|
||||
layer_types[name] = "Conv2d"
|
||||
elif isinstance(module, nn.Conv1d):
|
||||
layer_types[name] = "Conv1d"
|
||||
elif isinstance(module, nn.Linear):
|
||||
layer_types[name] = "Linear"
|
||||
|
||||
del model
|
||||
LOG.debug(
|
||||
f"Layer type map: {len(layer_types)} modules "
|
||||
f"({sum(1 for v in layer_types.values() if 'Conv' in v)} conv layers)"
|
||||
)
|
||||
return layer_types
|
||||
|
||||
|
||||
def _simulate_nf4_roundtrip(
|
||||
tensor: torch.Tensor,
|
||||
blocksize: Optional[int] = None,
|
||||
@@ -191,6 +278,7 @@ def _build_peft_layer_and_get_delta(
|
||||
adapter_name: str = "default",
|
||||
is_param_wrapper: bool = False,
|
||||
magnitude: Optional[torch.Tensor] = None,
|
||||
layer_type: Optional[str] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Use PEFT's own layer classes to compute the LoRA delta weight.
|
||||
@@ -211,7 +299,7 @@ def _build_peft_layer_and_get_delta(
|
||||
out_features = lora_b.shape[0]
|
||||
lora_alpha = lora_config_dict.get("lora_alpha", lora_config_dict.get("r", 1))
|
||||
use_rslora = bool(lora_config_dict.get("use_rslora", False))
|
||||
use_dora = bool(lora_config_dict.get("use_dora", False)) and magnitude is not None
|
||||
use_dora = bool(lora_config_dict.get("use_dora", False))
|
||||
|
||||
if is_param_wrapper:
|
||||
from peft.tuners.lora.layer import ParamWrapper
|
||||
@@ -239,6 +327,77 @@ def _build_peft_layer_and_get_delta(
|
||||
)
|
||||
layer.lora_A[adapter_name].weight.data = lora_a
|
||||
layer.lora_B[adapter_name].weight.data = lora_b
|
||||
return layer.get_delta_weight(adapter_name)
|
||||
elif (
|
||||
layer_type and "Conv" in layer_type or (layer_type is None and lora_a.ndim > 2)
|
||||
):
|
||||
# Conv layer detected via model introspection (or ndim fallback)
|
||||
|
||||
from peft.tuners.lora import layer as peft_lora_layer
|
||||
|
||||
# Determine conv type from layer_type map or fall back to ndim
|
||||
if layer_type and "Conv" in layer_type:
|
||||
conv_type: str = layer_type
|
||||
else:
|
||||
ndim = lora_a.ndim
|
||||
_conv_map = {3: "Conv1d", 4: "Conv2d", 5: "Conv3d"}
|
||||
if ndim not in _conv_map:
|
||||
raise ValueError(
|
||||
f"Unsupported LoRA weight dimensionality {ndim} for conv layer"
|
||||
)
|
||||
conv_type = _conv_map[ndim]
|
||||
LOG.warning(
|
||||
f"Using ndim-based fallback for conv detection (ndim={ndim}). "
|
||||
f"Consider providing layer_type from meta-device introspection."
|
||||
)
|
||||
|
||||
conv_cls_map = {"Conv1d": nn.Conv1d, "Conv2d": nn.Conv2d, "Conv3d": nn.Conv3d}
|
||||
ConvCls = conv_cls_map[conv_type]
|
||||
PeftConvCls = getattr(peft_lora_layer, conv_type)
|
||||
|
||||
# Reconstruct conv parameters from base tensor and lora_a shapes
|
||||
# base_tensor: [out_channels, in_channels/groups, *kernel_size]
|
||||
# lora_a: [r, in_channels/groups, *kernel_size]
|
||||
# lora_b: [out_channels, r, *ones]
|
||||
out_channels = base_tensor.shape[0]
|
||||
in_channels = base_tensor.shape[1]
|
||||
kernel_size = tuple(base_tensor.shape[2:])
|
||||
stride = (1,) * (base_tensor.ndim - 2)
|
||||
padding = (0,) * (base_tensor.ndim - 2)
|
||||
|
||||
base_layer = ConvCls(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
bias=False,
|
||||
)
|
||||
base_layer.weight.data = base_tensor.clone()
|
||||
|
||||
layer = PeftConvCls(
|
||||
base_layer,
|
||||
adapter_name=adapter_name,
|
||||
r=r_total,
|
||||
lora_alpha=lora_alpha,
|
||||
use_rslora=use_rslora,
|
||||
use_dora=use_dora,
|
||||
)
|
||||
layer.lora_A[adapter_name].weight.data = lora_a
|
||||
layer.lora_B[adapter_name].weight.data = lora_b
|
||||
|
||||
if use_dora:
|
||||
if magnitude is None:
|
||||
raise ValueError(
|
||||
f"DoRA merge requires a magnitude vector but none was found "
|
||||
f"for conv layer (adapter={adapter_name}). Check that the "
|
||||
f"adapter checkpoint contains lora_magnitude_vector weights."
|
||||
)
|
||||
mag_layer = layer.lora_magnitude_vector[adapter_name]
|
||||
mag_layer.weight = nn.Parameter(magnitude)
|
||||
layer.merge(adapter_names=[adapter_name])
|
||||
return base_layer.weight.data - base_tensor
|
||||
|
||||
return layer.get_delta_weight(adapter_name)
|
||||
else:
|
||||
from peft.tuners.lora.layer import Linear as LoraLinear
|
||||
@@ -267,6 +426,12 @@ def _build_peft_layer_and_get_delta(
|
||||
# DoRA merges magnitude normalization into the weight directly.
|
||||
# Use PEFT's merge() which handles DoRA internally, then
|
||||
# compute the delta as merged_weight - original_weight.
|
||||
if magnitude is None:
|
||||
raise ValueError(
|
||||
f"DoRA merge requires a magnitude vector but none was found "
|
||||
f"for linear layer (adapter={adapter_name}). Check that the "
|
||||
f"adapter checkpoint contains lora_magnitude_vector weights."
|
||||
)
|
||||
mag_layer = layer.lora_magnitude_vector[adapter_name]
|
||||
mag_layer.weight = nn.Parameter(magnitude)
|
||||
layer.merge(adapter_names=[adapter_name])
|
||||
@@ -382,6 +547,7 @@ def _merge_tensor_with_lora(
|
||||
nf4_double_quant: bool = True,
|
||||
use_dora: bool = False,
|
||||
weight_renamings: Optional[Dict[str, str]] = None,
|
||||
layer_type_map: Optional[Dict[str, str]] = None,
|
||||
) -> tuple[torch.Tensor, bool]:
|
||||
"""
|
||||
Helper function to merge a single tensor with its corresponding LoRA weights.
|
||||
@@ -426,12 +592,30 @@ def _merge_tensor_with_lora(
|
||||
if use_dora
|
||||
else None
|
||||
)
|
||||
|
||||
# Look up layer type from meta-device model introspection
|
||||
_layer_type = None
|
||||
if layer_type_map:
|
||||
mod_path = key.rsplit(".weight", 1)[0] if key.endswith(".weight") else key
|
||||
_layer_type = layer_type_map.get(mod_path)
|
||||
# Try common prefix variations (e.g. with/without "model." prefix)
|
||||
if _layer_type is None:
|
||||
for prefix in [
|
||||
"model.",
|
||||
"model.language_model.",
|
||||
"model.language_model.model.",
|
||||
]:
|
||||
_layer_type = layer_type_map.get(prefix + mod_path)
|
||||
if _layer_type:
|
||||
break
|
||||
|
||||
delta = _build_peft_layer_and_get_delta(
|
||||
lora_a.to(device),
|
||||
lora_b.to(device),
|
||||
lora_config_dict,
|
||||
tensor.to(device),
|
||||
magnitude=magnitude.to(device) if magnitude is not None else None,
|
||||
layer_type=_layer_type,
|
||||
)
|
||||
merged_tensor = (
|
||||
(tensor.to(device).to(torch.float32) + delta.to(torch.float32))
|
||||
@@ -556,6 +740,7 @@ def _fuse_and_unfuse_with_merge(
|
||||
nf4_double_quant: bool = True,
|
||||
use_dora: bool = False,
|
||||
weight_renamings: Optional[Dict[str, str]] = None,
|
||||
layer_type_map: Optional[Dict[str, str]] = None,
|
||||
) -> tuple[Dict[str, torch.Tensor], int, set]:
|
||||
"""
|
||||
For tensors matching WeightConverter patterns (MoE expert weights):
|
||||
@@ -696,12 +881,32 @@ def _fuse_and_unfuse_with_merge(
|
||||
if use_dora
|
||||
else None
|
||||
)
|
||||
# Look up layer type for the fused key
|
||||
_layer_type = None
|
||||
if layer_type_map:
|
||||
mod_path = (
|
||||
fused_key.rsplit(".weight", 1)[0]
|
||||
if fused_key.endswith(".weight")
|
||||
else fused_key
|
||||
)
|
||||
_layer_type = layer_type_map.get(mod_path)
|
||||
if _layer_type is None:
|
||||
for prefix in [
|
||||
"model.",
|
||||
"model.language_model.",
|
||||
"model.language_model.model.",
|
||||
]:
|
||||
_layer_type = layer_type_map.get(prefix + mod_path)
|
||||
if _layer_type:
|
||||
break
|
||||
|
||||
delta = _build_peft_layer_and_get_delta(
|
||||
lora_a.to(device),
|
||||
lora_b.to(device),
|
||||
lora_config_dict,
|
||||
fused_tensor.to(device),
|
||||
magnitude=magnitude.to(device) if magnitude is not None else None,
|
||||
layer_type=_layer_type,
|
||||
)
|
||||
fused_tensor = (
|
||||
(
|
||||
@@ -740,6 +945,7 @@ def merge_lora_sharded_efficient(
|
||||
simulate_nf4_experts: bool = False,
|
||||
nf4_blocksize: Optional[int] = None,
|
||||
nf4_double_quant: bool = True,
|
||||
trust_remote_code: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Memory-efficient LoRA merging that processes shards individually
|
||||
@@ -750,6 +956,8 @@ def merge_lora_sharded_efficient(
|
||||
simulate_nf4_experts: Apply NF4 roundtrip only to MoE expert tensors
|
||||
(for quantize_moe_experts). Expert tensors are identified by having
|
||||
"expert" in the key name and ndim >= 3.
|
||||
trust_remote_code: Whether to trust remote code when loading model
|
||||
config for layer-type introspection. Defaults to False for safety.
|
||||
"""
|
||||
base_model_path = Path(base_model_path)
|
||||
lora_adapter_path = Path(lora_adapter_path)
|
||||
@@ -780,6 +988,10 @@ def merge_lora_sharded_efficient(
|
||||
|
||||
use_dora = bool(lora_config_dict.get("use_dora", False))
|
||||
|
||||
# Build layer type map via meta-device model introspection
|
||||
layer_type_map = _build_layer_type_map(
|
||||
base_model_path, trust_remote_code=trust_remote_code
|
||||
)
|
||||
unsupported_methods = []
|
||||
|
||||
# Check for AdaLoRA (Adaptive LoRA)
|
||||
@@ -904,6 +1116,7 @@ def merge_lora_sharded_efficient(
|
||||
nf4_double_quant=nf4_double_quant,
|
||||
use_dora=use_dora,
|
||||
weight_renamings=weight_renamings,
|
||||
layer_type_map=layer_type_map,
|
||||
)
|
||||
merged_count += fused_merged
|
||||
|
||||
@@ -926,6 +1139,7 @@ def merge_lora_sharded_efficient(
|
||||
nf4_double_quant=nf4_double_quant,
|
||||
use_dora=use_dora,
|
||||
weight_renamings=weight_renamings,
|
||||
layer_type_map=layer_type_map,
|
||||
)
|
||||
merged_tensors[key] = merged_tensor
|
||||
if was_merged:
|
||||
|
||||
@@ -2,6 +2,7 @@ import json
|
||||
import math
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
import safetensors.torch
|
||||
import torch
|
||||
|
||||
@@ -773,8 +774,8 @@ class TestEfficientMerge:
|
||||
"v_proj should be unchanged (no LoRA weights for it)"
|
||||
)
|
||||
|
||||
def test_dora_missing_magnitude_falls_back(self):
|
||||
"""DoRA without magnitude vector falls back to standard LoRA merge."""
|
||||
def test_dora_missing_magnitude_raises(self):
|
||||
"""DoRA with missing magnitude vector raises an explicit error."""
|
||||
hidden = 16
|
||||
r = 4
|
||||
alpha = 8
|
||||
@@ -791,11 +792,13 @@ class TestEfficientMerge:
|
||||
}
|
||||
|
||||
config = {"r": r, "lora_alpha": alpha, "use_dora": True}
|
||||
merged, was_merged = _merge_tensor_with_lora(
|
||||
base, "layer.proj.weight", lora_state, scale, config, "cpu", use_dora=True
|
||||
)
|
||||
assert was_merged
|
||||
# No magnitude vector → PEFT creates DoRA layer but with default magnitude,
|
||||
# which produces a result different from plain W + scale * B @ A.
|
||||
# Just verify it was merged (not unchanged).
|
||||
assert not torch.equal(merged, base)
|
||||
with pytest.raises(ValueError, match="DoRA merge requires a magnitude vector"):
|
||||
_merge_tensor_with_lora(
|
||||
base,
|
||||
"layer.proj.weight",
|
||||
lora_state,
|
||||
scale,
|
||||
config,
|
||||
"cpu",
|
||||
use_dora=True,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user