llama sdpa patching WIP

This commit is contained in:
Sunny Liu
2025-01-22 20:16:27 -05:00
parent cee310dcfa
commit b7deb5241c

View File

@@ -4,7 +4,6 @@ Patched LlamaAttention to use torch.nn.functional.scaled_dot_product_attention
from axolotl.monkeypatch.utils import (
mask_2d_to_4d,
patched_prepare_4d_causal_attention_mask,
patched_prepare_4d_causal_attention_mask_for_sdpa,
)
@@ -21,6 +20,8 @@ def hijack_llama_prepare_4d_mask():
)
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.monkeypatch.utils import mask_2d_to_4d
# modeling_llama._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access
# patched_prepare_4d_causal_attention_mask_for_sdpa
# )