BLOCK SIZE

This commit is contained in:
bursteratom
2025-02-02 01:22:23 -05:00
parent b832b11c8f
commit e98581f6f5
2 changed files with 5 additions and 4 deletions

View File

@@ -42,7 +42,7 @@ def create_block_causal_mask(
batch_size = len(seq_lens)
for sample_idx in range(batch_size):
block_attn_masks = [
torch.tril(
torch.trilu( # torch.tril(
torch.ones(seq_len, seq_len, dtype=torch.bool, device=seq_len.device)
)
for seq_len in seq_lens[sample_idx]
@@ -94,7 +94,7 @@ def _get_document_ids_from_seq_lens(
def packed_block_causal_mask(
seq_lens: list[torch.Tensor],
seq_lens: list[torch.Tensor], max_seq_len: int
) -> _MaskType:
"""
Create a block causal document mask for a batch of packed sequences. If
@@ -113,7 +113,7 @@ def packed_block_causal_mask(
"""
document_ids = _get_document_ids_from_seq_lens(seq_lens)
batch_size, max_seq_len = document_ids.shape
batch_size, _ = document_ids.shape
document_ids = document_ids.to("cuda")
# Instead of passing a tensor mask, flex attention requires a mask_mod function
@@ -140,4 +140,5 @@ def packed_block_causal_mask(
max_seq_len,
max_seq_len,
device="cuda",
BLOCK_SIZE=512,
)

View File

@@ -180,7 +180,7 @@ class FlexBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
out = super().__call__(out_features, return_tensors=return_tensors)
collated_seq_lens, max_seq_len = get_seqlens_from_pos_ids(out["position_ids"])
out["attention_mask"] = packed_block_causal_mask(collated_seq_lens)
out["attention_mask"] = packed_block_causal_mask(collated_seq_lens, max_seq_len)
# out["attention_mask"] = create_block_causal_mask(collated_seq_lens, max_seq_len)
# raise ValueError(f"{out['attention_mask'].shape}")
return out