diff --git a/src/axolotl/monkeypatch/flex_attn.py b/src/axolotl/monkeypatch/flex_attn.py index 8089edeae..a1c2de644 100644 --- a/src/axolotl/monkeypatch/flex_attn.py +++ b/src/axolotl/monkeypatch/flex_attn.py @@ -51,7 +51,7 @@ def create_block_causal_mask( residue_len = max_seq_len - torch.sum(seq_lens[sample_idx]) block_attn_masks.append( torch.tril( - torch.zeros( + torch.ones( residue_len, residue_len, dtype=torch.bool, device=seq_lens[sample_idx].device ) )