From b0871c8d3b6b01ba9fe133b3170c6eb221d09108 Mon Sep 17 00:00:00 2001 From: Sunny Liu Date: Sun, 2 Feb 2025 20:18:49 -0500 Subject: [PATCH] attempt - mask padding --- src/axolotl/monkeypatch/flex_attn.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/axolotl/monkeypatch/flex_attn.py b/src/axolotl/monkeypatch/flex_attn.py index 4e86b65c3..9c9aed055 100644 --- a/src/axolotl/monkeypatch/flex_attn.py +++ b/src/axolotl/monkeypatch/flex_attn.py @@ -131,7 +131,8 @@ def packed_block_causal_mask( """ causal_mask = q_idx >= kv_idx document_mask = document_ids[b, q_idx] == document_ids[b, kv_idx] - return causal_mask & document_mask & (q_idx < totalseqlens[b]) + finite_mask = q_idx < totalseqlens[b] + return causal_mask & document_mask & finite_mask return create_block_causal_mask_flex( mask_mod,