Compare commits

...

5 Commits

Author SHA1 Message Date
Salman Mohammadi
a7676af44d hmmm 2025-09-12 18:51:10 +01:00
Salman Mohammadi
52e37077fc Merge branch 'main' into lora_bf16 2025-09-12 18:35:03 +01:00
Salman Mohammadi
850489405b working? 2025-09-12 17:34:41 +00:00
Salman Mohammadi
6874d32e0c more lora handling 2025-09-12 15:26:12 +00:00
Salman Mohammadi
6daed7d060 dont keep adpater weights in fp32 2025-09-09 17:11:13 +01:00

View File

@@ -180,6 +180,38 @@ def get_state_dict(self, model, unwrap=True):
return state_dict return state_dict
def cast_lora_module(module):
base_layer_dtype = module.base_layer.weight.dtype
# Linear4Bit will keep it's bias term in fp32. If the weight dtype is in bf16 we are not able to
# wrap this. Therefore we must ensure the bias has the same dtype as the weight
if hasattr(module.base_layer, "bias") and module.base_layer.bias is not None:
if module.base_layer.weight.dtype != module.base_layer.bias.dtype:
log_bias_dtype_mismatch = True
module.base_layer.bias.data = module.base_layer.bias.data.to(
module.base_layer.weight.dtype
)
for active_adapter in module.active_adapters:
if module.lora_A:
module.lora_A[active_adapter] = module.lora_A[active_adapter].to(base_layer_dtype)
if hasattr(module.lora_A[active_adapter], 'bias') and module.lora_A[active_adapter].bias is not None:
module.lora_A[active_adapter].bias.data = module.lora_A[active_adapter].bias.data.to(base_layer_dtype)
if module.lora_B:
module.lora_B[active_adapter] = module.lora_B[active_adapter].to(base_layer_dtype)
if hasattr(module.lora_B[active_adapter], 'bias') and module.lora_B[active_adapter].bias is not None:
module.lora_B[active_adapter].bias.data = module.lora_B[active_adapter].bias.data.to(base_layer_dtype)
if module.lora_embedding_A:
module.lora_embedding_A[active_adapter] = module.lora_embedding_A[active_adapter].to(base_layer_dtype)
if hasattr(module.lora_embedding_A[active_adapter], 'bias') and module.lora_embedding_A[active_adapter].bias is not None:
module.lora_embedding_A[active_adapter].bias.data = module.lora_embedding_A[active_adapter].bias.data.to(base_layer_dtype)
if module.lora_embedding_B:
module.lora_embedding_B[active_adapter] = module.lora_embedding_B[active_adapter].to(base_layer_dtype)
if hasattr(module.lora_embedding_B[active_adapter], 'bias') and module.lora_embedding_B[active_adapter].bias is not None:
module.lora_embedding_B[active_adapter].bias.data = module.lora_embedding_B[active_adapter].bias.data.to(base_layer_dtype)
if module.lora_magnitude_vector:
module.lora_magnitude_vector[active_adapter] = module.lora_magnitude_vector[active_adapter].to(base_layer_dtype)
if hasattr(module.lora_magnitude_vector[active_adapter], 'bias') and module.lora_magnitude_vector[active_adapter].bias is not None:
module.lora_magnitude_vector[active_adapter].bias.data = module.lora_magnitude_vector[active_adapter].bias.data.to(base_layer_dtype)
def _process_lora_module_for_fsdp(module, fsdp2_kwargs): def _process_lora_module_for_fsdp(module, fsdp2_kwargs):
"""Helper function to process LoRA modules for FSDP2.""" """Helper function to process LoRA modules for FSDP2."""
@@ -195,18 +227,37 @@ def _process_lora_module_for_fsdp(module, fsdp2_kwargs):
module.base_layer.bias.data = module.base_layer.bias.data.to( module.base_layer.bias.data = module.base_layer.bias.data.to(
module.base_layer.weight.dtype module.base_layer.weight.dtype
) )
fully_shard(module, **fsdp2_kwargs)
for active_adapter in module.active_adapters: module.set_reshard_after_forward(False)
if module.lora_A: module.set_reshard_after_backward(False)
fully_shard(module.lora_A[active_adapter], **fsdp2_kwargs) # for active_adapter in module.active_adapters:
if module.lora_B: # for adapter_name in [
fully_shard(module.lora_B[active_adapter], **fsdp2_kwargs) # "lora_A",
if module.lora_embedding_A: # "lora_B",
fully_shard(module.lora_embedding_A[active_adapter], **fsdp2_kwargs) # "lora_embedding_A",
if module.lora_embedding_B: # "lora_embedding_B",
fully_shard(module.lora_embedding_B[active_adapter], **fsdp2_kwargs) # "lora_magnitude_vector",
if module.lora_magnitude_vector: # ]:
fully_shard(module.lora_magnitude_vector[active_adapter], **fsdp2_kwargs) # adapter_module = getattr(module, adapter_name, None)
# # print(adapter_module, adapter_name)
# # torch.distributed.breakpoint()
# if not adapter_module:
# continue
# fsdp_adapter_module = fully_shard(adapter_module[active_adapter], **fsdp2_kwargs)
# # fsdp_adapter_module.unshard()
# fsdp_adapter_module.set_reshard_after_backward(False)
# fsdp_adapter_module.set_reshard_after_forward(False)
# torch.distributed.breakpoint()
# if module.lora_A:
# fully_shard(module.lora_A[active_adapter], **fsdp2_kwargs)
# if module.lora_B:
# fully_shard(module.lora_B[active_adapter], **fsdp2_kwargs)
# if module.lora_embedding_A:
# fully_shard(module.lora_embedding_A[active_adapter], **fsdp2_kwargs)
# if module.lora_embedding_B:
# fully_shard(module.lora_embedding_B[active_adapter], **fsdp2_kwargs)
# if module.lora_magnitude_vector:
# fully_shard(module.lora_magnitude_vector[active_adapter], **fsdp2_kwargs)
return log_bias_dtype_mismatch return log_bias_dtype_mismatch
@@ -320,16 +371,26 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
model.tie_weights() model.tie_weights()
is_peft_model = isinstance(model, PeftModel) is_peft_model = isinstance(model, PeftModel)
# TODO - this doesn't actually do anything
for name, module in model.named_children():
if name == "experts":
# torch.distributed.breakpoint()
for expert in module.children():
# torch.distributed.breakpoint()
print(f"expert: {expert}")
for lora_module in expert.children():
print(f"lora {lora_module}")
# torch.distributed.breakpoint()
cast_lora_module(lora_module)
_process_lora_module_for_fsdp(lora_module, fsdp2_kwargs)
auto_wrap_policy = fsdp2_prepare_auto_wrap_policy(fsdp2_plugin, model) auto_wrap_policy = fsdp2_prepare_auto_wrap_policy(fsdp2_plugin, model)
log_bias_dtype_mismatch = False log_bias_dtype_mismatch = False
if auto_wrap_policy is not None: if auto_wrap_policy is not None:
for module in get_module_children_bottom_up(model)[:-1]: for module in get_module_children_bottom_up(model)[:-1]:
if is_peft_model and isinstance(module, LoraLayer): if is_peft_model and isinstance(module, LoraLayer) and not isinstance(module, FSDPModule):
module_log_bias_mismatch = _process_lora_module_for_fsdp( # torch.distributed.breakpoint()
module, fsdp2_kwargs cast_lora_module(module)
) # torch.distributed.breakpoint()
log_bias_dtype_mismatch |= module_log_bias_mismatch
if auto_wrap_policy(module) and not isinstance(module, FSDPModule): if auto_wrap_policy(module) and not isinstance(module, FSDPModule):
fully_shard(module, **fsdp2_kwargs) fully_shard(module, **fsdp2_kwargs)
@@ -346,6 +407,9 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
accelerator, model, original_sd, offload_to_cpu=offload_to_cpu accelerator, model, original_sd, offload_to_cpu=offload_to_cpu
) )
# for module in model.named_modules():
# if "Lora" in
if fsdp2_plugin.cpu_ram_efficient_loading and not model_has_params4bit: if fsdp2_plugin.cpu_ram_efficient_loading and not model_has_params4bit:
# We re-register the buffers, as they may not be in the state_dict # We re-register the buffers, as they may not be in the state_dict
for fqn, buffer_tensor in original_non_persistent_buffers.items(): for fqn, buffer_tensor in original_non_persistent_buffers.items():