This commit is contained in:
Sunny Liu
2025-02-02 20:36:14 -05:00
parent e5b36900e4
commit 8e1adc154d

View File

@@ -115,6 +115,7 @@ def packed_block_causal_mask(
document_ids = _get_document_ids_from_seq_lens(seq_lens)
batch_size , max_seq_len = document_ids.shape
document_ids = document_ids.to("cuda")
totalseqlens = totalseqlens.to("cuda")
# Instead of passing a tensor mask, flex attention requires a mask_mod function
# that determines which elements of QK^T should be included in the attention