handle refactor upstream for flash attention (#2966)

This commit is contained in:
Wing Lian
2025-07-22 20:40:04 -04:00
committed by GitHub
parent 208fb7b8e7
commit 93709eb5ce

View File

@@ -15,7 +15,13 @@ from typing import Optional
import accelerate
import torch
import torch.distributed as dist
from transformers.modeling_flash_attention_utils import _flash_supports_window_size
try:
from transformers.modeling_flash_attention_utils import _flash_supports_window
except ImportError:
from transformers.modeling_flash_attention_utils import (
_flash_supports_window_size as _flash_supports_window,
)
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
from axolotl.utils.logging import get_logger
@@ -106,7 +112,7 @@ def create_ring_flash_attention_forward(
# Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).
use_sliding_windows = (
_flash_supports_window_size
_flash_supports_window
and sliding_window is not None
and key_states.shape[1] > sliding_window
)