diff --git a/src/axolotl/monkeypatch/ring_attn/patch.py b/src/axolotl/monkeypatch/ring_attn/patch.py index 41e39e657..9c9ba4553 100644 --- a/src/axolotl/monkeypatch/ring_attn/patch.py +++ b/src/axolotl/monkeypatch/ring_attn/patch.py @@ -15,7 +15,13 @@ from typing import Optional import accelerate import torch import torch.distributed as dist -from transformers.modeling_flash_attention_utils import _flash_supports_window_size + +try: + from transformers.modeling_flash_attention_utils import _flash_supports_window +except ImportError: + from transformers.modeling_flash_attention_utils import ( + _flash_supports_window_size as _flash_supports_window, + ) from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids from axolotl.utils.logging import get_logger @@ -106,7 +112,7 @@ def create_ring_flash_attention_forward( # Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length). use_sliding_windows = ( - _flash_supports_window_size + _flash_supports_window and sliding_window is not None and key_states.shape[1] > sliding_window )