From c64c881460b9d19b0e48e9f1aec6da9185447c16 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Sun, 6 Apr 2025 18:35:31 +0000 Subject: [PATCH] using existing packed seqlens util --- .../monkeypatch/attention/ring_attn.py | 75 ++++--------------- src/axolotl/monkeypatch/utils.py | 4 +- src/axolotl/utils/collators/batching.py | 10 +-- 3 files changed, 21 insertions(+), 68 deletions(-) diff --git a/src/axolotl/monkeypatch/attention/ring_attn.py b/src/axolotl/monkeypatch/attention/ring_attn.py index 64b7fd84c..30aa78f01 100644 --- a/src/axolotl/monkeypatch/attention/ring_attn.py +++ b/src/axolotl/monkeypatch/attention/ring_attn.py @@ -8,10 +8,10 @@ their sequence parallel version of Flash Attention 2. import torch import torch.distributed as dist -import torch.nn.functional as F from accelerate.logging import get_logger from axolotl.logging_config import configure_logging +from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids configure_logging() LOG = get_logger(__name__) @@ -102,70 +102,25 @@ def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None ) -def calculate_packed_seq_lens(position_ids: torch.Tensor) -> torch.Tensor: - """ - Calculates lengths of packed sequences from position IDs tensor. - - Args: - position_ids: A tensor of shape `[1, seq_len]` containing position IDs, where - zeros indicate potential sequence starts. - - Returns: - A tensor containing the lengths of each sequence in the packed format. - """ - # Batch size must be 1 (checked in Pydantic config model validation) - position_ids = position_ids.flatten() - - # Find where the position resets - sequence_starts = torch.cat( - [position_ids.new_ones(1), (position_ids[1:] == 0).to(torch.int)] - ) - - # Get all indices where sequence_starts - potential_indices = torch.nonzero(sequence_starts).flatten() - - # Filter out indices where the next index also has a zero - valid_indices = [] - for i, current_pos in enumerate(potential_indices): - # Check if this is the last index or if the next element is not a zero - if i == len(potential_indices) - 1: - break - valid_indices.append(current_pos) - - start_indices = torch.tensor(valid_indices, device=potential_indices.device) - - # Calculate packed sequence lengths - if len(start_indices) > 1: - packed_seq_lens = torch.diff( - start_indices, append=torch.tensor([len(position_ids)]) - ) - else: - packed_seq_lens = torch.tensor([len(position_ids)]) - - return packed_seq_lens - - -def update_ring_attn_params(packed_seq_lens: torch.Tensor, total_seq_len: int): +def update_ring_attn_params(batch: dict[str, torch.Tensor]): """ Calculate the cumulative sequence lengths for the current forward pass and pass the - value to the substituted ring_flash_attn. - - Logic borrowed from - https://github.com/zhuzilin/OpenRLHF/blob/47f7cd8fc76de6d057d053251c1b55c00421cc24/openrlhf/models/ring_attn_utils.py#L43. + value to the substituted `ring_flash_attn`. Args: - packed_seq_lens: Lengths of multipacked sequences. - total_seq_len: Length of the full sequence. + batch: A dictionary with a batch of data. May or may not contain `position_ids` + data; if not, we compute it. """ - cu_seqlens = torch.cumsum( - packed_seq_lens.clone() - .detach() - .to(device=torch.cuda.current_device(), dtype=torch.int32), - dim=-1, - dtype=torch.int32, - ) - cu_seqlens = F.pad(F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len) - from ring_flash_attn import update_ring_flash_attn_params + input_ids = batch["input_ids"] + position_ids = batch.get("position_ids") + if position_ids is None: + seq_len = input_ids.shape[1] + position_ids = torch.arange( + 0, seq_len, dtype=torch.long, device=input_ids.device + ).unsqueeze(0) + + cu_seqlens, _ = get_cu_seqlens_from_pos_ids(position_ids) + cu_seqlens = cu_seqlens.squeeze().to(device=torch.cuda.current_device()) update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group()) diff --git a/src/axolotl/monkeypatch/utils.py b/src/axolotl/monkeypatch/utils.py index 43496c7c8..4c6a4de11 100644 --- a/src/axolotl/monkeypatch/utils.py +++ b/src/axolotl/monkeypatch/utils.py @@ -96,7 +96,9 @@ def get_cu_seqlens(attn_mask): return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens) -def get_cu_seqlens_from_pos_ids(position_ids): +def get_cu_seqlens_from_pos_ids( + position_ids: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: """generate a cumulative sequence length mask for flash attention using pos ids""" if len(position_ids.shape) == 1: position_ids = position_ids.unsqueeze(0) diff --git a/src/axolotl/utils/collators/batching.py b/src/axolotl/utils/collators/batching.py index 1cbc702cf..ed445ae56 100644 --- a/src/axolotl/utils/collators/batching.py +++ b/src/axolotl/utils/collators/batching.py @@ -12,10 +12,7 @@ import torch.distributed as dist from transformers import PreTrainedTokenizerBase from transformers.utils import PaddingStrategy -from axolotl.monkeypatch.attention.ring_attn import ( - calculate_packed_seq_lens, - update_ring_attn_params, -) +from axolotl.monkeypatch.attention.ring_attn import update_ring_attn_params @dataclass @@ -166,13 +163,12 @@ class DataCollatorForSeq2Seq: end = start + slice_size # Update params for ring attention calculation - packed_seq_lens = calculate_packed_seq_lens(batch["position_ids"]) - update_ring_attn_params(packed_seq_lens, total_seq_len) + update_ring_attn_params(batch=batch) + # Slice batch for sequence parallel processing keys_to_slice = ["input_ids", "attention_mask", "labels", "position_ids"] for key in keys_to_slice: if key in batch: - # Slice batch for local sequence parallel processing batch[key] = batch[key][:, start:end] return batch