fix: handle fsdp2 for paramwrapper dtensor
This commit is contained in:
@@ -182,10 +182,57 @@ def get_state_dict(self, model, unwrap=True):
|
|||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def patch_peft_param_wrapper_for_fsdp2():
|
||||||
|
"""Patch PEFT's _LoraParameterProxy.forward for FSDP2 DTensor compatibility.
|
||||||
|
|
||||||
|
PEFT's ParamWrapper applies LoRA via torch.nn.utils.parametrize, which adds
|
||||||
|
delta_weight to the base weight W inside _LoraParameterProxy.forward().
|
||||||
|
Under FSDP2, W may be a DTensor (from FSDP unshard) while delta_weight is a
|
||||||
|
regular Tensor (or vice versa), causing a RuntimeError on mixed types.
|
||||||
|
|
||||||
|
This patch promotes the non-DTensor operand to match the DTensor's spec
|
||||||
|
using DTensor.from_local(), which is free for Replicate placement (just
|
||||||
|
metadata wrapping, no communication).
|
||||||
|
"""
|
||||||
|
from peft.tuners.lora.layer import _LoraParameterProxy
|
||||||
|
|
||||||
|
if getattr(_LoraParameterProxy, "_axolotl_fsdp2_patched", False):
|
||||||
|
return
|
||||||
|
|
||||||
|
_original_forward = _LoraParameterProxy.forward
|
||||||
|
|
||||||
|
def _patched_forward(self, W):
|
||||||
|
from torch.distributed.tensor import DTensor
|
||||||
|
|
||||||
|
delta = self.delta_weight
|
||||||
|
w_is_dt = isinstance(W, DTensor)
|
||||||
|
d_is_dt = isinstance(delta, DTensor)
|
||||||
|
|
||||||
|
with torch.nn.utils.parametrize.cached():
|
||||||
|
if w_is_dt == d_is_dt:
|
||||||
|
return W + delta
|
||||||
|
if w_is_dt:
|
||||||
|
return W + DTensor.from_local(delta, W.device_mesh, W.placements)
|
||||||
|
return DTensor.from_local(W, delta.device_mesh, delta.placements) + delta
|
||||||
|
|
||||||
|
_LoraParameterProxy.forward = _patched_forward
|
||||||
|
_LoraParameterProxy._axolotl_fsdp2_patched = True
|
||||||
|
LOG.info("Patched PEFT _LoraParameterProxy.forward for FSDP2 DTensor compatibility")
|
||||||
|
|
||||||
|
|
||||||
def _process_lora_module_for_fsdp(module, fsdp2_kwargs):
|
def _process_lora_module_for_fsdp(module, fsdp2_kwargs):
|
||||||
"""Helper function to process LoRA modules for FSDP2."""
|
"""Helper function to process LoRA modules for FSDP2."""
|
||||||
|
from peft.tuners.lora.layer import ParamWrapper
|
||||||
from torch.distributed.fsdp import fully_shard
|
from torch.distributed.fsdp import fully_shard
|
||||||
|
|
||||||
|
# ParamWrapper's lora_A/B are only accessed via .weight in get_delta_weight(),
|
||||||
|
# not through forward(). Independent sharding leaves them as sharded DTensors
|
||||||
|
# outside the unshard context, causing incorrect reshape/einsum results.
|
||||||
|
# The parent decoder layer's FSDP wrapper manages them instead — properly
|
||||||
|
# unsharded (Replicate) during forward.
|
||||||
|
if isinstance(module, ParamWrapper):
|
||||||
|
return False
|
||||||
|
|
||||||
log_bias_dtype_mismatch = False
|
log_bias_dtype_mismatch = False
|
||||||
|
|
||||||
# Linear4Bit will keep it's bias term in fp32. If the weight dtype is in bf16 we are not able to
|
# Linear4Bit will keep it's bias term in fp32. If the weight dtype is in bf16 we are not able to
|
||||||
@@ -327,6 +374,14 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
|
|||||||
|
|
||||||
is_peft_model = isinstance(model, PeftModel)
|
is_peft_model = isinstance(model, PeftModel)
|
||||||
|
|
||||||
|
# Patch PEFT's _LoraParameterProxy for DTensor compatibility if any
|
||||||
|
# ParamWrapper modules exist (used for target_parameters / 3D expert params).
|
||||||
|
if is_peft_model:
|
||||||
|
from peft.tuners.lora.layer import ParamWrapper
|
||||||
|
|
||||||
|
if any(isinstance(m, ParamWrapper) for m in model.modules()):
|
||||||
|
patch_peft_param_wrapper_for_fsdp2()
|
||||||
|
|
||||||
auto_wrap_policy = fsdp2_prepare_auto_wrap_policy(fsdp2_plugin, model)
|
auto_wrap_policy = fsdp2_prepare_auto_wrap_policy(fsdp2_plugin, model)
|
||||||
log_bias_dtype_mismatch = False
|
log_bias_dtype_mismatch = False
|
||||||
if auto_wrap_policy is not None:
|
if auto_wrap_policy is not None:
|
||||||
|
|||||||
Reference in New Issue
Block a user