From 6874d32e0c9749997b627cdf5c5a7029b10646c9 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi <“salman.mohammadi@outlook.com”> Date: Fri, 12 Sep 2025 15:26:12 +0000 Subject: [PATCH] more lora handling --- src/axolotl/monkeypatch/accelerate/fsdp2.py | 55 +++++++++++++++------ 1 file changed, 39 insertions(+), 16 deletions(-) diff --git a/src/axolotl/monkeypatch/accelerate/fsdp2.py b/src/axolotl/monkeypatch/accelerate/fsdp2.py index 935184652..68306a689 100644 --- a/src/axolotl/monkeypatch/accelerate/fsdp2.py +++ b/src/axolotl/monkeypatch/accelerate/fsdp2.py @@ -225,18 +225,37 @@ def _process_lora_module_for_fsdp(module, fsdp2_kwargs): module.base_layer.bias.data = module.base_layer.bias.data.to( module.base_layer.weight.dtype ) - - for active_adapter in module.active_adapters: - if module.lora_A: - fully_shard(module.lora_A[active_adapter], **fsdp2_kwargs) - if module.lora_B: - fully_shard(module.lora_B[active_adapter], **fsdp2_kwargs) - if module.lora_embedding_A: - fully_shard(module.lora_embedding_A[active_adapter], **fsdp2_kwargs) - if module.lora_embedding_B: - fully_shard(module.lora_embedding_B[active_adapter], **fsdp2_kwargs) - if module.lora_magnitude_vector: - fully_shard(module.lora_magnitude_vector[active_adapter], **fsdp2_kwargs) + fully_shard(module, **fsdp2_kwargs) + module.set_reshard_after_forward(False) + module.set_reshard_after_backward(False) + # for active_adapter in module.active_adapters: + # for adapter_name in [ + # "lora_A", + # "lora_B", + # "lora_embedding_A", + # "lora_embedding_B", + # "lora_magnitude_vector", + # ]: + # adapter_module = getattr(module, adapter_name, None) + # # print(adapter_module, adapter_name) + # # torch.distributed.breakpoint() + # if not adapter_module: + # continue + # fsdp_adapter_module = fully_shard(adapter_module[active_adapter], **fsdp2_kwargs) + # # fsdp_adapter_module.unshard() + # fsdp_adapter_module.set_reshard_after_backward(False) + # fsdp_adapter_module.set_reshard_after_forward(False) + # torch.distributed.breakpoint() + # if module.lora_A: + # fully_shard(module.lora_A[active_adapter], **fsdp2_kwargs) + # if module.lora_B: + # fully_shard(module.lora_B[active_adapter], **fsdp2_kwargs) + # if module.lora_embedding_A: + # fully_shard(module.lora_embedding_A[active_adapter], **fsdp2_kwargs) + # if module.lora_embedding_B: + # fully_shard(module.lora_embedding_B[active_adapter], **fsdp2_kwargs) + # if module.lora_magnitude_vector: + # fully_shard(module.lora_magnitude_vector[active_adapter], **fsdp2_kwargs) return log_bias_dtype_mismatch @@ -356,11 +375,12 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module: if auto_wrap_policy is not None: for module in get_module_children_bottom_up(model)[:-1]: if is_peft_model and isinstance(module, LoraLayer): + # torch.distributed.breakpoint() cast_lora_module(module) - # module_log_bias_mismatch = _process_lora_module_for_fsdp( - # module, fsdp2_kwargs - # ) - # log_bias_dtype_mismatch |= module_log_bias_mismatch + module_log_bias_mismatch = _process_lora_module_for_fsdp( + module, fsdp2_kwargs + ) + log_bias_dtype_mismatch |= module_log_bias_mismatch if auto_wrap_policy(module) and not isinstance(module, FSDPModule): fully_shard(module, **fsdp2_kwargs) @@ -377,6 +397,9 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module: accelerator, model, original_sd, offload_to_cpu=offload_to_cpu ) + # for module in model.named_modules(): + # if "Lora" in + if fsdp2_plugin.cpu_ram_efficient_loading and not model_has_params4bit: # We re-register the buffers, as they may not be in the state_dict for fqn, buffer_tensor in original_non_persistent_buffers.items():