undo my stupidity
This commit is contained in:
@@ -113,7 +113,7 @@ def packed_block_causal_mask(
|
||||
"""
|
||||
|
||||
document_ids = _get_document_ids_from_seq_lens(seq_lens)
|
||||
batch_size , max_seq_len = document_ids
|
||||
batch_size , max_seq_len = document_ids.shape
|
||||
document_ids = document_ids.to("cuda")
|
||||
|
||||
# Instead of passing a tensor mask, flex attention requires a mask_mod function
|
||||
|
||||
Reference in New Issue
Block a user