undo my stupidity

This commit is contained in:
Sunny Liu
2025-02-02 20:25:53 -05:00
parent b0871c8d3b
commit 9f6c89b12b

View File

@@ -113,7 +113,7 @@ def packed_block_causal_mask(
"""
document_ids = _get_document_ids_from_seq_lens(seq_lens)
batch_size , max_seq_len = document_ids
batch_size , max_seq_len = document_ids.shape
document_ids = document_ids.to("cuda")
# Instead of passing a tensor mask, flex attention requires a mask_mod function