vanills mask

This commit is contained in:
Sunny Liu
2025-02-01 14:23:37 -05:00
parent 3ed9c117fb
commit 48c3c47071
2 changed files with 12 additions and 9 deletions

View File

@@ -1,10 +1,11 @@
'''
"""
Taken from https://github.com/pytorch/torchtune/blob/main/torchtune/modules/attention_utils.py
'''
"""
from typing import Union
import torch
from torch.nn.attention.flex_attention import BlockMask
from torch.nn.attention.flex_attention import (
BlockMask,
create_block_mask as create_block_causal_mask_flex,
)
@@ -46,7 +47,7 @@ def create_block_causal_mask(seq_lens: list[torch.Tensor]) -> torch.Tensor:
]
batch_block_attn_masks.append(torch.block_diag(*block_attn_masks))
return torch.stack(batch_block_attn_masks)
return torch.stack(batch_block_attn_masks)[:,None,:,:]
def _get_document_ids_from_seq_lens(
@@ -79,6 +80,7 @@ def _get_document_ids_from_seq_lens(
batch_document_ids = torch.stack(batch_document_ids)
return batch_document_ids
def packed_block_causal_mask(
seq_lens: list[torch.Tensor],
) -> _MaskType:
@@ -97,7 +99,7 @@ def packed_block_causal_mask(
Returns:
_MaskType: BlockMask or Tensor if torch version < 2.5.0.
"""
document_ids = _get_document_ids_from_seq_lens(seq_lens)
batch_size, max_seq_len = document_ids.shape
document_ids = document_ids.to("cuda")
@@ -127,5 +129,3 @@ def packed_block_causal_mask(
max_seq_len,
device="cuda",
)

View File

@@ -10,8 +10,11 @@ import torch
from transformers import PreTrainedTokenizerBase
from transformers.utils import PaddingStrategy
from axolotl.monkeypatch.flex_attn import (
create_block_causal_mask,
packed_block_causal_mask,
)
from axolotl.monkeypatch.utils import get_seqlens_from_pos_ids
from axolotl.monkeypatch.flex_attn import create_block_causal_mask, packed_block_causal_mask
@dataclass
@@ -167,7 +170,7 @@ class FlexBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
out_features = [{} for _ in features]
for i, features_ in enumerate(features):
for feature in features_[0].keys():
if feature in {"length" , "attention_mask"}:
if feature in {"length", "attention_mask"}:
continue
else:
arrays = [