fix: try match target param properly end with
This commit is contained in:
@@ -177,9 +177,14 @@ class ModelLoader:
|
||||
# Quantize 3D MoE expert nn.Parameter tensors that BnB skips during loading.
|
||||
self.model._moe_experts_quantized = False
|
||||
if self.cfg.adapter in ("qlora", "lora") and self.cfg.load_in_4bit:
|
||||
from axolotl.monkeypatch.moe_quant import quantize_moe_expert_params
|
||||
from axolotl.monkeypatch.moe_quant import (
|
||||
patch_peft_target_parameters_matching,
|
||||
quantize_moe_expert_params,
|
||||
)
|
||||
|
||||
self.model._moe_experts_quantized = quantize_moe_expert_params(self.model)
|
||||
if self.model._moe_experts_quantized:
|
||||
patch_peft_target_parameters_matching()
|
||||
|
||||
PLUGIN_MANAGER.post_model_build(self.cfg, self.model)
|
||||
|
||||
|
||||
@@ -19,6 +19,48 @@ from axolotl.utils.logging import get_logger
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def patch_peft_target_parameters_matching():
|
||||
"""Fix PEFT's _inject_parameters to use suffix matching for parametrized modules.
|
||||
|
||||
PEFT's parametrized-module branch uses exact name match for target_parameters,
|
||||
but the standard branch uses endswith. This means suffix-style paths like
|
||||
"mlp.experts.gate_up_proj" fail to match parametrized modules whose full path
|
||||
is "model.layers.0.mlp.experts.gate_up_proj". This patch makes the parametrized
|
||||
branch consistent with the standard branch.
|
||||
"""
|
||||
from peft.tuners.tuners_utils import BaseTuner
|
||||
|
||||
original_inject = BaseTuner._inject_parameters
|
||||
|
||||
def _patched_inject_parameters(
|
||||
self, peft_config, model, adapter_name, low_cpu_mem_usage
|
||||
):
|
||||
# Patch target_parameters to use full paths for parametrized modules
|
||||
original_targets = list(peft_config.target_parameters)
|
||||
expanded = set(original_targets)
|
||||
|
||||
for module_name, module in model.named_modules():
|
||||
if not hasattr(module, "parametrizations"):
|
||||
continue
|
||||
for target in original_targets:
|
||||
mod_path, _, param_name = target.rpartition(".")
|
||||
if (
|
||||
module_name == mod_path or module_name.endswith("." + mod_path)
|
||||
) and hasattr(module, param_name):
|
||||
expanded.add(f"{module_name}.{param_name}")
|
||||
|
||||
peft_config.target_parameters = sorted(expanded)
|
||||
try:
|
||||
return original_inject(
|
||||
self, peft_config, model, adapter_name, low_cpu_mem_usage
|
||||
)
|
||||
finally:
|
||||
peft_config.target_parameters = original_targets
|
||||
|
||||
BaseTuner._inject_parameters = _patched_inject_parameters
|
||||
LOG.info("Patched PEFT _inject_parameters for parametrized module suffix matching")
|
||||
|
||||
|
||||
def find_unquantized_expert_params(model):
|
||||
"""Find 3D+ nn.Parameter tensors that BnB quantization skipped.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user