more lora handling

This commit is contained in:
Salman Mohammadi
2025-09-12 15:26:12 +00:00
parent 6daed7d060
commit 6874d32e0c

View File

@@ -225,18 +225,37 @@ def _process_lora_module_for_fsdp(module, fsdp2_kwargs):
module.base_layer.bias.data = module.base_layer.bias.data.to(
module.base_layer.weight.dtype
)
for active_adapter in module.active_adapters:
if module.lora_A:
fully_shard(module.lora_A[active_adapter], **fsdp2_kwargs)
if module.lora_B:
fully_shard(module.lora_B[active_adapter], **fsdp2_kwargs)
if module.lora_embedding_A:
fully_shard(module.lora_embedding_A[active_adapter], **fsdp2_kwargs)
if module.lora_embedding_B:
fully_shard(module.lora_embedding_B[active_adapter], **fsdp2_kwargs)
if module.lora_magnitude_vector:
fully_shard(module.lora_magnitude_vector[active_adapter], **fsdp2_kwargs)
fully_shard(module, **fsdp2_kwargs)
module.set_reshard_after_forward(False)
module.set_reshard_after_backward(False)
# for active_adapter in module.active_adapters:
# for adapter_name in [
# "lora_A",
# "lora_B",
# "lora_embedding_A",
# "lora_embedding_B",
# "lora_magnitude_vector",
# ]:
# adapter_module = getattr(module, adapter_name, None)
# # print(adapter_module, adapter_name)
# # torch.distributed.breakpoint()
# if not adapter_module:
# continue
# fsdp_adapter_module = fully_shard(adapter_module[active_adapter], **fsdp2_kwargs)
# # fsdp_adapter_module.unshard()
# fsdp_adapter_module.set_reshard_after_backward(False)
# fsdp_adapter_module.set_reshard_after_forward(False)
# torch.distributed.breakpoint()
# if module.lora_A:
# fully_shard(module.lora_A[active_adapter], **fsdp2_kwargs)
# if module.lora_B:
# fully_shard(module.lora_B[active_adapter], **fsdp2_kwargs)
# if module.lora_embedding_A:
# fully_shard(module.lora_embedding_A[active_adapter], **fsdp2_kwargs)
# if module.lora_embedding_B:
# fully_shard(module.lora_embedding_B[active_adapter], **fsdp2_kwargs)
# if module.lora_magnitude_vector:
# fully_shard(module.lora_magnitude_vector[active_adapter], **fsdp2_kwargs)
return log_bias_dtype_mismatch
@@ -356,11 +375,12 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
if auto_wrap_policy is not None:
for module in get_module_children_bottom_up(model)[:-1]:
if is_peft_model and isinstance(module, LoraLayer):
# torch.distributed.breakpoint()
cast_lora_module(module)
# module_log_bias_mismatch = _process_lora_module_for_fsdp(
# module, fsdp2_kwargs
# )
# log_bias_dtype_mismatch |= module_log_bias_mismatch
module_log_bias_mismatch = _process_lora_module_for_fsdp(
module, fsdp2_kwargs
)
log_bias_dtype_mismatch |= module_log_bias_mismatch
if auto_wrap_policy(module) and not isinstance(module, FSDPModule):
fully_shard(module, **fsdp2_kwargs)
@@ -377,6 +397,9 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
accelerator, model, original_sd, offload_to_cpu=offload_to_cpu
)
# for module in model.named_modules():
# if "Lora" in
if fsdp2_plugin.cpu_ram_efficient_loading and not model_has_params4bit:
# We re-register the buffers, as they may not be in the state_dict
for fqn, buffer_tensor in original_non_persistent_buffers.items():