From 470ba65c444aae5a09fa5c696da12bf187e0b9dd Mon Sep 17 00:00:00 2001 From: Sunny Liu Date: Tue, 4 Feb 2025 20:27:39 -0500 Subject: [PATCH] make doc mask instead of the whole block mask in collator --- src/axolotl/monkeypatch/utils.py | 47 ++++++++++++++++++++++++- src/axolotl/utils/collators/batching.py | 22 +++++++++--- 2 files changed, 63 insertions(+), 6 deletions(-) diff --git a/src/axolotl/monkeypatch/utils.py b/src/axolotl/monkeypatch/utils.py index 4665a54d4..0dcd023a9 100644 --- a/src/axolotl/monkeypatch/utils.py +++ b/src/axolotl/monkeypatch/utils.py @@ -95,6 +95,51 @@ def get_cu_seqlens(attn_mask): return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens) +def get_packed_mask_from_pos_ids(position_ids): + if len(position_ids.shape) == 1: + position_ids = position_ids.unsqueeze(0) + + device = position_ids.device + results = [] + + for i, row in enumerate(position_ids): + # Count the number of consecutive zeros from the right side + padding_length = (row == 0).int().flip(dims=[0]).cumprod(dim=0).sum().item() + + # Adjust the row to exclude padding + adjusted_row = row[:-padding_length] if padding_length else row.clone() + + # Find where the position resets to 0 (indicating a new sequence) + seq_starts = torch.cat( + [ + torch.tensor([True], dtype=torch.bool, device=device), + adjusted_row[1:] == 0, + ] + ) + # Get the indices where the sequence starts + start_indices = torch.cat( + [ + torch.nonzero(seq_starts).unbind(dim=1)[0], + torch.tensor([len(adjusted_row)], dtype=torch.int32, device=device), + ] + ) + # Calculate the sequence lengths + seq_lengths = start_indices[1:] - start_indices[:-1] + # Append the padding length to the sequence lengths + doc_mask = torch.ones(len(row), dtype=torch.int32, device=device) + for i, seq_len in enumerate(seq_lengths): + start_id = start_indices[i] + doc_mask[start_id : start_id + seq_len] = ( + i * doc_mask[start_id : start_id + seq_len] + ) + if padding_length: + doc_mask[len(adjusted_row) :] = -100 * doc_mask[seq_lengths[-1] :] + + results.append(doc_mask) + + return results + + def get_seqlens_from_pos_ids(position_ids): """generate a sequence length set using pos ids for doc mask creation in flex attention""" if len(position_ids.shape) == 1: @@ -144,7 +189,7 @@ def get_seqlens_from_pos_ids(position_ids): results.append(seq_lengths) totalseqlens.append(len(adjusted_row)) - return results , torch.tensor(totalseqlens, dtype=torch.int32, device=device) + return results, torch.tensor(totalseqlens, dtype=torch.int32, device=device) def get_cu_seqlens_from_pos_ids(position_ids): diff --git a/src/axolotl/utils/collators/batching.py b/src/axolotl/utils/collators/batching.py index 21dc26945..9e247877b 100644 --- a/src/axolotl/utils/collators/batching.py +++ b/src/axolotl/utils/collators/batching.py @@ -14,7 +14,10 @@ from axolotl.monkeypatch.flex_attn import ( create_block_causal_mask, packed_block_causal_mask, ) -from axolotl.monkeypatch.utils import get_seqlens_from_pos_ids +from axolotl.monkeypatch.utils import ( + get_packed_mask_from_pos_ids, + get_seqlens_from_pos_ids, +) @dataclass @@ -170,7 +173,15 @@ class FlexBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): out_features = [{} for _ in features] for i, features_ in enumerate(features): for feature in features_[0].keys(): - if feature in {"length", "attention_mask"}: + if feature == "length": + continue + elif feature == "attention_mask": + """arrays = [ + i * np.array(item[feature]) + for i, item in enumerate(features_) + if feature in item + ] + out_features[i][feature] = np.concatenate(arrays)""" continue else: arrays = [ @@ -179,8 +190,9 @@ class FlexBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): out_features[i][feature] = np.concatenate(arrays) out = super().__call__(out_features, return_tensors=return_tensors) - collated_seq_lens, totalseqlens = get_seqlens_from_pos_ids(out["position_ids"]) - out["attention_mask"] = packed_block_causal_mask(collated_seq_lens, totalseqlens) + # collated_seq_lens, totalseqlens = get_seqlens_from_pos_ids(out["position_ids"]) + # out["attention_mask"] = packed_block_causal_mask(collated_seq_lens, totalseqlens) + out["attention_mask"] = get_packed_mask_from_pos_ids(out["position_ids"]) # out["attention_mask"] = create_block_causal_mask(collated_seq_lens, max_seq_len) # raise ValueError(f"{out['attention_mask'].shape}") return out @@ -243,4 +255,4 @@ class PretrainingBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): arrays = [np.array(item) for item in features[feature]] chunked_data[feature] = np.concatenate(arrays) features = [chunked_data] - return super().__call__(features, return_tensors=return_tensors) \ No newline at end of file + return super().__call__(features, return_tensors=return_tensors)