using existing packed seqlens util

This commit is contained in:
Dan Saunders
2025-04-06 18:35:31 +00:00
parent cefd57cecb
commit c64c881460
3 changed files with 21 additions and 68 deletions

View File

@@ -8,10 +8,10 @@ their sequence parallel version of Flash Attention 2.
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn.functional as F
from accelerate.logging import get_logger from accelerate.logging import get_logger
from axolotl.logging_config import configure_logging from axolotl.logging_config import configure_logging
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
configure_logging() configure_logging()
LOG = get_logger(__name__) LOG = get_logger(__name__)
@@ -102,70 +102,25 @@ def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None
) )
def calculate_packed_seq_lens(position_ids: torch.Tensor) -> torch.Tensor: def update_ring_attn_params(batch: dict[str, torch.Tensor]):
"""
Calculates lengths of packed sequences from position IDs tensor.
Args:
position_ids: A tensor of shape `[1, seq_len]` containing position IDs, where
zeros indicate potential sequence starts.
Returns:
A tensor containing the lengths of each sequence in the packed format.
"""
# Batch size must be 1 (checked in Pydantic config model validation)
position_ids = position_ids.flatten()
# Find where the position resets
sequence_starts = torch.cat(
[position_ids.new_ones(1), (position_ids[1:] == 0).to(torch.int)]
)
# Get all indices where sequence_starts
potential_indices = torch.nonzero(sequence_starts).flatten()
# Filter out indices where the next index also has a zero
valid_indices = []
for i, current_pos in enumerate(potential_indices):
# Check if this is the last index or if the next element is not a zero
if i == len(potential_indices) - 1:
break
valid_indices.append(current_pos)
start_indices = torch.tensor(valid_indices, device=potential_indices.device)
# Calculate packed sequence lengths
if len(start_indices) > 1:
packed_seq_lens = torch.diff(
start_indices, append=torch.tensor([len(position_ids)])
)
else:
packed_seq_lens = torch.tensor([len(position_ids)])
return packed_seq_lens
def update_ring_attn_params(packed_seq_lens: torch.Tensor, total_seq_len: int):
""" """
Calculate the cumulative sequence lengths for the current forward pass and pass the Calculate the cumulative sequence lengths for the current forward pass and pass the
value to the substituted ring_flash_attn. value to the substituted `ring_flash_attn`.
Logic borrowed from
https://github.com/zhuzilin/OpenRLHF/blob/47f7cd8fc76de6d057d053251c1b55c00421cc24/openrlhf/models/ring_attn_utils.py#L43.
Args: Args:
packed_seq_lens: Lengths of multipacked sequences. batch: A dictionary with a batch of data. May or may not contain `position_ids`
total_seq_len: Length of the full sequence. data; if not, we compute it.
""" """
cu_seqlens = torch.cumsum(
packed_seq_lens.clone()
.detach()
.to(device=torch.cuda.current_device(), dtype=torch.int32),
dim=-1,
dtype=torch.int32,
)
cu_seqlens = F.pad(F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len)
from ring_flash_attn import update_ring_flash_attn_params from ring_flash_attn import update_ring_flash_attn_params
input_ids = batch["input_ids"]
position_ids = batch.get("position_ids")
if position_ids is None:
seq_len = input_ids.shape[1]
position_ids = torch.arange(
0, seq_len, dtype=torch.long, device=input_ids.device
).unsqueeze(0)
cu_seqlens, _ = get_cu_seqlens_from_pos_ids(position_ids)
cu_seqlens = cu_seqlens.squeeze().to(device=torch.cuda.current_device())
update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group()) update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group())

View File

@@ -96,7 +96,9 @@ def get_cu_seqlens(attn_mask):
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens) return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
def get_cu_seqlens_from_pos_ids(position_ids): def get_cu_seqlens_from_pos_ids(
position_ids: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""generate a cumulative sequence length mask for flash attention using pos ids""" """generate a cumulative sequence length mask for flash attention using pos ids"""
if len(position_ids.shape) == 1: if len(position_ids.shape) == 1:
position_ids = position_ids.unsqueeze(0) position_ids = position_ids.unsqueeze(0)

View File

@@ -12,10 +12,7 @@ import torch.distributed as dist
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from transformers.utils import PaddingStrategy from transformers.utils import PaddingStrategy
from axolotl.monkeypatch.attention.ring_attn import ( from axolotl.monkeypatch.attention.ring_attn import update_ring_attn_params
calculate_packed_seq_lens,
update_ring_attn_params,
)
@dataclass @dataclass
@@ -166,13 +163,12 @@ class DataCollatorForSeq2Seq:
end = start + slice_size end = start + slice_size
# Update params for ring attention calculation # Update params for ring attention calculation
packed_seq_lens = calculate_packed_seq_lens(batch["position_ids"]) update_ring_attn_params(batch=batch)
update_ring_attn_params(packed_seq_lens, total_seq_len)
# Slice batch for sequence parallel processing
keys_to_slice = ["input_ids", "attention_mask", "labels", "position_ids"] keys_to_slice = ["input_ids", "attention_mask", "labels", "position_ids"]
for key in keys_to_slice: for key in keys_to_slice:
if key in batch: if key in batch:
# Slice batch for local sequence parallel processing
batch[key] = batch[key][:, start:end] batch[key] = batch[key][:, start:end]
return batch return batch