diff --git a/src/axolotl/monkeypatch/flex_attn.py b/src/axolotl/monkeypatch/flex_attn.py index 2894dd5c6..f764e7c16 100644 --- a/src/axolotl/monkeypatch/flex_attn.py +++ b/src/axolotl/monkeypatch/flex_attn.py @@ -12,7 +12,7 @@ from torch.nn.attention.flex_attention import ( _MaskType = Union[torch.Tensor, BlockMask] -def create_block_causal_mask(seq_lens: list[torch.Tensor]) -> torch.Tensor: +def create_block_causal_mask(seq_lens: list[torch.Tensor], max_seq_len: int) -> torch.Tensor: """ Given a batch tensor of seq lens defining the lengths of samples in each pack, Construct a 2D block causal mask for each pack in the batch. For example, if @@ -43,10 +43,14 @@ def create_block_causal_mask(seq_lens: list[torch.Tensor]) -> torch.Tensor: torch.tril( torch.ones(seq_len, seq_len, dtype=torch.bool, device=seq_len.device) ) - for i, seq_len in enumerate(seq_lens[sample_idx]) + for seq_len in enumerate(seq_lens[sample_idx]) ] - + + residue_len = max_seq_len - torch.sum(seq_lens[sample_idx]) + block_attn_masks.append(torch.zeros(residue_len, residue_len, dtype=torch.bool, device=seq_lens[0][0].device)) + batch_block_attn_masks.append(torch.block_diag(*block_attn_masks)) + return torch.stack(batch_block_attn_masks)[:,None,:,:] diff --git a/src/axolotl/monkeypatch/utils.py b/src/axolotl/monkeypatch/utils.py index 5371b3ff3..f23fcf7fa 100644 --- a/src/axolotl/monkeypatch/utils.py +++ b/src/axolotl/monkeypatch/utils.py @@ -99,6 +99,7 @@ def get_seqlens_from_pos_ids(position_ids): """generate a sequence length set using pos ids for doc mask creation in flex attention""" if len(position_ids.shape) == 1: position_ids = position_ids.unsqueeze(0) + max_seq_len = position_ids.shape[1] device = position_ids.device results = [] @@ -126,22 +127,10 @@ def get_seqlens_from_pos_ids(position_ids): ) # Calculate the sequence lengths seq_lengths = start_indices[1:] - start_indices[:-1] - # Append the padding length to the sequence lengths - if padding_length: - seq_lengths = torch.cat( - [ - seq_lengths, - torch.tensor( - [len(row) - torch.sum(seq_lengths)], - dtype=torch.int32, - device=device, - ), - ] - ) results.append(seq_lengths) - return results + return results , max_seq_len def get_cu_seqlens_from_pos_ids(position_ids): diff --git a/src/axolotl/utils/collators/batching.py b/src/axolotl/utils/collators/batching.py index 0531bd2f6..5a514b613 100644 --- a/src/axolotl/utils/collators/batching.py +++ b/src/axolotl/utils/collators/batching.py @@ -179,9 +179,9 @@ class FlexBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): out_features[i][feature] = np.concatenate(arrays) out = super().__call__(out_features, return_tensors=return_tensors) - collated_seq_lens = get_seqlens_from_pos_ids(out["position_ids"]) + collated_seq_lens, max_seq_len = get_seqlens_from_pos_ids(out["position_ids"]) # out["attention_mask"] = packed_block_causal_mask(collated_seq_lens) - out["attention_mask"] = create_block_causal_mask(collated_seq_lens) + out["attention_mask"] = create_block_causal_mask(collated_seq_lens, max_seq_len) return out