This commit is contained in:
Sunny Liu
2025-01-23 14:56:26 -05:00
parent 555aa5772a
commit 8c34c65181

View File

@@ -32,9 +32,10 @@ def hijack_llama_prepare_4d_mask():
attention_mask: Optional[torch.Tensor], **kwargs attention_mask: Optional[torch.Tensor], **kwargs
): ):
dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32 dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32
return LlamaModel._prepare_4d_causal_attention_mask_with_cache_position( # return LlamaModel._prepare_4d_causal_attention_mask_with_cache_position(
mask_2d_to_4d(attention_mask, dtype=dtype), **kwargs # mask_2d_to_4d(attention_mask, dtype=dtype), **kwargs
) # )
return mask_2d_to_4d(attention_mask, dtype=dtype)
LlamaModel._prepare_4d_causal_attention_mask_with_cache_position = ( # pylint: disable=protected-access LlamaModel._prepare_4d_causal_attention_mask_with_cache_position = ( # pylint: disable=protected-access
llama_patched_prepare_4d_causal_attention_mask_with_cache_position llama_patched_prepare_4d_causal_attention_mask_with_cache_position