stuff
This commit is contained in:
@@ -115,6 +115,7 @@ def packed_block_causal_mask(
|
|||||||
document_ids = _get_document_ids_from_seq_lens(seq_lens)
|
document_ids = _get_document_ids_from_seq_lens(seq_lens)
|
||||||
batch_size , max_seq_len = document_ids.shape
|
batch_size , max_seq_len = document_ids.shape
|
||||||
document_ids = document_ids.to("cuda")
|
document_ids = document_ids.to("cuda")
|
||||||
|
totalseqlens = totalseqlens.to("cuda")
|
||||||
|
|
||||||
# Instead of passing a tensor mask, flex attention requires a mask_mod function
|
# Instead of passing a tensor mask, flex attention requires a mask_mod function
|
||||||
# that determines which elements of QK^T should be included in the attention
|
# that determines which elements of QK^T should be included in the attention
|
||||||
|
|||||||
Reference in New Issue
Block a user