fix: handle fsdp2 for paramwrapper dtensor
This commit is contained in:
@@ -182,10 +182,57 @@ def get_state_dict(self, model, unwrap=True):
|
||||
return state_dict
|
||||
|
||||
|
||||
def patch_peft_param_wrapper_for_fsdp2():
|
||||
"""Patch PEFT's _LoraParameterProxy.forward for FSDP2 DTensor compatibility.
|
||||
|
||||
PEFT's ParamWrapper applies LoRA via torch.nn.utils.parametrize, which adds
|
||||
delta_weight to the base weight W inside _LoraParameterProxy.forward().
|
||||
Under FSDP2, W may be a DTensor (from FSDP unshard) while delta_weight is a
|
||||
regular Tensor (or vice versa), causing a RuntimeError on mixed types.
|
||||
|
||||
This patch promotes the non-DTensor operand to match the DTensor's spec
|
||||
using DTensor.from_local(), which is free for Replicate placement (just
|
||||
metadata wrapping, no communication).
|
||||
"""
|
||||
from peft.tuners.lora.layer import _LoraParameterProxy
|
||||
|
||||
if getattr(_LoraParameterProxy, "_axolotl_fsdp2_patched", False):
|
||||
return
|
||||
|
||||
_original_forward = _LoraParameterProxy.forward
|
||||
|
||||
def _patched_forward(self, W):
|
||||
from torch.distributed.tensor import DTensor
|
||||
|
||||
delta = self.delta_weight
|
||||
w_is_dt = isinstance(W, DTensor)
|
||||
d_is_dt = isinstance(delta, DTensor)
|
||||
|
||||
with torch.nn.utils.parametrize.cached():
|
||||
if w_is_dt == d_is_dt:
|
||||
return W + delta
|
||||
if w_is_dt:
|
||||
return W + DTensor.from_local(delta, W.device_mesh, W.placements)
|
||||
return DTensor.from_local(W, delta.device_mesh, delta.placements) + delta
|
||||
|
||||
_LoraParameterProxy.forward = _patched_forward
|
||||
_LoraParameterProxy._axolotl_fsdp2_patched = True
|
||||
LOG.info("Patched PEFT _LoraParameterProxy.forward for FSDP2 DTensor compatibility")
|
||||
|
||||
|
||||
def _process_lora_module_for_fsdp(module, fsdp2_kwargs):
|
||||
"""Helper function to process LoRA modules for FSDP2."""
|
||||
from peft.tuners.lora.layer import ParamWrapper
|
||||
from torch.distributed.fsdp import fully_shard
|
||||
|
||||
# ParamWrapper's lora_A/B are only accessed via .weight in get_delta_weight(),
|
||||
# not through forward(). Independent sharding leaves them as sharded DTensors
|
||||
# outside the unshard context, causing incorrect reshape/einsum results.
|
||||
# The parent decoder layer's FSDP wrapper manages them instead — properly
|
||||
# unsharded (Replicate) during forward.
|
||||
if isinstance(module, ParamWrapper):
|
||||
return False
|
||||
|
||||
log_bias_dtype_mismatch = False
|
||||
|
||||
# Linear4Bit will keep it's bias term in fp32. If the weight dtype is in bf16 we are not able to
|
||||
@@ -327,6 +374,14 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
|
||||
|
||||
is_peft_model = isinstance(model, PeftModel)
|
||||
|
||||
# Patch PEFT's _LoraParameterProxy for DTensor compatibility if any
|
||||
# ParamWrapper modules exist (used for target_parameters / 3D expert params).
|
||||
if is_peft_model:
|
||||
from peft.tuners.lora.layer import ParamWrapper
|
||||
|
||||
if any(isinstance(m, ParamWrapper) for m in model.modules()):
|
||||
patch_peft_param_wrapper_for_fsdp2()
|
||||
|
||||
auto_wrap_policy = fsdp2_prepare_auto_wrap_policy(fsdp2_plugin, model)
|
||||
log_bias_dtype_mismatch = False
|
||||
if auto_wrap_policy is not None:
|
||||
|
||||
Reference in New Issue
Block a user