diff --git a/src/axolotl/monkeypatch/flex_attn.py b/src/axolotl/monkeypatch/flex_attn.py index 76b7d3bc8..8089edeae 100644 --- a/src/axolotl/monkeypatch/flex_attn.py +++ b/src/axolotl/monkeypatch/flex_attn.py @@ -50,8 +50,10 @@ def create_block_causal_mask( residue_len = max_seq_len - torch.sum(seq_lens[sample_idx]) block_attn_masks.append( - torch.zeros( - residue_len, residue_len, dtype=torch.bool, device=seq_lens[0][0].device + torch.tril( + torch.zeros( + residue_len, residue_len, dtype=torch.bool, device=seq_lens[sample_idx].device + ) ) ) diff --git a/src/axolotl/utils/collators/batching.py b/src/axolotl/utils/collators/batching.py index b68a31d71..01d9792bc 100644 --- a/src/axolotl/utils/collators/batching.py +++ b/src/axolotl/utils/collators/batching.py @@ -182,7 +182,7 @@ class FlexBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): collated_seq_lens, max_seq_len = get_seqlens_from_pos_ids(out["position_ids"]) # out["attention_mask"] = packed_block_causal_mask(collated_seq_lens) out["attention_mask"] = create_block_causal_mask(collated_seq_lens, max_seq_len) - raise ValueError(f"{out['attention_mask'].shape}") + # raise ValueError(f"{out['attention_mask'].shape}") return out