diff --git a/src/axolotl/monkeypatch/flex_attn.py b/src/axolotl/monkeypatch/flex_attn.py index e83dd2fff..e2b3670e8 100644 --- a/src/axolotl/monkeypatch/flex_attn.py +++ b/src/axolotl/monkeypatch/flex_attn.py @@ -42,7 +42,7 @@ def create_block_causal_mask( batch_size = len(seq_lens) for sample_idx in range(batch_size): block_attn_masks = [ - torch.tril( + torch.trilu( # torch.tril( torch.ones(seq_len, seq_len, dtype=torch.bool, device=seq_len.device) ) for seq_len in seq_lens[sample_idx] @@ -94,7 +94,7 @@ def _get_document_ids_from_seq_lens( def packed_block_causal_mask( - seq_lens: list[torch.Tensor], + seq_lens: list[torch.Tensor], max_seq_len: int ) -> _MaskType: """ Create a block causal document mask for a batch of packed sequences. If @@ -113,7 +113,7 @@ def packed_block_causal_mask( """ document_ids = _get_document_ids_from_seq_lens(seq_lens) - batch_size, max_seq_len = document_ids.shape + batch_size, _ = document_ids.shape document_ids = document_ids.to("cuda") # Instead of passing a tensor mask, flex attention requires a mask_mod function @@ -140,4 +140,5 @@ def packed_block_causal_mask( max_seq_len, max_seq_len, device="cuda", + BLOCK_SIZE=512, ) diff --git a/src/axolotl/utils/collators/batching.py b/src/axolotl/utils/collators/batching.py index 6a04df99e..96c87ddbe 100644 --- a/src/axolotl/utils/collators/batching.py +++ b/src/axolotl/utils/collators/batching.py @@ -180,7 +180,7 @@ class FlexBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): out = super().__call__(out_features, return_tensors=return_tensors) collated_seq_lens, max_seq_len = get_seqlens_from_pos_ids(out["position_ids"]) - out["attention_mask"] = packed_block_causal_mask(collated_seq_lens) + out["attention_mask"] = packed_block_causal_mask(collated_seq_lens, max_seq_len) # out["attention_mask"] = create_block_causal_mask(collated_seq_lens, max_seq_len) # raise ValueError(f"{out['attention_mask'].shape}") return out