dont keep adpater weights in fp32
This commit is contained in:
@@ -178,6 +178,38 @@ def get_state_dict(self, model, unwrap=True):
|
|||||||
|
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
def cast_lora_module(module):
|
||||||
|
base_layer_dtype = module.base_layer.weight.dtype
|
||||||
|
# Linear4Bit will keep it's bias term in fp32. If the weight dtype is in bf16 we are not able to
|
||||||
|
# wrap this. Therefore we must ensure the bias has the same dtype as the weight
|
||||||
|
if hasattr(module.base_layer, "bias") and module.base_layer.bias is not None:
|
||||||
|
if module.base_layer.weight.dtype != module.base_layer.bias.dtype:
|
||||||
|
log_bias_dtype_mismatch = True
|
||||||
|
module.base_layer.bias.data = module.base_layer.bias.data.to(
|
||||||
|
module.base_layer.weight.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
for active_adapter in module.active_adapters:
|
||||||
|
if module.lora_A:
|
||||||
|
module.lora_A[active_adapter] = module.lora_A[active_adapter].to(base_layer_dtype)
|
||||||
|
if hasattr(module.lora_A[active_adapter], 'bias') and module.lora_A[active_adapter].bias is not None:
|
||||||
|
module.lora_A[active_adapter].bias.data = module.lora_A[active_adapter].bias.data.to(base_layer_dtype)
|
||||||
|
if module.lora_B:
|
||||||
|
module.lora_B[active_adapter] = module.lora_B[active_adapter].to(base_layer_dtype)
|
||||||
|
if hasattr(module.lora_B[active_adapter], 'bias') and module.lora_B[active_adapter].bias is not None:
|
||||||
|
module.lora_B[active_adapter].bias.data = module.lora_B[active_adapter].bias.data.to(base_layer_dtype)
|
||||||
|
if module.lora_embedding_A:
|
||||||
|
module.lora_embedding_A[active_adapter] = module.lora_embedding_A[active_adapter].to(base_layer_dtype)
|
||||||
|
if hasattr(module.lora_embedding_A[active_adapter], 'bias') and module.lora_embedding_A[active_adapter].bias is not None:
|
||||||
|
module.lora_embedding_A[active_adapter].bias.data = module.lora_embedding_A[active_adapter].bias.data.to(base_layer_dtype)
|
||||||
|
if module.lora_embedding_B:
|
||||||
|
module.lora_embedding_B[active_adapter] = module.lora_embedding_B[active_adapter].to(base_layer_dtype)
|
||||||
|
if hasattr(module.lora_embedding_B[active_adapter], 'bias') and module.lora_embedding_B[active_adapter].bias is not None:
|
||||||
|
module.lora_embedding_B[active_adapter].bias.data = module.lora_embedding_B[active_adapter].bias.data.to(base_layer_dtype)
|
||||||
|
if module.lora_magnitude_vector:
|
||||||
|
module.lora_magnitude_vector[active_adapter] = module.lora_magnitude_vector[active_adapter].to(base_layer_dtype)
|
||||||
|
if hasattr(module.lora_magnitude_vector[active_adapter], 'bias') and module.lora_magnitude_vector[active_adapter].bias is not None:
|
||||||
|
module.lora_magnitude_vector[active_adapter].bias.data = module.lora_magnitude_vector[active_adapter].bias.data.to(base_layer_dtype)
|
||||||
|
|
||||||
def _process_lora_module_for_fsdp(module, fsdp2_kwargs):
|
def _process_lora_module_for_fsdp(module, fsdp2_kwargs):
|
||||||
"""Helper function to process LoRA modules for FSDP2."""
|
"""Helper function to process LoRA modules for FSDP2."""
|
||||||
@@ -324,10 +356,11 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
|
|||||||
if auto_wrap_policy is not None:
|
if auto_wrap_policy is not None:
|
||||||
for module in get_module_children_bottom_up(model)[:-1]:
|
for module in get_module_children_bottom_up(model)[:-1]:
|
||||||
if is_peft_model and isinstance(module, LoraLayer):
|
if is_peft_model and isinstance(module, LoraLayer):
|
||||||
module_log_bias_mismatch = _process_lora_module_for_fsdp(
|
cast_lora_module(module)
|
||||||
module, fsdp2_kwargs
|
# module_log_bias_mismatch = _process_lora_module_for_fsdp(
|
||||||
)
|
# module, fsdp2_kwargs
|
||||||
log_bias_dtype_mismatch |= module_log_bias_mismatch
|
# )
|
||||||
|
# log_bias_dtype_mismatch |= module_log_bias_mismatch
|
||||||
if auto_wrap_policy(module) and not isinstance(module, FSDPModule):
|
if auto_wrap_policy(module) and not isinstance(module, FSDPModule):
|
||||||
fully_shard(module, **fsdp2_kwargs)
|
fully_shard(module, **fsdp2_kwargs)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user