vanills mask

This commit is contained in:
Sunny Liu
2025-02-01 14:23:37 -05:00
parent 3ed9c117fb
commit 48c3c47071
2 changed files with 12 additions and 9 deletions

View File

@@ -1,10 +1,11 @@
''' """
Taken from https://github.com/pytorch/torchtune/blob/main/torchtune/modules/attention_utils.py Taken from https://github.com/pytorch/torchtune/blob/main/torchtune/modules/attention_utils.py
''' """
from typing import Union from typing import Union
import torch import torch
from torch.nn.attention.flex_attention import BlockMask
from torch.nn.attention.flex_attention import ( from torch.nn.attention.flex_attention import (
BlockMask,
create_block_mask as create_block_causal_mask_flex, create_block_mask as create_block_causal_mask_flex,
) )
@@ -46,7 +47,7 @@ def create_block_causal_mask(seq_lens: list[torch.Tensor]) -> torch.Tensor:
] ]
batch_block_attn_masks.append(torch.block_diag(*block_attn_masks)) batch_block_attn_masks.append(torch.block_diag(*block_attn_masks))
return torch.stack(batch_block_attn_masks) return torch.stack(batch_block_attn_masks)[:,None,:,:]
def _get_document_ids_from_seq_lens( def _get_document_ids_from_seq_lens(
@@ -79,6 +80,7 @@ def _get_document_ids_from_seq_lens(
batch_document_ids = torch.stack(batch_document_ids) batch_document_ids = torch.stack(batch_document_ids)
return batch_document_ids return batch_document_ids
def packed_block_causal_mask( def packed_block_causal_mask(
seq_lens: list[torch.Tensor], seq_lens: list[torch.Tensor],
) -> _MaskType: ) -> _MaskType:
@@ -97,7 +99,7 @@ def packed_block_causal_mask(
Returns: Returns:
_MaskType: BlockMask or Tensor if torch version < 2.5.0. _MaskType: BlockMask or Tensor if torch version < 2.5.0.
""" """
document_ids = _get_document_ids_from_seq_lens(seq_lens) document_ids = _get_document_ids_from_seq_lens(seq_lens)
batch_size, max_seq_len = document_ids.shape batch_size, max_seq_len = document_ids.shape
document_ids = document_ids.to("cuda") document_ids = document_ids.to("cuda")
@@ -127,5 +129,3 @@ def packed_block_causal_mask(
max_seq_len, max_seq_len,
device="cuda", device="cuda",
) )

View File

@@ -10,8 +10,11 @@ import torch
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from transformers.utils import PaddingStrategy from transformers.utils import PaddingStrategy
from axolotl.monkeypatch.flex_attn import (
create_block_causal_mask,
packed_block_causal_mask,
)
from axolotl.monkeypatch.utils import get_seqlens_from_pos_ids from axolotl.monkeypatch.utils import get_seqlens_from_pos_ids
from axolotl.monkeypatch.flex_attn import create_block_causal_mask, packed_block_causal_mask
@dataclass @dataclass
@@ -167,7 +170,7 @@ class FlexBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
out_features = [{} for _ in features] out_features = [{} for _ in features]
for i, features_ in enumerate(features): for i, features_ in enumerate(features):
for feature in features_[0].keys(): for feature in features_[0].keys():
if feature in {"length" , "attention_mask"}: if feature in {"length", "attention_mask"}:
continue continue
else: else:
arrays = [ arrays = [