diff --git a/src/axolotl/monkeypatch/flex_attn.py b/src/axolotl/monkeypatch/flex_attn.py index 9c9aed055..845a6df9e 100644 --- a/src/axolotl/monkeypatch/flex_attn.py +++ b/src/axolotl/monkeypatch/flex_attn.py @@ -113,7 +113,7 @@ def packed_block_causal_mask( """ document_ids = _get_document_ids_from_seq_lens(seq_lens) - batch_size , max_seq_len = document_ids + batch_size , max_seq_len = document_ids.shape document_ids = document_ids.to("cuda") # Instead of passing a tensor mask, flex attention requires a mask_mod function