better handling of dora merge on Conv layers in Qwen 3.5 (#3599)
* better handling of dora merge on Conv layers in Qwen 3.5 * address issues from code review * stricter efficient merges for dora since we now have meta model to reference
This commit is contained in:
@@ -115,6 +115,7 @@ def _do_merge_lora_efficient(*, cfg: DictDefault) -> None:
|
|||||||
simulate_nf4_experts=simulate_nf4_experts,
|
simulate_nf4_experts=simulate_nf4_experts,
|
||||||
nf4_blocksize=nf4_blocksize,
|
nf4_blocksize=nf4_blocksize,
|
||||||
nf4_double_quant=nf4_double_quant,
|
nf4_double_quant=nf4_double_quant,
|
||||||
|
trust_remote_code=bool(getattr(cfg, "trust_remote_code", False)),
|
||||||
)
|
)
|
||||||
|
|
||||||
LOG.debug("Memory-efficient LoRA merge completed successfully!")
|
LOG.debug("Memory-efficient LoRA merge completed successfully!")
|
||||||
|
|||||||
@@ -17,6 +17,93 @@ from axolotl.utils.logging import get_logger
|
|||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_layer_type_map(
|
||||||
|
base_model_path: Path, trust_remote_code: bool = False
|
||||||
|
) -> dict[str, str]:
|
||||||
|
"""Build a map of module_name -> layer_type using a meta-device model.
|
||||||
|
|
||||||
|
Instantiates the model architecture on the meta device (zero memory)
|
||||||
|
to inspect which modules are Linear vs Conv1d/Conv2d/Conv3d.
|
||||||
|
This avoids relying on weight tensor ndim heuristics.
|
||||||
|
"""
|
||||||
|
import json as _json
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
from transformers import AutoConfig
|
||||||
|
|
||||||
|
config_path = base_model_path / "config.json"
|
||||||
|
if not config_path.exists():
|
||||||
|
return {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(config_path) as f:
|
||||||
|
model_config = _json.load(f)
|
||||||
|
except (OSError, _json.JSONDecodeError):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
architectures = model_config.get("architectures", [])
|
||||||
|
if not architectures:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
config = AutoConfig.from_pretrained(
|
||||||
|
str(base_model_path), trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
LOG.debug("Could not load config for layer type introspection")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# Determine the right Auto class from architectures
|
||||||
|
from transformers import (
|
||||||
|
AutoModel,
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
)
|
||||||
|
|
||||||
|
auto_classes = [AutoModelForCausalLM, AutoModel]
|
||||||
|
try:
|
||||||
|
from transformers import AutoModelForImageTextToText
|
||||||
|
|
||||||
|
auto_classes.insert(0, AutoModelForImageTextToText)
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
model = None
|
||||||
|
for auto_cls in auto_classes:
|
||||||
|
try:
|
||||||
|
with torch.device("meta"):
|
||||||
|
model = auto_cls.from_config(
|
||||||
|
config, trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
|
break
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
LOG.debug(
|
||||||
|
"Could not instantiate meta model with %s, trying next",
|
||||||
|
auto_cls.__name__,
|
||||||
|
)
|
||||||
|
|
||||||
|
if model is None:
|
||||||
|
LOG.debug("Could not instantiate meta model for layer type introspection")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
layer_types = {}
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if isinstance(module, nn.Conv3d):
|
||||||
|
layer_types[name] = "Conv3d"
|
||||||
|
elif isinstance(module, nn.Conv2d):
|
||||||
|
layer_types[name] = "Conv2d"
|
||||||
|
elif isinstance(module, nn.Conv1d):
|
||||||
|
layer_types[name] = "Conv1d"
|
||||||
|
elif isinstance(module, nn.Linear):
|
||||||
|
layer_types[name] = "Linear"
|
||||||
|
|
||||||
|
del model
|
||||||
|
LOG.debug(
|
||||||
|
f"Layer type map: {len(layer_types)} modules "
|
||||||
|
f"({sum(1 for v in layer_types.values() if 'Conv' in v)} conv layers)"
|
||||||
|
)
|
||||||
|
return layer_types
|
||||||
|
|
||||||
|
|
||||||
def _simulate_nf4_roundtrip(
|
def _simulate_nf4_roundtrip(
|
||||||
tensor: torch.Tensor,
|
tensor: torch.Tensor,
|
||||||
blocksize: Optional[int] = None,
|
blocksize: Optional[int] = None,
|
||||||
@@ -191,6 +278,7 @@ def _build_peft_layer_and_get_delta(
|
|||||||
adapter_name: str = "default",
|
adapter_name: str = "default",
|
||||||
is_param_wrapper: bool = False,
|
is_param_wrapper: bool = False,
|
||||||
magnitude: Optional[torch.Tensor] = None,
|
magnitude: Optional[torch.Tensor] = None,
|
||||||
|
layer_type: Optional[str] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Use PEFT's own layer classes to compute the LoRA delta weight.
|
Use PEFT's own layer classes to compute the LoRA delta weight.
|
||||||
@@ -211,7 +299,7 @@ def _build_peft_layer_and_get_delta(
|
|||||||
out_features = lora_b.shape[0]
|
out_features = lora_b.shape[0]
|
||||||
lora_alpha = lora_config_dict.get("lora_alpha", lora_config_dict.get("r", 1))
|
lora_alpha = lora_config_dict.get("lora_alpha", lora_config_dict.get("r", 1))
|
||||||
use_rslora = bool(lora_config_dict.get("use_rslora", False))
|
use_rslora = bool(lora_config_dict.get("use_rslora", False))
|
||||||
use_dora = bool(lora_config_dict.get("use_dora", False)) and magnitude is not None
|
use_dora = bool(lora_config_dict.get("use_dora", False))
|
||||||
|
|
||||||
if is_param_wrapper:
|
if is_param_wrapper:
|
||||||
from peft.tuners.lora.layer import ParamWrapper
|
from peft.tuners.lora.layer import ParamWrapper
|
||||||
@@ -239,6 +327,77 @@ def _build_peft_layer_and_get_delta(
|
|||||||
)
|
)
|
||||||
layer.lora_A[adapter_name].weight.data = lora_a
|
layer.lora_A[adapter_name].weight.data = lora_a
|
||||||
layer.lora_B[adapter_name].weight.data = lora_b
|
layer.lora_B[adapter_name].weight.data = lora_b
|
||||||
|
return layer.get_delta_weight(adapter_name)
|
||||||
|
elif (
|
||||||
|
layer_type and "Conv" in layer_type or (layer_type is None and lora_a.ndim > 2)
|
||||||
|
):
|
||||||
|
# Conv layer detected via model introspection (or ndim fallback)
|
||||||
|
|
||||||
|
from peft.tuners.lora import layer as peft_lora_layer
|
||||||
|
|
||||||
|
# Determine conv type from layer_type map or fall back to ndim
|
||||||
|
if layer_type and "Conv" in layer_type:
|
||||||
|
conv_type: str = layer_type
|
||||||
|
else:
|
||||||
|
ndim = lora_a.ndim
|
||||||
|
_conv_map = {3: "Conv1d", 4: "Conv2d", 5: "Conv3d"}
|
||||||
|
if ndim not in _conv_map:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported LoRA weight dimensionality {ndim} for conv layer"
|
||||||
|
)
|
||||||
|
conv_type = _conv_map[ndim]
|
||||||
|
LOG.warning(
|
||||||
|
f"Using ndim-based fallback for conv detection (ndim={ndim}). "
|
||||||
|
f"Consider providing layer_type from meta-device introspection."
|
||||||
|
)
|
||||||
|
|
||||||
|
conv_cls_map = {"Conv1d": nn.Conv1d, "Conv2d": nn.Conv2d, "Conv3d": nn.Conv3d}
|
||||||
|
ConvCls = conv_cls_map[conv_type]
|
||||||
|
PeftConvCls = getattr(peft_lora_layer, conv_type)
|
||||||
|
|
||||||
|
# Reconstruct conv parameters from base tensor and lora_a shapes
|
||||||
|
# base_tensor: [out_channels, in_channels/groups, *kernel_size]
|
||||||
|
# lora_a: [r, in_channels/groups, *kernel_size]
|
||||||
|
# lora_b: [out_channels, r, *ones]
|
||||||
|
out_channels = base_tensor.shape[0]
|
||||||
|
in_channels = base_tensor.shape[1]
|
||||||
|
kernel_size = tuple(base_tensor.shape[2:])
|
||||||
|
stride = (1,) * (base_tensor.ndim - 2)
|
||||||
|
padding = (0,) * (base_tensor.ndim - 2)
|
||||||
|
|
||||||
|
base_layer = ConvCls(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
base_layer.weight.data = base_tensor.clone()
|
||||||
|
|
||||||
|
layer = PeftConvCls(
|
||||||
|
base_layer,
|
||||||
|
adapter_name=adapter_name,
|
||||||
|
r=r_total,
|
||||||
|
lora_alpha=lora_alpha,
|
||||||
|
use_rslora=use_rslora,
|
||||||
|
use_dora=use_dora,
|
||||||
|
)
|
||||||
|
layer.lora_A[adapter_name].weight.data = lora_a
|
||||||
|
layer.lora_B[adapter_name].weight.data = lora_b
|
||||||
|
|
||||||
|
if use_dora:
|
||||||
|
if magnitude is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"DoRA merge requires a magnitude vector but none was found "
|
||||||
|
f"for conv layer (adapter={adapter_name}). Check that the "
|
||||||
|
f"adapter checkpoint contains lora_magnitude_vector weights."
|
||||||
|
)
|
||||||
|
mag_layer = layer.lora_magnitude_vector[adapter_name]
|
||||||
|
mag_layer.weight = nn.Parameter(magnitude)
|
||||||
|
layer.merge(adapter_names=[adapter_name])
|
||||||
|
return base_layer.weight.data - base_tensor
|
||||||
|
|
||||||
return layer.get_delta_weight(adapter_name)
|
return layer.get_delta_weight(adapter_name)
|
||||||
else:
|
else:
|
||||||
from peft.tuners.lora.layer import Linear as LoraLinear
|
from peft.tuners.lora.layer import Linear as LoraLinear
|
||||||
@@ -267,6 +426,12 @@ def _build_peft_layer_and_get_delta(
|
|||||||
# DoRA merges magnitude normalization into the weight directly.
|
# DoRA merges magnitude normalization into the weight directly.
|
||||||
# Use PEFT's merge() which handles DoRA internally, then
|
# Use PEFT's merge() which handles DoRA internally, then
|
||||||
# compute the delta as merged_weight - original_weight.
|
# compute the delta as merged_weight - original_weight.
|
||||||
|
if magnitude is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"DoRA merge requires a magnitude vector but none was found "
|
||||||
|
f"for linear layer (adapter={adapter_name}). Check that the "
|
||||||
|
f"adapter checkpoint contains lora_magnitude_vector weights."
|
||||||
|
)
|
||||||
mag_layer = layer.lora_magnitude_vector[adapter_name]
|
mag_layer = layer.lora_magnitude_vector[adapter_name]
|
||||||
mag_layer.weight = nn.Parameter(magnitude)
|
mag_layer.weight = nn.Parameter(magnitude)
|
||||||
layer.merge(adapter_names=[adapter_name])
|
layer.merge(adapter_names=[adapter_name])
|
||||||
@@ -382,6 +547,7 @@ def _merge_tensor_with_lora(
|
|||||||
nf4_double_quant: bool = True,
|
nf4_double_quant: bool = True,
|
||||||
use_dora: bool = False,
|
use_dora: bool = False,
|
||||||
weight_renamings: Optional[Dict[str, str]] = None,
|
weight_renamings: Optional[Dict[str, str]] = None,
|
||||||
|
layer_type_map: Optional[Dict[str, str]] = None,
|
||||||
) -> tuple[torch.Tensor, bool]:
|
) -> tuple[torch.Tensor, bool]:
|
||||||
"""
|
"""
|
||||||
Helper function to merge a single tensor with its corresponding LoRA weights.
|
Helper function to merge a single tensor with its corresponding LoRA weights.
|
||||||
@@ -426,12 +592,30 @@ def _merge_tensor_with_lora(
|
|||||||
if use_dora
|
if use_dora
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Look up layer type from meta-device model introspection
|
||||||
|
_layer_type = None
|
||||||
|
if layer_type_map:
|
||||||
|
mod_path = key.rsplit(".weight", 1)[0] if key.endswith(".weight") else key
|
||||||
|
_layer_type = layer_type_map.get(mod_path)
|
||||||
|
# Try common prefix variations (e.g. with/without "model." prefix)
|
||||||
|
if _layer_type is None:
|
||||||
|
for prefix in [
|
||||||
|
"model.",
|
||||||
|
"model.language_model.",
|
||||||
|
"model.language_model.model.",
|
||||||
|
]:
|
||||||
|
_layer_type = layer_type_map.get(prefix + mod_path)
|
||||||
|
if _layer_type:
|
||||||
|
break
|
||||||
|
|
||||||
delta = _build_peft_layer_and_get_delta(
|
delta = _build_peft_layer_and_get_delta(
|
||||||
lora_a.to(device),
|
lora_a.to(device),
|
||||||
lora_b.to(device),
|
lora_b.to(device),
|
||||||
lora_config_dict,
|
lora_config_dict,
|
||||||
tensor.to(device),
|
tensor.to(device),
|
||||||
magnitude=magnitude.to(device) if magnitude is not None else None,
|
magnitude=magnitude.to(device) if magnitude is not None else None,
|
||||||
|
layer_type=_layer_type,
|
||||||
)
|
)
|
||||||
merged_tensor = (
|
merged_tensor = (
|
||||||
(tensor.to(device).to(torch.float32) + delta.to(torch.float32))
|
(tensor.to(device).to(torch.float32) + delta.to(torch.float32))
|
||||||
@@ -556,6 +740,7 @@ def _fuse_and_unfuse_with_merge(
|
|||||||
nf4_double_quant: bool = True,
|
nf4_double_quant: bool = True,
|
||||||
use_dora: bool = False,
|
use_dora: bool = False,
|
||||||
weight_renamings: Optional[Dict[str, str]] = None,
|
weight_renamings: Optional[Dict[str, str]] = None,
|
||||||
|
layer_type_map: Optional[Dict[str, str]] = None,
|
||||||
) -> tuple[Dict[str, torch.Tensor], int, set]:
|
) -> tuple[Dict[str, torch.Tensor], int, set]:
|
||||||
"""
|
"""
|
||||||
For tensors matching WeightConverter patterns (MoE expert weights):
|
For tensors matching WeightConverter patterns (MoE expert weights):
|
||||||
@@ -696,12 +881,32 @@ def _fuse_and_unfuse_with_merge(
|
|||||||
if use_dora
|
if use_dora
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
# Look up layer type for the fused key
|
||||||
|
_layer_type = None
|
||||||
|
if layer_type_map:
|
||||||
|
mod_path = (
|
||||||
|
fused_key.rsplit(".weight", 1)[0]
|
||||||
|
if fused_key.endswith(".weight")
|
||||||
|
else fused_key
|
||||||
|
)
|
||||||
|
_layer_type = layer_type_map.get(mod_path)
|
||||||
|
if _layer_type is None:
|
||||||
|
for prefix in [
|
||||||
|
"model.",
|
||||||
|
"model.language_model.",
|
||||||
|
"model.language_model.model.",
|
||||||
|
]:
|
||||||
|
_layer_type = layer_type_map.get(prefix + mod_path)
|
||||||
|
if _layer_type:
|
||||||
|
break
|
||||||
|
|
||||||
delta = _build_peft_layer_and_get_delta(
|
delta = _build_peft_layer_and_get_delta(
|
||||||
lora_a.to(device),
|
lora_a.to(device),
|
||||||
lora_b.to(device),
|
lora_b.to(device),
|
||||||
lora_config_dict,
|
lora_config_dict,
|
||||||
fused_tensor.to(device),
|
fused_tensor.to(device),
|
||||||
magnitude=magnitude.to(device) if magnitude is not None else None,
|
magnitude=magnitude.to(device) if magnitude is not None else None,
|
||||||
|
layer_type=_layer_type,
|
||||||
)
|
)
|
||||||
fused_tensor = (
|
fused_tensor = (
|
||||||
(
|
(
|
||||||
@@ -740,6 +945,7 @@ def merge_lora_sharded_efficient(
|
|||||||
simulate_nf4_experts: bool = False,
|
simulate_nf4_experts: bool = False,
|
||||||
nf4_blocksize: Optional[int] = None,
|
nf4_blocksize: Optional[int] = None,
|
||||||
nf4_double_quant: bool = True,
|
nf4_double_quant: bool = True,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Memory-efficient LoRA merging that processes shards individually
|
Memory-efficient LoRA merging that processes shards individually
|
||||||
@@ -750,6 +956,8 @@ def merge_lora_sharded_efficient(
|
|||||||
simulate_nf4_experts: Apply NF4 roundtrip only to MoE expert tensors
|
simulate_nf4_experts: Apply NF4 roundtrip only to MoE expert tensors
|
||||||
(for quantize_moe_experts). Expert tensors are identified by having
|
(for quantize_moe_experts). Expert tensors are identified by having
|
||||||
"expert" in the key name and ndim >= 3.
|
"expert" in the key name and ndim >= 3.
|
||||||
|
trust_remote_code: Whether to trust remote code when loading model
|
||||||
|
config for layer-type introspection. Defaults to False for safety.
|
||||||
"""
|
"""
|
||||||
base_model_path = Path(base_model_path)
|
base_model_path = Path(base_model_path)
|
||||||
lora_adapter_path = Path(lora_adapter_path)
|
lora_adapter_path = Path(lora_adapter_path)
|
||||||
@@ -780,6 +988,10 @@ def merge_lora_sharded_efficient(
|
|||||||
|
|
||||||
use_dora = bool(lora_config_dict.get("use_dora", False))
|
use_dora = bool(lora_config_dict.get("use_dora", False))
|
||||||
|
|
||||||
|
# Build layer type map via meta-device model introspection
|
||||||
|
layer_type_map = _build_layer_type_map(
|
||||||
|
base_model_path, trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
unsupported_methods = []
|
unsupported_methods = []
|
||||||
|
|
||||||
# Check for AdaLoRA (Adaptive LoRA)
|
# Check for AdaLoRA (Adaptive LoRA)
|
||||||
@@ -904,6 +1116,7 @@ def merge_lora_sharded_efficient(
|
|||||||
nf4_double_quant=nf4_double_quant,
|
nf4_double_quant=nf4_double_quant,
|
||||||
use_dora=use_dora,
|
use_dora=use_dora,
|
||||||
weight_renamings=weight_renamings,
|
weight_renamings=weight_renamings,
|
||||||
|
layer_type_map=layer_type_map,
|
||||||
)
|
)
|
||||||
merged_count += fused_merged
|
merged_count += fused_merged
|
||||||
|
|
||||||
@@ -926,6 +1139,7 @@ def merge_lora_sharded_efficient(
|
|||||||
nf4_double_quant=nf4_double_quant,
|
nf4_double_quant=nf4_double_quant,
|
||||||
use_dora=use_dora,
|
use_dora=use_dora,
|
||||||
weight_renamings=weight_renamings,
|
weight_renamings=weight_renamings,
|
||||||
|
layer_type_map=layer_type_map,
|
||||||
)
|
)
|
||||||
merged_tensors[key] = merged_tensor
|
merged_tensors[key] = merged_tensor
|
||||||
if was_merged:
|
if was_merged:
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import json
|
|||||||
import math
|
import math
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -773,8 +774,8 @@ class TestEfficientMerge:
|
|||||||
"v_proj should be unchanged (no LoRA weights for it)"
|
"v_proj should be unchanged (no LoRA weights for it)"
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_dora_missing_magnitude_falls_back(self):
|
def test_dora_missing_magnitude_raises(self):
|
||||||
"""DoRA without magnitude vector falls back to standard LoRA merge."""
|
"""DoRA with missing magnitude vector raises an explicit error."""
|
||||||
hidden = 16
|
hidden = 16
|
||||||
r = 4
|
r = 4
|
||||||
alpha = 8
|
alpha = 8
|
||||||
@@ -791,11 +792,13 @@ class TestEfficientMerge:
|
|||||||
}
|
}
|
||||||
|
|
||||||
config = {"r": r, "lora_alpha": alpha, "use_dora": True}
|
config = {"r": r, "lora_alpha": alpha, "use_dora": True}
|
||||||
merged, was_merged = _merge_tensor_with_lora(
|
with pytest.raises(ValueError, match="DoRA merge requires a magnitude vector"):
|
||||||
base, "layer.proj.weight", lora_state, scale, config, "cpu", use_dora=True
|
_merge_tensor_with_lora(
|
||||||
)
|
base,
|
||||||
assert was_merged
|
"layer.proj.weight",
|
||||||
# No magnitude vector → PEFT creates DoRA layer but with default magnitude,
|
lora_state,
|
||||||
# which produces a result different from plain W + scale * B @ A.
|
scale,
|
||||||
# Just verify it was merged (not unchanged).
|
config,
|
||||||
assert not torch.equal(merged, base)
|
"cpu",
|
||||||
|
use_dora=True,
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user