diff --git a/src/axolotl/monkeypatch/flex_attn.py b/src/axolotl/monkeypatch/flex_attn.py index bd0104e31..97db4d414 100644 --- a/src/axolotl/monkeypatch/flex_attn.py +++ b/src/axolotl/monkeypatch/flex_attn.py @@ -1,15 +1,17 @@ ''' Taken from https://github.com/pytorch/torchtune/blob/main/torchtune/modules/attention_utils.py ''' +from typing import Union import torch from torch.nn.attention.flex_attention import ( BlockMask, create_block_mask as create_block_causal_mask_flex, - flex_attention, ) +_MaskType = Union[torch.Tensor, BlockMask] + def _get_document_ids_from_seq_lens( - seq_lens: List[torch.Tensor], + seq_lens: list[torch.Tensor], ) -> torch.Tensor: """ Convert a batch tensor of seq lens into integer IDs denoting sample ownership. @@ -39,7 +41,7 @@ def _get_document_ids_from_seq_lens( return batch_document_ids def packed_block_causal_mask( - seq_lens: List[torch.Tensor], + seq_lens: list[torch.Tensor], ) -> _MaskType: """ Create a block causal document mask for a batch of packed sequences. If