BLOCK SIZE
This commit is contained in:
@@ -42,7 +42,7 @@ def create_block_causal_mask(
|
||||
batch_size = len(seq_lens)
|
||||
for sample_idx in range(batch_size):
|
||||
block_attn_masks = [
|
||||
torch.tril(
|
||||
torch.trilu( # torch.tril(
|
||||
torch.ones(seq_len, seq_len, dtype=torch.bool, device=seq_len.device)
|
||||
)
|
||||
for seq_len in seq_lens[sample_idx]
|
||||
@@ -94,7 +94,7 @@ def _get_document_ids_from_seq_lens(
|
||||
|
||||
|
||||
def packed_block_causal_mask(
|
||||
seq_lens: list[torch.Tensor],
|
||||
seq_lens: list[torch.Tensor], max_seq_len: int
|
||||
) -> _MaskType:
|
||||
"""
|
||||
Create a block causal document mask for a batch of packed sequences. If
|
||||
@@ -113,7 +113,7 @@ def packed_block_causal_mask(
|
||||
"""
|
||||
|
||||
document_ids = _get_document_ids_from_seq_lens(seq_lens)
|
||||
batch_size, max_seq_len = document_ids.shape
|
||||
batch_size, _ = document_ids.shape
|
||||
document_ids = document_ids.to("cuda")
|
||||
|
||||
# Instead of passing a tensor mask, flex attention requires a mask_mod function
|
||||
@@ -140,4 +140,5 @@ def packed_block_causal_mask(
|
||||
max_seq_len,
|
||||
max_seq_len,
|
||||
device="cuda",
|
||||
BLOCK_SIZE=512,
|
||||
)
|
||||
|
||||
@@ -180,7 +180,7 @@ class FlexBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
out = super().__call__(out_features, return_tensors=return_tensors)
|
||||
|
||||
collated_seq_lens, max_seq_len = get_seqlens_from_pos_ids(out["position_ids"])
|
||||
out["attention_mask"] = packed_block_causal_mask(collated_seq_lens)
|
||||
out["attention_mask"] = packed_block_causal_mask(collated_seq_lens, max_seq_len)
|
||||
# out["attention_mask"] = create_block_causal_mask(collated_seq_lens, max_seq_len)
|
||||
# raise ValueError(f"{out['attention_mask'].shape}")
|
||||
return out
|
||||
|
||||
Reference in New Issue
Block a user