diff --git a/src/axolotl/monkeypatch/llama_patch_multipack.py b/src/axolotl/monkeypatch/llama_patch_multipack.py index 29c9bf7a2..e8e7a0409 100644 --- a/src/axolotl/monkeypatch/llama_patch_multipack.py +++ b/src/axolotl/monkeypatch/llama_patch_multipack.py @@ -2,7 +2,13 @@ Patched LlamaAttention to use torch.nn.functional.scaled_dot_product_attention """ +from typing import Optional + +import torch +from transformers.utils import is_torch_bf16_gpu_available + from axolotl.monkeypatch.utils import ( + mask_2d_to_4d, patched_prepare_4d_causal_attention_mask, patched_prepare_4d_causal_attention_mask_for_sdpa, ) @@ -11,10 +17,26 @@ from axolotl.monkeypatch.utils import ( def hijack_llama_prepare_4d_mask(): from transformers import modeling_attn_mask_utils from transformers.models.llama.modeling_llama import LlamaModel + from transformers.models.llama.modeling_llama.LlamaModel import ( + _prepare_4d_causal_attention_mask_with_cache_position, + ) # modeling_llama._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access + # patched_prepare_4d_causal_attention_mask_for_sdpa + # ) + + def llama_patched_prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: Optional[torch.Tensor], + *args, + ): + dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32 + return _prepare_4d_causal_attention_mask_with_cache_position( + mask_2d_to_4d(attention_mask, dtype=dtype), + *args, + ) + LlamaModel._prepare_4d_causal_attention_mask_with_cache_position = ( # pylint: disable=protected-access - patched_prepare_4d_causal_attention_mask_for_sdpa + llama_patched_prepare_4d_causal_attention_mask_with_cache_position ) modeling_attn_mask_utils._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access patched_prepare_4d_causal_attention_mask_for_sdpa