From 8e1adc154dcfa03e1a53ff258ee5b6b15b07c0c3 Mon Sep 17 00:00:00 2001 From: Sunny Liu Date: Sun, 2 Feb 2025 20:36:14 -0500 Subject: [PATCH] stuff --- src/axolotl/monkeypatch/flex_attn.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/axolotl/monkeypatch/flex_attn.py b/src/axolotl/monkeypatch/flex_attn.py index 845a6df9e..aa95480e2 100644 --- a/src/axolotl/monkeypatch/flex_attn.py +++ b/src/axolotl/monkeypatch/flex_attn.py @@ -115,6 +115,7 @@ def packed_block_causal_mask( document_ids = _get_document_ids_from_seq_lens(seq_lens) batch_size , max_seq_len = document_ids.shape document_ids = document_ids.to("cuda") + totalseqlens = totalseqlens.to("cuda") # Instead of passing a tensor mask, flex attention requires a mask_mod function # that determines which elements of QK^T should be included in the attention