vanills mask
This commit is contained in:
@@ -1,10 +1,11 @@
|
||||
'''
|
||||
"""
|
||||
Taken from https://github.com/pytorch/torchtune/blob/main/torchtune/modules/attention_utils.py
|
||||
'''
|
||||
"""
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
from torch.nn.attention.flex_attention import (
|
||||
BlockMask,
|
||||
create_block_mask as create_block_causal_mask_flex,
|
||||
)
|
||||
|
||||
@@ -46,7 +47,7 @@ def create_block_causal_mask(seq_lens: list[torch.Tensor]) -> torch.Tensor:
|
||||
]
|
||||
|
||||
batch_block_attn_masks.append(torch.block_diag(*block_attn_masks))
|
||||
return torch.stack(batch_block_attn_masks)
|
||||
return torch.stack(batch_block_attn_masks)[:,None,:,:]
|
||||
|
||||
|
||||
def _get_document_ids_from_seq_lens(
|
||||
@@ -79,6 +80,7 @@ def _get_document_ids_from_seq_lens(
|
||||
batch_document_ids = torch.stack(batch_document_ids)
|
||||
return batch_document_ids
|
||||
|
||||
|
||||
def packed_block_causal_mask(
|
||||
seq_lens: list[torch.Tensor],
|
||||
) -> _MaskType:
|
||||
@@ -97,7 +99,7 @@ def packed_block_causal_mask(
|
||||
Returns:
|
||||
_MaskType: BlockMask or Tensor if torch version < 2.5.0.
|
||||
"""
|
||||
|
||||
|
||||
document_ids = _get_document_ids_from_seq_lens(seq_lens)
|
||||
batch_size, max_seq_len = document_ids.shape
|
||||
document_ids = document_ids.to("cuda")
|
||||
@@ -127,5 +129,3 @@ def packed_block_causal_mask(
|
||||
max_seq_len,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -10,8 +10,11 @@ import torch
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from transformers.utils import PaddingStrategy
|
||||
|
||||
from axolotl.monkeypatch.flex_attn import (
|
||||
create_block_causal_mask,
|
||||
packed_block_causal_mask,
|
||||
)
|
||||
from axolotl.monkeypatch.utils import get_seqlens_from_pos_ids
|
||||
from axolotl.monkeypatch.flex_attn import create_block_causal_mask, packed_block_causal_mask
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -167,7 +170,7 @@ class FlexBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
out_features = [{} for _ in features]
|
||||
for i, features_ in enumerate(features):
|
||||
for feature in features_[0].keys():
|
||||
if feature in {"length" , "attention_mask"}:
|
||||
if feature in {"length", "attention_mask"}:
|
||||
continue
|
||||
else:
|
||||
arrays = [
|
||||
|
||||
Reference in New Issue
Block a user