From 2319e5276d0cce82030765aa62ae9272ee5334b6 Mon Sep 17 00:00:00 2001 From: bursteratom Date: Sun, 2 Feb 2025 00:48:15 -0500 Subject: [PATCH] more test --- src/axolotl/monkeypatch/flex_attn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/axolotl/monkeypatch/flex_attn.py b/src/axolotl/monkeypatch/flex_attn.py index a1c2de644..e83dd2fff 100644 --- a/src/axolotl/monkeypatch/flex_attn.py +++ b/src/axolotl/monkeypatch/flex_attn.py @@ -48,14 +48,14 @@ def create_block_causal_mask( for seq_len in seq_lens[sample_idx] ] - residue_len = max_seq_len - torch.sum(seq_lens[sample_idx]) + """residue_len = max_seq_len - torch.sum(seq_lens[sample_idx]) block_attn_masks.append( torch.tril( torch.ones( residue_len, residue_len, dtype=torch.bool, device=seq_lens[sample_idx].device ) ) - ) + )""" batch_block_attn_masks.append(torch.block_diag(*block_attn_masks))