diff --git a/src/axolotl/monkeypatch/accelerate/fsdp2.py b/src/axolotl/monkeypatch/accelerate/fsdp2.py index af6f24a63..242bcc667 100644 --- a/src/axolotl/monkeypatch/accelerate/fsdp2.py +++ b/src/axolotl/monkeypatch/accelerate/fsdp2.py @@ -182,10 +182,57 @@ def get_state_dict(self, model, unwrap=True): return state_dict +def patch_peft_param_wrapper_for_fsdp2(): + """Patch PEFT's _LoraParameterProxy.forward for FSDP2 DTensor compatibility. + + PEFT's ParamWrapper applies LoRA via torch.nn.utils.parametrize, which adds + delta_weight to the base weight W inside _LoraParameterProxy.forward(). + Under FSDP2, W may be a DTensor (from FSDP unshard) while delta_weight is a + regular Tensor (or vice versa), causing a RuntimeError on mixed types. + + This patch promotes the non-DTensor operand to match the DTensor's spec + using DTensor.from_local(), which is free for Replicate placement (just + metadata wrapping, no communication). + """ + from peft.tuners.lora.layer import _LoraParameterProxy + + if getattr(_LoraParameterProxy, "_axolotl_fsdp2_patched", False): + return + + _original_forward = _LoraParameterProxy.forward + + def _patched_forward(self, W): + from torch.distributed.tensor import DTensor + + delta = self.delta_weight + w_is_dt = isinstance(W, DTensor) + d_is_dt = isinstance(delta, DTensor) + + with torch.nn.utils.parametrize.cached(): + if w_is_dt == d_is_dt: + return W + delta + if w_is_dt: + return W + DTensor.from_local(delta, W.device_mesh, W.placements) + return DTensor.from_local(W, delta.device_mesh, delta.placements) + delta + + _LoraParameterProxy.forward = _patched_forward + _LoraParameterProxy._axolotl_fsdp2_patched = True + LOG.info("Patched PEFT _LoraParameterProxy.forward for FSDP2 DTensor compatibility") + + def _process_lora_module_for_fsdp(module, fsdp2_kwargs): """Helper function to process LoRA modules for FSDP2.""" + from peft.tuners.lora.layer import ParamWrapper from torch.distributed.fsdp import fully_shard + # ParamWrapper's lora_A/B are only accessed via .weight in get_delta_weight(), + # not through forward(). Independent sharding leaves them as sharded DTensors + # outside the unshard context, causing incorrect reshape/einsum results. + # The parent decoder layer's FSDP wrapper manages them instead — properly + # unsharded (Replicate) during forward. + if isinstance(module, ParamWrapper): + return False + log_bias_dtype_mismatch = False # Linear4Bit will keep it's bias term in fp32. If the weight dtype is in bf16 we are not able to @@ -327,6 +374,14 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module: is_peft_model = isinstance(model, PeftModel) + # Patch PEFT's _LoraParameterProxy for DTensor compatibility if any + # ParamWrapper modules exist (used for target_parameters / 3D expert params). + if is_peft_model: + from peft.tuners.lora.layer import ParamWrapper + + if any(isinstance(m, ParamWrapper) for m in model.modules()): + patch_peft_param_wrapper_for_fsdp2() + auto_wrap_policy = fsdp2_prepare_auto_wrap_policy(fsdp2_plugin, model) log_bias_dtype_mismatch = False if auto_wrap_policy is not None: