refactor and fix multipack seqlens

This commit is contained in:
Dan Saunders
2025-04-06 00:31:19 +00:00
parent 4188700b7b
commit 741015b3cf
4 changed files with 91 additions and 228 deletions

View File

@@ -235,6 +235,9 @@ class AxolotlTrainer(
self.accelerator.even_batches = False
# Return unprepared dataloader if using sequence parallelism
# TODO(djsaunde): We might be able to use `accelerate`'s dataloader preparation
# if we use `dispatch_batches` and `slice_fn_for_dispatch` properly (i.e.,
# slice each batch along the sequence dimension).
if self.args.sequence_parallel_degree > 1:
return dataloader

View File

@@ -1,88 +1,22 @@
"""Module for Axolotl trainer sequence parallelism mixin"""
import logging
from typing import Any
import torch
import torch.distributed as dist
import torch.nn.functional as F
from datasets import Dataset
from torch import nn
from torch.utils.data import DistributedSampler, Sampler
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
LOG = logging.getLogger(__name__)
try:
from ring_flash_attn import update_ring_flash_attn_params
except ImportError:
# We pass silently here, but raise an ImportError in our Axolotl config validation
# if cfg.sequence_parallel_degree > 1 and `ring-flash-attn` is not installed.
pass
def calculate_cu_seqlens(position_ids: torch.Tensor, total_seq_len: int) -> torch.Tensor:
# Must be batch size 1
position_ids = position_ids.flatten()
LOG.info(f"position_ids: {position_ids}")
# Find where the position resets to 0 (indicating a new sequence)
# We add position_ids.new_ones(1) to mark the start of the first sequence
sequence_starts = torch.cat([position_ids.new_ones(1), (position_ids[1:] == 0).to(torch.int)])
# Get all indices where sequence_starts
potential_indices = torch.nonzero(sequence_starts).flatten()
# Filter out indices where the next index also has a zero
valid_indices = []
for i in range(len(potential_indices)):
# Get current index position in the original tensor
current_pos = potential_indices[i]
# Check if this is the last index or if the next element is not a zero
if i == len(potential_indices) - 1:
continue
elif potential_indices[i + 1] != current_pos + 1:
valid_indices.append(current_pos)
start_indices = torch.tensor(valid_indices, device=potential_indices.device)
LOG.info(f"start_indices: {start_indices}")
# Calculate individual sequence lengths
if len(start_indices) > 1:
sequence_lengths = torch.diff(start_indices, append=torch.tensor([len(position_ids)]))
else:
sequence_lengths = torch.tensor([len(position_ids)])
LOG.info(f"sequence_lengths: {sequence_lengths}")
# Calculate cumulative sequence lengths
cu_seqlens = torch.cumsum(
sequence_lengths.to(torch.cuda.current_device()),
dim=0,
dtype=torch.int32,
)
LOG.info(f"cu_seqlens: {cu_seqlens}")
cu_seqlens = F.pad(F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len)
LOG.info(f"cu_seqlens with padding: {cu_seqlens}")
import torch.distributed as dist
if dist.get_rank() == 1:
import ipdb; ipdb.set_trace()
dist.barrier()
return cu_seqlens
class SequenceParallelMixin:
"""
Mixin class for sequence parallelism support in trainers.
This mixin provides functionality for handling sequence parallelism,
including creating appropriate samplers, managing data partitioning,
and updating ring flash attention parameters during training.
specifically for creating appropriate data samplers.
"""
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
@@ -153,109 +87,3 @@ class SequenceParallelMixin:
return self._create_sequence_parallel_sampler(
eval_dataset, shuffle=False, is_eval=True
)
def _update_ring_flash_attn_params(self, inputs: dict[str, torch.Tensor | Any]):
"""
Calculate the cu_seqlens for the current forward pass and pass the value to
the substituted ring_flash_attn. This is accomplished by using the passed
`input_ids`.
Args:
inputs: Current batch of inputs.
"""
# At this point, inputs should already be partitioned by the sequence
# parallel data collator
batch_size = inputs["input_ids"].shape[0]
seq_len = inputs["input_ids"].shape[1]
packed_seq_lens = [seq_len] * batch_size
# Calculate the full sequence length across all GPUs in this SP group
total_seq_len = seq_len * self.args.sequence_parallel_degree
# cu_seqlens = torch.cumsum(
# torch.tensor(
# packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32
# ),
# dim=-1,
# dtype=torch.int32,
# )
# cu_seqlens = F.pad(
# F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len
# )
# packed_seq_lens = []
# current_len = 1 # Start counting the first token
# # Iterate through position IDs starting from the second element
# for i in range(1, len(inputs["position_ids"])):
# # If current position is less than previous, it's a new sequence
# if inputs["position_ids"][i] < inputs["position_ids"][i - 1]:
# packed_seq_lens.append(current_len)
# current_len = 1
# else:
# current_len += 1
# # Add the last sequence length
# packed_seq_lens.append(current_len)
# LOG.info(f"{packed_seq_lens}: packed_seq_lens")
# cu_seqlens = torch.cumsum(
# torch.tensor(packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32),
# dim=-1,
# dtype=torch.int32,
# )
# cu_seqlens = F.pad(F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len)
# LOG.info(f"{cu_seqlens}: cu_seqlens")
cu_seqlens = calculate_cu_seqlens(inputs["position_ids"], total_seq_len)
update_ring_flash_attn_params(cu_seqlens, self.ring_attn_group)
def training_step(
self,
model: nn.Module,
inputs: dict[str, torch.Tensor | Any],
num_items_in_batch: int | None = None,
) -> torch.Tensor:
"""
Perform a training step on a batch of inputs. Overrides the
`transformers.trainer.Trainer` method to handle sequence parallelism if
enabled.
Args:
model: Model to perform training step for.
inputs: Dictionary mapping.
"""
# Set up sequence parallelism for this step if enabled
if self.args.sequence_parallel_degree > 1:
self._update_ring_flash_attn_params(inputs)
# Proceed with normal training step
return super().training_step(model, inputs, num_items_in_batch) # type: ignore
def prediction_step(
self,
model: nn.Module,
inputs: dict[str, torch.Tensor | Any],
prediction_loss_only: bool,
ignore_keys: list[str] | None = None,
) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]:
"""
Perform a prediction step on a batch of inputs. Overrides the
`transformers.trainer.Trainer` method to handle sequence parallelism if
enabled.
Args:
model: Model to perform prediction step for.
inputs: Dictionary mapping of inputs.
prediction_loss_only: Whether to return only the loss.
ignore_keys: Keys to ignore in the inputs.
Returns:
Tuple of (loss, logits, labels).
"""
# Set up sequence parallelism for this prediction step if enabled
if self.args.sequence_parallel_degree > 1:
self._update_ring_flash_attn_params(inputs)
# Proceed with normal prediction step
return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys) # type: ignore

