handle refactor upstream for flash attention (#2966)
This commit is contained in:
@@ -15,7 +15,13 @@ from typing import Optional
|
|||||||
import accelerate
|
import accelerate
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from transformers.modeling_flash_attention_utils import _flash_supports_window_size
|
|
||||||
|
try:
|
||||||
|
from transformers.modeling_flash_attention_utils import _flash_supports_window
|
||||||
|
except ImportError:
|
||||||
|
from transformers.modeling_flash_attention_utils import (
|
||||||
|
_flash_supports_window_size as _flash_supports_window,
|
||||||
|
)
|
||||||
|
|
||||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
@@ -106,7 +112,7 @@ def create_ring_flash_attention_forward(
|
|||||||
|
|
||||||
# Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).
|
# Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).
|
||||||
use_sliding_windows = (
|
use_sliding_windows = (
|
||||||
_flash_supports_window_size
|
_flash_supports_window
|
||||||
and sliding_window is not None
|
and sliding_window is not None
|
||||||
and key_states.shape[1] > sliding_window
|
and key_states.shape[1] > sliding_window
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user