stuff
This commit is contained in:
@@ -115,6 +115,7 @@ def packed_block_causal_mask(
|
||||
document_ids = _get_document_ids_from_seq_lens(seq_lens)
|
||||
batch_size , max_seq_len = document_ids.shape
|
||||
document_ids = document_ids.to("cuda")
|
||||
totalseqlens = totalseqlens.to("cuda")
|
||||
|
||||
# Instead of passing a tensor mask, flex attention requires a mask_mod function
|
||||
# that determines which elements of QK^T should be included in the attention
|
||||
|
||||
Reference in New Issue
Block a user