diff --git a/src/axolotl/monkeypatch/flex_attn.py b/src/axolotl/monkeypatch/flex_attn.py index e2b3670e8..4e86b65c3 100644 --- a/src/axolotl/monkeypatch/flex_attn.py +++ b/src/axolotl/monkeypatch/flex_attn.py @@ -94,7 +94,7 @@ def _get_document_ids_from_seq_lens( def packed_block_causal_mask( - seq_lens: list[torch.Tensor], max_seq_len: int + seq_lens: list[torch.Tensor], totalseqlens: list[int] ) -> _MaskType: """ Create a block causal document mask for a batch of packed sequences. If @@ -113,7 +113,7 @@ def packed_block_causal_mask( """ document_ids = _get_document_ids_from_seq_lens(seq_lens) - batch_size, _ = document_ids.shape + batch_size , max_seq_len = document_ids document_ids = document_ids.to("cuda") # Instead of passing a tensor mask, flex attention requires a mask_mod function @@ -131,7 +131,7 @@ def packed_block_causal_mask( """ causal_mask = q_idx >= kv_idx document_mask = document_ids[b, q_idx] == document_ids[b, kv_idx] - return causal_mask & document_mask + return causal_mask & document_mask & (q_idx < totalseqlens[b]) return create_block_causal_mask_flex( mask_mod, diff --git a/src/axolotl/monkeypatch/utils.py b/src/axolotl/monkeypatch/utils.py index 363fe78a9..dcc0f5645 100644 --- a/src/axolotl/monkeypatch/utils.py +++ b/src/axolotl/monkeypatch/utils.py @@ -103,6 +103,7 @@ def get_seqlens_from_pos_ids(position_ids): device = position_ids.device results = [] + totalseqlens = [] for row in position_ids: # Count the number of consecutive zeros from the right side @@ -128,7 +129,7 @@ def get_seqlens_from_pos_ids(position_ids): # Calculate the sequence lengths seq_lengths = start_indices[1:] - start_indices[:-1] # Append the padding length to the sequence lengths - """if padding_length: + if padding_length: seq_lengths = torch.cat( [ seq_lengths, @@ -138,11 +139,12 @@ def get_seqlens_from_pos_ids(position_ids): device=device, ), ] - )""" + ) results.append(seq_lengths) + totalseqlens.append(len(adjusted_row)) - return results , max_seq_len + return results , totalseqlens def get_cu_seqlens_from_pos_ids(position_ids): diff --git a/src/axolotl/utils/collators/batching.py b/src/axolotl/utils/collators/batching.py index 96c87ddbe..5a4d081de 100644 --- a/src/axolotl/utils/collators/batching.py +++ b/src/axolotl/utils/collators/batching.py @@ -179,8 +179,8 @@ class FlexBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): out_features[i][feature] = np.concatenate(arrays) out = super().__call__(out_features, return_tensors=return_tensors) - collated_seq_lens, max_seq_len = get_seqlens_from_pos_ids(out["position_ids"]) - out["attention_mask"] = packed_block_causal_mask(collated_seq_lens, max_seq_len) + collated_seq_lens, totalseqlens = get_seqlens_from_pos_ids(out["position_ids"]) + out["attention_mask"] = packed_block_causal_mask(collated_seq_lens, totalseqlens) # out["attention_mask"] = create_block_causal_mask(collated_seq_lens, max_seq_len) # raise ValueError(f"{out['attention_mask'].shape}") return out