From 6daed7d060d758e6c2ed490a52b855fa18866ed3 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Tue, 9 Sep 2025 17:11:13 +0100 Subject: [PATCH] dont keep adpater weights in fp32 --- src/axolotl/monkeypatch/accelerate/fsdp2.py | 41 +++++++++++++++++++-- 1 file changed, 37 insertions(+), 4 deletions(-) diff --git a/src/axolotl/monkeypatch/accelerate/fsdp2.py b/src/axolotl/monkeypatch/accelerate/fsdp2.py index 3b38a33b7..935184652 100644 --- a/src/axolotl/monkeypatch/accelerate/fsdp2.py +++ b/src/axolotl/monkeypatch/accelerate/fsdp2.py @@ -178,6 +178,38 @@ def get_state_dict(self, model, unwrap=True): return state_dict +def cast_lora_module(module): + base_layer_dtype = module.base_layer.weight.dtype + # Linear4Bit will keep it's bias term in fp32. If the weight dtype is in bf16 we are not able to + # wrap this. Therefore we must ensure the bias has the same dtype as the weight + if hasattr(module.base_layer, "bias") and module.base_layer.bias is not None: + if module.base_layer.weight.dtype != module.base_layer.bias.dtype: + log_bias_dtype_mismatch = True + module.base_layer.bias.data = module.base_layer.bias.data.to( + module.base_layer.weight.dtype + ) + + for active_adapter in module.active_adapters: + if module.lora_A: + module.lora_A[active_adapter] = module.lora_A[active_adapter].to(base_layer_dtype) + if hasattr(module.lora_A[active_adapter], 'bias') and module.lora_A[active_adapter].bias is not None: + module.lora_A[active_adapter].bias.data = module.lora_A[active_adapter].bias.data.to(base_layer_dtype) + if module.lora_B: + module.lora_B[active_adapter] = module.lora_B[active_adapter].to(base_layer_dtype) + if hasattr(module.lora_B[active_adapter], 'bias') and module.lora_B[active_adapter].bias is not None: + module.lora_B[active_adapter].bias.data = module.lora_B[active_adapter].bias.data.to(base_layer_dtype) + if module.lora_embedding_A: + module.lora_embedding_A[active_adapter] = module.lora_embedding_A[active_adapter].to(base_layer_dtype) + if hasattr(module.lora_embedding_A[active_adapter], 'bias') and module.lora_embedding_A[active_adapter].bias is not None: + module.lora_embedding_A[active_adapter].bias.data = module.lora_embedding_A[active_adapter].bias.data.to(base_layer_dtype) + if module.lora_embedding_B: + module.lora_embedding_B[active_adapter] = module.lora_embedding_B[active_adapter].to(base_layer_dtype) + if hasattr(module.lora_embedding_B[active_adapter], 'bias') and module.lora_embedding_B[active_adapter].bias is not None: + module.lora_embedding_B[active_adapter].bias.data = module.lora_embedding_B[active_adapter].bias.data.to(base_layer_dtype) + if module.lora_magnitude_vector: + module.lora_magnitude_vector[active_adapter] = module.lora_magnitude_vector[active_adapter].to(base_layer_dtype) + if hasattr(module.lora_magnitude_vector[active_adapter], 'bias') and module.lora_magnitude_vector[active_adapter].bias is not None: + module.lora_magnitude_vector[active_adapter].bias.data = module.lora_magnitude_vector[active_adapter].bias.data.to(base_layer_dtype) def _process_lora_module_for_fsdp(module, fsdp2_kwargs): """Helper function to process LoRA modules for FSDP2.""" @@ -324,10 +356,11 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module: if auto_wrap_policy is not None: for module in get_module_children_bottom_up(model)[:-1]: if is_peft_model and isinstance(module, LoraLayer): - module_log_bias_mismatch = _process_lora_module_for_fsdp( - module, fsdp2_kwargs - ) - log_bias_dtype_mismatch |= module_log_bias_mismatch + cast_lora_module(module) + # module_log_bias_mismatch = _process_lora_module_for_fsdp( + # module, fsdp2_kwargs + # ) + # log_bias_dtype_mismatch |= module_log_bias_mismatch if auto_wrap_policy(module) and not isinstance(module, FSDPModule): fully_shard(module, **fsdp2_kwargs)