BLOCK SIZE
This commit is contained in:
@@ -42,7 +42,7 @@ def create_block_causal_mask(
|
|||||||
batch_size = len(seq_lens)
|
batch_size = len(seq_lens)
|
||||||
for sample_idx in range(batch_size):
|
for sample_idx in range(batch_size):
|
||||||
block_attn_masks = [
|
block_attn_masks = [
|
||||||
torch.tril(
|
torch.trilu( # torch.tril(
|
||||||
torch.ones(seq_len, seq_len, dtype=torch.bool, device=seq_len.device)
|
torch.ones(seq_len, seq_len, dtype=torch.bool, device=seq_len.device)
|
||||||
)
|
)
|
||||||
for seq_len in seq_lens[sample_idx]
|
for seq_len in seq_lens[sample_idx]
|
||||||
@@ -94,7 +94,7 @@ def _get_document_ids_from_seq_lens(
|
|||||||
|
|
||||||
|
|
||||||
def packed_block_causal_mask(
|
def packed_block_causal_mask(
|
||||||
seq_lens: list[torch.Tensor],
|
seq_lens: list[torch.Tensor], max_seq_len: int
|
||||||
) -> _MaskType:
|
) -> _MaskType:
|
||||||
"""
|
"""
|
||||||
Create a block causal document mask for a batch of packed sequences. If
|
Create a block causal document mask for a batch of packed sequences. If
|
||||||
@@ -113,7 +113,7 @@ def packed_block_causal_mask(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
document_ids = _get_document_ids_from_seq_lens(seq_lens)
|
document_ids = _get_document_ids_from_seq_lens(seq_lens)
|
||||||
batch_size, max_seq_len = document_ids.shape
|
batch_size, _ = document_ids.shape
|
||||||
document_ids = document_ids.to("cuda")
|
document_ids = document_ids.to("cuda")
|
||||||
|
|
||||||
# Instead of passing a tensor mask, flex attention requires a mask_mod function
|
# Instead of passing a tensor mask, flex attention requires a mask_mod function
|
||||||
@@ -140,4 +140,5 @@ def packed_block_causal_mask(
|
|||||||
max_seq_len,
|
max_seq_len,
|
||||||
max_seq_len,
|
max_seq_len,
|
||||||
device="cuda",
|
device="cuda",
|
||||||
|
BLOCK_SIZE=512,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -180,7 +180,7 @@ class FlexBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
out = super().__call__(out_features, return_tensors=return_tensors)
|
out = super().__call__(out_features, return_tensors=return_tensors)
|
||||||
|
|
||||||
collated_seq_lens, max_seq_len = get_seqlens_from_pos_ids(out["position_ids"])
|
collated_seq_lens, max_seq_len = get_seqlens_from_pos_ids(out["position_ids"])
|
||||||
out["attention_mask"] = packed_block_causal_mask(collated_seq_lens)
|
out["attention_mask"] = packed_block_causal_mask(collated_seq_lens, max_seq_len)
|
||||||
# out["attention_mask"] = create_block_causal_mask(collated_seq_lens, max_seq_len)
|
# out["attention_mask"] = create_block_causal_mask(collated_seq_lens, max_seq_len)
|
||||||
# raise ValueError(f"{out['attention_mask'].shape}")
|
# raise ValueError(f"{out['attention_mask'].shape}")
|
||||||
return out
|
return out
|
||||||
|
|||||||
Reference in New Issue
Block a user