View File

@@ -6,7 +6,9 @@ package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patc
their sequence parallel version of Flash Attention 2.
"""
import torch
import torch.distributed as dist
import torch.nn.functional as F
from accelerate.logging import get_logger
from axolotl.logging_config import configure_logging
@@ -98,3 +100,72 @@ def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None
substitute_hf_flash_attn(
process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride
)
def calculate_packed_seq_lens(position_ids: torch.Tensor) -> torch.Tensor:
"""
Calculates lengths of packed sequences from position IDs tensor.
Args:
position_ids: A tensor of shape `[1, seq_len]` containing position IDs, where
zeros indicate potential sequence starts.
Returns:
A tensor containing the lengths of each sequence in the packed format.
"""
# Batch size must be 1 (checked in Pydantic config model validation)
position_ids = position_ids.flatten()
# Find where the position resets
sequence_starts = torch.cat(
[position_ids.new_ones(1), (position_ids[1:] == 0).to(torch.int)]
)
# Get all indices where sequence_starts
potential_indices = torch.nonzero(sequence_starts).flatten()
# Filter out indices where the next index also has a zero
valid_indices = []
for i, current_pos in enumerate(potential_indices):
# Check if this is the last index or if the next element is not a zero
if i == len(potential_indices) - 1:
break
valid_indices.append(current_pos)
start_indices = torch.tensor(valid_indices, device=potential_indices.device)
# Calculate packed sequence lengths
if len(start_indices) > 1:
packed_seq_lens = torch.diff(
start_indices, append=torch.tensor([len(position_ids)])
)
else:
packed_seq_lens = torch.tensor([len(position_ids)])
return packed_seq_lens
def update_ring_attn_params(packed_seq_lens: torch.Tensor, total_seq_len: int):
"""
Calculate the cumulative sequence lengths for the current forward pass and pass the
value to the substituted ring_flash_attn.
Logic borrowed from
https://github.com/zhuzilin/OpenRLHF/blob/47f7cd8fc76de6d057d053251c1b55c00421cc24/openrlhf/models/ring_attn_utils.py#L43.
Args:
packed_seq_lens: Lengths of multipacked sequences.
total_seq_len: Length of the full sequence.
"""
cu_seqlens = torch.cumsum(
packed_seq_lens.clone()
.detach()
.to(device=torch.cuda.current_device(), dtype=torch.int32),
dim=-1,
dtype=torch.int32,
)
cu_seqlens = F.pad(F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len)
from ring_flash_attn import update_ring_flash_attn_params
update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group())

View File

@@ -3,7 +3,6 @@ Data collators for axolotl to pad labels and position_ids for packed sequences.
includes logic for handling sequence parallelism collation.
"""
import logging
from dataclasses import dataclass
from typing import Any, Optional, Union
@@ -13,46 +12,10 @@ import torch.distributed as dist
from transformers import PreTrainedTokenizerBase
from transformers.utils import PaddingStrategy
logger = logging.getLogger(__name__)
def adjust_position_ids_for_slice(
position_ids: torch.Tensor, start_idx: int
) -> torch.Tensor:
"""
Adjust position IDs for a sliced sequence to maintain proper relative positions.
This handles the case where position IDs might not be contiguous due to sample
packing.
"""
# Convert to tensor if not already
# Find the boundaries between samples (where position_ids reset)
adjusted_pos_ids = position_ids.clone()
# Process each sequence in the batch
for i in range(position_ids.shape[0]):
seq = position_ids[i]
# Find sample boundaries
boundaries = []
for j in range(1, len(seq)):
if seq[j] < seq[j - 1]:
boundaries.append(j)
# No need to adjust if there are no boundaries or this is a single sample
if not boundaries:
adjusted_pos_ids[i] = seq - start_idx
continue
# Adjust each segment separately
prev_boundary = 0
for boundary in boundaries:
adjusted_pos_ids[i, prev_boundary:boundary] -= start_idx
prev_boundary = boundary
# Last segment
adjusted_pos_ids[i, prev_boundary:] -= start_idx
return adjusted_pos_ids
from axolotl.monkeypatch.attention.ring_attn import (
calculate_packed_seq_lens,
update_ring_attn_params,
)
@dataclass
@@ -196,23 +159,21 @@ class DataCollatorForSeq2Seq:
Returns:
Sliced batch dictionary.
"""
keys_to_slice = ["input_ids", "attention_mask", "labels", "position_ids"]
# Get local (start, end) for sequence parallelism slicing
total_seq_len = batch["input_ids"].shape[1]
slice_size = total_seq_len // self.local_world_size
start = self.local_rank * slice_size
end = start + slice_size
# Update params for ring attention calculation
packed_seq_lens = calculate_packed_seq_lens(batch["position_ids"])
update_ring_attn_params(packed_seq_lens, total_seq_len)
keys_to_slice = ["input_ids", "attention_mask", "labels", "position_ids"]
for key in keys_to_slice:
if key in batch:
seq_len = batch[key].shape[1]
slice_size = seq_len // self.local_world_size
start_idx = self.local_rank * slice_size
end_idx = (
start_idx + slice_size
if self.local_rank < self.local_world_size - 1
else seq_len
)
batch[key] = batch[key][:, start_idx:end_idx]
# Special handling for position_ids
# if key == "position_ids" and self.local_rank > 0:
# batch[key] = adjust_position_ids_for_slice(batch[key], start_idx)
# Slice batch for local sequence parallel processing
batch[key] = batch[key][:, start:end]
return batch