diff --git a/src/axolotl/monkeypatch/flex_attn.py b/src/axolotl/monkeypatch/flex_attn.py index a157906d5..2894dd5c6 100644 --- a/src/axolotl/monkeypatch/flex_attn.py +++ b/src/axolotl/monkeypatch/flex_attn.py @@ -1,10 +1,11 @@ -''' +""" Taken from https://github.com/pytorch/torchtune/blob/main/torchtune/modules/attention_utils.py -''' +""" from typing import Union + import torch +from torch.nn.attention.flex_attention import BlockMask from torch.nn.attention.flex_attention import ( - BlockMask, create_block_mask as create_block_causal_mask_flex, ) @@ -46,7 +47,7 @@ def create_block_causal_mask(seq_lens: list[torch.Tensor]) -> torch.Tensor: ] batch_block_attn_masks.append(torch.block_diag(*block_attn_masks)) - return torch.stack(batch_block_attn_masks) + return torch.stack(batch_block_attn_masks)[:,None,:,:] def _get_document_ids_from_seq_lens( @@ -79,6 +80,7 @@ def _get_document_ids_from_seq_lens( batch_document_ids = torch.stack(batch_document_ids) return batch_document_ids + def packed_block_causal_mask( seq_lens: list[torch.Tensor], ) -> _MaskType: @@ -97,7 +99,7 @@ def packed_block_causal_mask( Returns: _MaskType: BlockMask or Tensor if torch version < 2.5.0. """ - + document_ids = _get_document_ids_from_seq_lens(seq_lens) batch_size, max_seq_len = document_ids.shape document_ids = document_ids.to("cuda") @@ -127,5 +129,3 @@ def packed_block_causal_mask( max_seq_len, device="cuda", ) - - diff --git a/src/axolotl/utils/collators/batching.py b/src/axolotl/utils/collators/batching.py index d449b38ad..0531bd2f6 100644 --- a/src/axolotl/utils/collators/batching.py +++ b/src/axolotl/utils/collators/batching.py @@ -10,8 +10,11 @@ import torch from transformers import PreTrainedTokenizerBase from transformers.utils import PaddingStrategy +from axolotl.monkeypatch.flex_attn import ( + create_block_causal_mask, + packed_block_causal_mask, +) from axolotl.monkeypatch.utils import get_seqlens_from_pos_ids -from axolotl.monkeypatch.flex_attn import create_block_causal_mask, packed_block_causal_mask @dataclass @@ -167,7 +170,7 @@ class FlexBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): out_features = [{} for _ in features] for i, features_ in enumerate(features): for feature in features_[0].keys(): - if feature in {"length" , "attention_mask"}: + if feature in {"length", "attention_mask"}: continue else: arrays = [