From 66c3e5a3fd523369b6e1c61925888264e9ab6e64 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 12 Apr 2026 10:57:45 -0400 Subject: [PATCH] better handling of dora merge on Conv layers in Qwen 3.5 (#3599) * better handling of dora merge on Conv layers in Qwen 3.5 * address issues from code review * stricter efficient merges for dora since we now have meta model to reference --- src/axolotl/cli/merge_lora.py | 1 + src/axolotl/cli/utils/lora_merge.py | 216 +++++++++++++++++++++++++++- tests/utils/lora/test_merge_lora.py | 23 +-- 3 files changed, 229 insertions(+), 11 deletions(-) diff --git a/src/axolotl/cli/merge_lora.py b/src/axolotl/cli/merge_lora.py index dae9f317d..00e5303bd 100644 --- a/src/axolotl/cli/merge_lora.py +++ b/src/axolotl/cli/merge_lora.py @@ -115,6 +115,7 @@ def _do_merge_lora_efficient(*, cfg: DictDefault) -> None: simulate_nf4_experts=simulate_nf4_experts, nf4_blocksize=nf4_blocksize, nf4_double_quant=nf4_double_quant, + trust_remote_code=bool(getattr(cfg, "trust_remote_code", False)), ) LOG.debug("Memory-efficient LoRA merge completed successfully!") diff --git a/src/axolotl/cli/utils/lora_merge.py b/src/axolotl/cli/utils/lora_merge.py index 339e41e2d..b7c436aed 100644 --- a/src/axolotl/cli/utils/lora_merge.py +++ b/src/axolotl/cli/utils/lora_merge.py @@ -17,6 +17,93 @@ from axolotl.utils.logging import get_logger LOG = get_logger(__name__) +def _build_layer_type_map( + base_model_path: Path, trust_remote_code: bool = False +) -> dict[str, str]: + """Build a map of module_name -> layer_type using a meta-device model. + + Instantiates the model architecture on the meta device (zero memory) + to inspect which modules are Linear vs Conv1d/Conv2d/Conv3d. + This avoids relying on weight tensor ndim heuristics. + """ + import json as _json + + import torch.nn as nn + from transformers import AutoConfig + + config_path = base_model_path / "config.json" + if not config_path.exists(): + return {} + + try: + with open(config_path) as f: + model_config = _json.load(f) + except (OSError, _json.JSONDecodeError): + return {} + + architectures = model_config.get("architectures", []) + if not architectures: + return {} + + try: + config = AutoConfig.from_pretrained( + str(base_model_path), trust_remote_code=trust_remote_code + ) + except Exception: + LOG.debug("Could not load config for layer type introspection") + return {} + + # Determine the right Auto class from architectures + from transformers import ( + AutoModel, + AutoModelForCausalLM, + ) + + auto_classes = [AutoModelForCausalLM, AutoModel] + try: + from transformers import AutoModelForImageTextToText + + auto_classes.insert(0, AutoModelForImageTextToText) + except ImportError: + pass + + model = None + for auto_cls in auto_classes: + try: + with torch.device("meta"): + model = auto_cls.from_config( + config, trust_remote_code=trust_remote_code + ) + break + except Exception: # noqa: BLE001 + LOG.debug( + "Could not instantiate meta model with %s, trying next", + auto_cls.__name__, + ) + + if model is None: + LOG.debug("Could not instantiate meta model for layer type introspection") + return {} + + layer_types = {} + for name, module in model.named_modules(): + if isinstance(module, nn.Conv3d): + layer_types[name] = "Conv3d" + elif isinstance(module, nn.Conv2d): + layer_types[name] = "Conv2d" + elif isinstance(module, nn.Conv1d): + layer_types[name] = "Conv1d" + elif isinstance(module, nn.Linear): + layer_types[name] = "Linear" + + del model + LOG.debug( + f"Layer type map: {len(layer_types)} modules " + f"({sum(1 for v in layer_types.values() if 'Conv' in v)} conv layers)" + ) + return layer_types + + def _simulate_nf4_roundtrip( tensor: torch.Tensor, blocksize: Optional[int] = None, @@ -191,6 +278,7 @@ def _build_peft_layer_and_get_delta( adapter_name: str = "default", is_param_wrapper: bool = False, magnitude: Optional[torch.Tensor] = None, + layer_type: Optional[str] = None, ) -> torch.Tensor: """ Use PEFT's own layer classes to compute the LoRA delta weight. @@ -211,7 +299,7 @@ def _build_peft_layer_and_get_delta( out_features = lora_b.shape[0] lora_alpha = lora_config_dict.get("lora_alpha", lora_config_dict.get("r", 1)) use_rslora = bool(lora_config_dict.get("use_rslora", False)) - use_dora = bool(lora_config_dict.get("use_dora", False)) and magnitude is not None + use_dora = bool(lora_config_dict.get("use_dora", False)) if is_param_wrapper: from peft.tuners.lora.layer import ParamWrapper @@ -239,6 +327,77 @@ def _build_peft_layer_and_get_delta( ) layer.lora_A[adapter_name].weight.data = lora_a layer.lora_B[adapter_name].weight.data = lora_b + return layer.get_delta_weight(adapter_name) + elif ( + layer_type and "Conv" in layer_type or (layer_type is None and lora_a.ndim > 2) + ): + # Conv layer detected via model introspection (or ndim fallback) + + from peft.tuners.lora import layer as peft_lora_layer + + # Determine conv type from layer_type map or fall back to ndim + if layer_type and "Conv" in layer_type: + conv_type: str = layer_type + else: + ndim = lora_a.ndim + _conv_map = {3: "Conv1d", 4: "Conv2d", 5: "Conv3d"} + if ndim not in _conv_map: + raise ValueError( + f"Unsupported LoRA weight dimensionality {ndim} for conv layer" + ) + conv_type = _conv_map[ndim] + LOG.warning( + f"Using ndim-based fallback for conv detection (ndim={ndim}). " + f"Consider providing layer_type from meta-device introspection." + ) + + conv_cls_map = {"Conv1d": nn.Conv1d, "Conv2d": nn.Conv2d, "Conv3d": nn.Conv3d} + ConvCls = conv_cls_map[conv_type] + PeftConvCls = getattr(peft_lora_layer, conv_type) + + # Reconstruct conv parameters from base tensor and lora_a shapes + # base_tensor: [out_channels, in_channels/groups, *kernel_size] + # lora_a: [r, in_channels/groups, *kernel_size] + # lora_b: [out_channels, r, *ones] + out_channels = base_tensor.shape[0] + in_channels = base_tensor.shape[1] + kernel_size = tuple(base_tensor.shape[2:]) + stride = (1,) * (base_tensor.ndim - 2) + padding = (0,) * (base_tensor.ndim - 2) + + base_layer = ConvCls( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=False, + ) + base_layer.weight.data = base_tensor.clone() + + layer = PeftConvCls( + base_layer, + adapter_name=adapter_name, + r=r_total, + lora_alpha=lora_alpha, + use_rslora=use_rslora, + use_dora=use_dora, + ) + layer.lora_A[adapter_name].weight.data = lora_a + layer.lora_B[adapter_name].weight.data = lora_b + + if use_dora: + if magnitude is None: + raise ValueError( + f"DoRA merge requires a magnitude vector but none was found " + f"for conv layer (adapter={adapter_name}). Check that the " + f"adapter checkpoint contains lora_magnitude_vector weights." + ) + mag_layer = layer.lora_magnitude_vector[adapter_name] + mag_layer.weight = nn.Parameter(magnitude) + layer.merge(adapter_names=[adapter_name]) + return base_layer.weight.data - base_tensor + return layer.get_delta_weight(adapter_name) else: from peft.tuners.lora.layer import Linear as LoraLinear @@ -267,6 +426,12 @@ def _build_peft_layer_and_get_delta( # DoRA merges magnitude normalization into the weight directly. # Use PEFT's merge() which handles DoRA internally, then # compute the delta as merged_weight - original_weight. + if magnitude is None: + raise ValueError( + f"DoRA merge requires a magnitude vector but none was found " + f"for linear layer (adapter={adapter_name}). Check that the " + f"adapter checkpoint contains lora_magnitude_vector weights." + ) mag_layer = layer.lora_magnitude_vector[adapter_name] mag_layer.weight = nn.Parameter(magnitude) layer.merge(adapter_names=[adapter_name]) @@ -382,6 +547,7 @@ def _merge_tensor_with_lora( nf4_double_quant: bool = True, use_dora: bool = False, weight_renamings: Optional[Dict[str, str]] = None, + layer_type_map: Optional[Dict[str, str]] = None, ) -> tuple[torch.Tensor, bool]: """ Helper function to merge a single tensor with its corresponding LoRA weights. @@ -426,12 +592,30 @@ def _merge_tensor_with_lora( if use_dora else None ) + + # Look up layer type from meta-device model introspection + _layer_type = None + if layer_type_map: + mod_path = key.rsplit(".weight", 1)[0] if key.endswith(".weight") else key + _layer_type = layer_type_map.get(mod_path) + # Try common prefix variations (e.g. with/without "model." prefix) + if _layer_type is None: + for prefix in [ + "model.", + "model.language_model.", + "model.language_model.model.", + ]: + _layer_type = layer_type_map.get(prefix + mod_path) + if _layer_type: + break + delta = _build_peft_layer_and_get_delta( lora_a.to(device), lora_b.to(device), lora_config_dict, tensor.to(device), magnitude=magnitude.to(device) if magnitude is not None else None, + layer_type=_layer_type, ) merged_tensor = ( (tensor.to(device).to(torch.float32) + delta.to(torch.float32)) @@ -556,6 +740,7 @@ def _fuse_and_unfuse_with_merge( nf4_double_quant: bool = True, use_dora: bool = False, weight_renamings: Optional[Dict[str, str]] = None, + layer_type_map: Optional[Dict[str, str]] = None, ) -> tuple[Dict[str, torch.Tensor], int, set]: """ For tensors matching WeightConverter patterns (MoE expert weights): @@ -696,12 +881,32 @@ def _fuse_and_unfuse_with_merge( if use_dora else None ) + # Look up layer type for the fused key + _layer_type = None + if layer_type_map: + mod_path = ( + fused_key.rsplit(".weight", 1)[0] + if fused_key.endswith(".weight") + else fused_key + ) + _layer_type = layer_type_map.get(mod_path) + if _layer_type is None: + for prefix in [ + "model.", + "model.language_model.", + "model.language_model.model.", + ]: + _layer_type = layer_type_map.get(prefix + mod_path) + if _layer_type: + break + delta = _build_peft_layer_and_get_delta( lora_a.to(device), lora_b.to(device), lora_config_dict, fused_tensor.to(device), magnitude=magnitude.to(device) if magnitude is not None else None, + layer_type=_layer_type, ) fused_tensor = ( ( @@ -740,6 +945,7 @@ def merge_lora_sharded_efficient( simulate_nf4_experts: bool = False, nf4_blocksize: Optional[int] = None, nf4_double_quant: bool = True, + trust_remote_code: bool = False, ) -> None: """ Memory-efficient LoRA merging that processes shards individually @@ -750,6 +956,8 @@ def merge_lora_sharded_efficient( simulate_nf4_experts: Apply NF4 roundtrip only to MoE expert tensors (for quantize_moe_experts). Expert tensors are identified by having "expert" in the key name and ndim >= 3. + trust_remote_code: Whether to trust remote code when loading model + config for layer-type introspection. Defaults to False for safety. """ base_model_path = Path(base_model_path) lora_adapter_path = Path(lora_adapter_path) @@ -780,6 +988,10 @@ def merge_lora_sharded_efficient( use_dora = bool(lora_config_dict.get("use_dora", False)) + # Build layer type map via meta-device model introspection + layer_type_map = _build_layer_type_map( + base_model_path, trust_remote_code=trust_remote_code + ) unsupported_methods = [] # Check for AdaLoRA (Adaptive LoRA) @@ -904,6 +1116,7 @@ def merge_lora_sharded_efficient( nf4_double_quant=nf4_double_quant, use_dora=use_dora, weight_renamings=weight_renamings, + layer_type_map=layer_type_map, ) merged_count += fused_merged @@ -926,6 +1139,7 @@ def merge_lora_sharded_efficient( nf4_double_quant=nf4_double_quant, use_dora=use_dora, weight_renamings=weight_renamings, + layer_type_map=layer_type_map, ) merged_tensors[key] = merged_tensor if was_merged: diff --git a/tests/utils/lora/test_merge_lora.py b/tests/utils/lora/test_merge_lora.py index e5d7f535d..b66ee8bf4 100644 --- a/tests/utils/lora/test_merge_lora.py +++ b/tests/utils/lora/test_merge_lora.py @@ -2,6 +2,7 @@ import json import math from unittest.mock import Mock, patch +import pytest import safetensors.torch import torch @@ -773,8 +774,8 @@ class TestEfficientMerge: "v_proj should be unchanged (no LoRA weights for it)" ) - def test_dora_missing_magnitude_falls_back(self): - """DoRA without magnitude vector falls back to standard LoRA merge.""" + def test_dora_missing_magnitude_raises(self): + """DoRA with missing magnitude vector raises an explicit error.""" hidden = 16 r = 4 alpha = 8 @@ -791,11 +792,13 @@ class TestEfficientMerge: } config = {"r": r, "lora_alpha": alpha, "use_dora": True} - merged, was_merged = _merge_tensor_with_lora( - base, "layer.proj.weight", lora_state, scale, config, "cpu", use_dora=True - ) - assert was_merged - # No magnitude vector → PEFT creates DoRA layer but with default magnitude, - # which produces a result different from plain W + scale * B @ A. - # Just verify it was merged (not unchanged). - assert not torch.equal(merged, base) + with pytest.raises(ValueError, match="DoRA merge requires a magnitude vector"): + _merge_tensor_with_lora( + base, + "layer.proj.weight", + lora_state, + scale, + config, + "cpu", + use_dora=True, + )