From 9a43a0925d2a131458d1962b3416e6a54f50b4d6 Mon Sep 17 00:00:00 2001 From: bursteratom Date: Sun, 2 Feb 2025 00:45:30 -0500 Subject: [PATCH] more test --- src/axolotl/monkeypatch/flex_attn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/monkeypatch/flex_attn.py b/src/axolotl/monkeypatch/flex_attn.py index 8089edeae..a1c2de644 100644 --- a/src/axolotl/monkeypatch/flex_attn.py +++ b/src/axolotl/monkeypatch/flex_attn.py @@ -51,7 +51,7 @@ def create_block_causal_mask( residue_len = max_seq_len - torch.sum(seq_lens[sample_idx]) block_attn_masks.append( torch.tril( - torch.zeros( + torch.ones( residue_len, residue_len, dtype=torch.bool, device=seq_lens[sample_idx].device ) )