try vanilla mask
This commit is contained in:
@@ -10,6 +10,45 @@ from torch.nn.attention.flex_attention import (
|
|||||||
|
|
||||||
_MaskType = Union[torch.Tensor, BlockMask]
|
_MaskType = Union[torch.Tensor, BlockMask]
|
||||||
|
|
||||||
|
|
||||||
|
def create_block_causal_mask(seq_lens: list[torch.Tensor]) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Given a batch tensor of seq lens defining the lengths of samples in each pack,
|
||||||
|
Construct a 2D block causal mask for each pack in the batch. For example, if
|
||||||
|
a single sample's seq_lens is [3, 2, 1], the mask would be::
|
||||||
|
|
||||||
|
mask = [
|
||||||
|
[1, 0, 0, 0, 0, 0],
|
||||||
|
[1, 1, 0, 0, 0, 0],
|
||||||
|
[1, 1, 1, 0, 0, 0],
|
||||||
|
[0, 0, 0, 1, 0, 0],
|
||||||
|
[0, 0, 0, 1, 1, 0],
|
||||||
|
[0, 0, 0, 0, 0, 1],
|
||||||
|
]
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seq_lens (List[torch.Tensor]): Sequence lengths of samples in each pack in the batch,
|
||||||
|
shape (batch_size, n), where n is the max number of sequences in a pack and can vary
|
||||||
|
across packs.
|
||||||
|
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Block causal mask of shape (batch_size, max_seq_len, max_seq_len).
|
||||||
|
"""
|
||||||
|
batch_block_attn_masks = []
|
||||||
|
batch_size = len(seq_lens)
|
||||||
|
for sample_idx in range(batch_size):
|
||||||
|
block_attn_masks = [
|
||||||
|
torch.tril(
|
||||||
|
torch.ones(seq_len, seq_len, dtype=torch.bool, device=seq_len.device)
|
||||||
|
)
|
||||||
|
for i, seq_len in enumerate(seq_lens[sample_idx])
|
||||||
|
]
|
||||||
|
|
||||||
|
batch_block_attn_masks.append(torch.block_diag(*block_attn_masks))
|
||||||
|
return torch.stack(batch_block_attn_masks)
|
||||||
|
|
||||||
|
|
||||||
def _get_document_ids_from_seq_lens(
|
def _get_document_ids_from_seq_lens(
|
||||||
seq_lens: list[torch.Tensor],
|
seq_lens: list[torch.Tensor],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from transformers import PreTrainedTokenizerBase
|
|||||||
from transformers.utils import PaddingStrategy
|
from transformers.utils import PaddingStrategy
|
||||||
|
|
||||||
from axolotl.monkeypatch.utils import get_seqlens_from_pos_ids
|
from axolotl.monkeypatch.utils import get_seqlens_from_pos_ids
|
||||||
from axolotl.monkeypatch.flex_attn import packed_block_causal_mask
|
from axolotl.monkeypatch.flex_attn import create_block_causal_mask, packed_block_causal_mask
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -177,7 +177,8 @@ class FlexBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
out = super().__call__(out_features, return_tensors=return_tensors)
|
out = super().__call__(out_features, return_tensors=return_tensors)
|
||||||
|
|
||||||
collated_seq_lens = get_seqlens_from_pos_ids(out["position_ids"])
|
collated_seq_lens = get_seqlens_from_pos_ids(out["position_ids"])
|
||||||
out["attention_mask"] = packed_block_causal_mask(collated_seq_lens)
|
# out["attention_mask"] = packed_block_causal_mask(collated_seq_lens)
|
||||||
|
out["attention_mask"] = create_block_causal_mask(collated_seq_lens)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user