refactor and fix multipack seqlens
This commit is contained in:
@@ -235,6 +235,9 @@ class AxolotlTrainer(
|
||||
self.accelerator.even_batches = False
|
||||
|
||||
# Return unprepared dataloader if using sequence parallelism
|
||||
# TODO(djsaunde): We might be able to use `accelerate`'s dataloader preparation
|
||||
# if we use `dispatch_batches` and `slice_fn_for_dispatch` properly (i.e.,
|
||||
# slice each batch along the sequence dimension).
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
return dataloader
|
||||
|
||||
|
||||
@@ -1,88 +1,22 @@
|
||||
"""Module for Axolotl trainer sequence parallelism mixin"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from datasets import Dataset
|
||||
from torch import nn
|
||||
from torch.utils.data import DistributedSampler, Sampler
|
||||
|
||||
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from ring_flash_attn import update_ring_flash_attn_params
|
||||
except ImportError:
|
||||
# We pass silently here, but raise an ImportError in our Axolotl config validation
|
||||
# if cfg.sequence_parallel_degree > 1 and `ring-flash-attn` is not installed.
|
||||
pass
|
||||
|
||||
|
||||
def calculate_cu_seqlens(position_ids: torch.Tensor, total_seq_len: int) -> torch.Tensor:
|
||||
# Must be batch size 1
|
||||
position_ids = position_ids.flatten()
|
||||
LOG.info(f"position_ids: {position_ids}")
|
||||
|
||||
# Find where the position resets to 0 (indicating a new sequence)
|
||||
# We add position_ids.new_ones(1) to mark the start of the first sequence
|
||||
sequence_starts = torch.cat([position_ids.new_ones(1), (position_ids[1:] == 0).to(torch.int)])
|
||||
|
||||
# Get all indices where sequence_starts
|
||||
potential_indices = torch.nonzero(sequence_starts).flatten()
|
||||
|
||||
# Filter out indices where the next index also has a zero
|
||||
valid_indices = []
|
||||
for i in range(len(potential_indices)):
|
||||
# Get current index position in the original tensor
|
||||
current_pos = potential_indices[i]
|
||||
|
||||
# Check if this is the last index or if the next element is not a zero
|
||||
if i == len(potential_indices) - 1:
|
||||
continue
|
||||
elif potential_indices[i + 1] != current_pos + 1:
|
||||
valid_indices.append(current_pos)
|
||||
|
||||
start_indices = torch.tensor(valid_indices, device=potential_indices.device)
|
||||
LOG.info(f"start_indices: {start_indices}")
|
||||
|
||||
# Calculate individual sequence lengths
|
||||
if len(start_indices) > 1:
|
||||
sequence_lengths = torch.diff(start_indices, append=torch.tensor([len(position_ids)]))
|
||||
else:
|
||||
sequence_lengths = torch.tensor([len(position_ids)])
|
||||
|
||||
LOG.info(f"sequence_lengths: {sequence_lengths}")
|
||||
|
||||
# Calculate cumulative sequence lengths
|
||||
cu_seqlens = torch.cumsum(
|
||||
sequence_lengths.to(torch.cuda.current_device()),
|
||||
dim=0,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
LOG.info(f"cu_seqlens: {cu_seqlens}")
|
||||
|
||||
cu_seqlens = F.pad(F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len)
|
||||
LOG.info(f"cu_seqlens with padding: {cu_seqlens}")
|
||||
|
||||
import torch.distributed as dist
|
||||
if dist.get_rank() == 1:
|
||||
import ipdb; ipdb.set_trace()
|
||||
dist.barrier()
|
||||
|
||||
return cu_seqlens
|
||||
|
||||
|
||||
class SequenceParallelMixin:
|
||||
"""
|
||||
Mixin class for sequence parallelism support in trainers.
|
||||
|
||||
This mixin provides functionality for handling sequence parallelism,
|
||||
including creating appropriate samplers, managing data partitioning,
|
||||
and updating ring flash attention parameters during training.
|
||||
specifically for creating appropriate data samplers.
|
||||
"""
|
||||
|
||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||
@@ -153,109 +87,3 @@ class SequenceParallelMixin:
|
||||
return self._create_sequence_parallel_sampler(
|
||||
eval_dataset, shuffle=False, is_eval=True
|
||||
)
|
||||
|
||||
def _update_ring_flash_attn_params(self, inputs: dict[str, torch.Tensor | Any]):
|
||||
"""
|
||||
Calculate the cu_seqlens for the current forward pass and pass the value to
|
||||
the substituted ring_flash_attn. This is accomplished by using the passed
|
||||
`input_ids`.
|
||||
|
||||
Args:
|
||||
inputs: Current batch of inputs.
|
||||
"""
|
||||
# At this point, inputs should already be partitioned by the sequence
|
||||
# parallel data collator
|
||||
batch_size = inputs["input_ids"].shape[0]
|
||||
seq_len = inputs["input_ids"].shape[1]
|
||||
packed_seq_lens = [seq_len] * batch_size
|
||||
|
||||
# Calculate the full sequence length across all GPUs in this SP group
|
||||
total_seq_len = seq_len * self.args.sequence_parallel_degree
|
||||
|
||||
# cu_seqlens = torch.cumsum(
|
||||
# torch.tensor(
|
||||
# packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32
|
||||
# ),
|
||||
# dim=-1,
|
||||
# dtype=torch.int32,
|
||||
# )
|
||||
# cu_seqlens = F.pad(
|
||||
# F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len
|
||||
# )
|
||||
|
||||
# packed_seq_lens = []
|
||||
# current_len = 1 # Start counting the first token
|
||||
|
||||
# # Iterate through position IDs starting from the second element
|
||||
# for i in range(1, len(inputs["position_ids"])):
|
||||
# # If current position is less than previous, it's a new sequence
|
||||
# if inputs["position_ids"][i] < inputs["position_ids"][i - 1]:
|
||||
# packed_seq_lens.append(current_len)
|
||||
# current_len = 1
|
||||
# else:
|
||||
# current_len += 1
|
||||
|
||||
# # Add the last sequence length
|
||||
# packed_seq_lens.append(current_len)
|
||||
# LOG.info(f"{packed_seq_lens}: packed_seq_lens")
|
||||
|
||||
# cu_seqlens = torch.cumsum(
|
||||
# torch.tensor(packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32),
|
||||
# dim=-1,
|
||||
# dtype=torch.int32,
|
||||
# )
|
||||
# cu_seqlens = F.pad(F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len)
|
||||
# LOG.info(f"{cu_seqlens}: cu_seqlens")
|
||||
|
||||
cu_seqlens = calculate_cu_seqlens(inputs["position_ids"], total_seq_len)
|
||||
update_ring_flash_attn_params(cu_seqlens, self.ring_attn_group)
|
||||
|
||||
def training_step(
|
||||
self,
|
||||
model: nn.Module,
|
||||
inputs: dict[str, torch.Tensor | Any],
|
||||
num_items_in_batch: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Perform a training step on a batch of inputs. Overrides the
|
||||
`transformers.trainer.Trainer` method to handle sequence parallelism if
|
||||
enabled.
|
||||
|
||||
Args:
|
||||
model: Model to perform training step for.
|
||||
inputs: Dictionary mapping.
|
||||
"""
|
||||
# Set up sequence parallelism for this step if enabled
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
self._update_ring_flash_attn_params(inputs)
|
||||
|
||||
# Proceed with normal training step
|
||||
return super().training_step(model, inputs, num_items_in_batch) # type: ignore
|
||||
|
||||
def prediction_step(
|
||||
self,
|
||||
model: nn.Module,
|
||||
inputs: dict[str, torch.Tensor | Any],
|
||||
prediction_loss_only: bool,
|
||||
ignore_keys: list[str] | None = None,
|
||||
) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]:
|
||||
"""
|
||||
Perform a prediction step on a batch of inputs. Overrides the
|
||||
`transformers.trainer.Trainer` method to handle sequence parallelism if
|
||||
enabled.
|
||||
|
||||
Args:
|
||||
model: Model to perform prediction step for.
|
||||
inputs: Dictionary mapping of inputs.
|
||||
prediction_loss_only: Whether to return only the loss.
|
||||
ignore_keys: Keys to ignore in the inputs.
|
||||
|
||||
Returns:
|
||||
Tuple of (loss, logits, labels).
|
||||
"""
|
||||
# Set up sequence parallelism for this prediction step if enabled
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
self._update_ring_flash_attn_params(inputs)
|
||||
|
||||
# Proceed with normal prediction step
|
||||
return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys) # type: ignore
|
||||
|
||||
@@ -6,7 +6,9 @@ package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patc
|
||||
their sequence parallel version of Flash Attention 2.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from accelerate.logging import get_logger
|
||||
|
||||
from axolotl.logging_config import configure_logging
|
||||
@@ -98,3 +100,72 @@ def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None
|
||||
substitute_hf_flash_attn(
|
||||
process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride
|
||||
)
|
||||
|
||||
|
||||
def calculate_packed_seq_lens(position_ids: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Calculates lengths of packed sequences from position IDs tensor.
|
||||
|
||||
Args:
|
||||
position_ids: A tensor of shape `[1, seq_len]` containing position IDs, where
|
||||
zeros indicate potential sequence starts.
|
||||
|
||||
Returns:
|
||||
A tensor containing the lengths of each sequence in the packed format.
|
||||
"""
|
||||
# Batch size must be 1 (checked in Pydantic config model validation)
|
||||
position_ids = position_ids.flatten()
|
||||
|
||||
# Find where the position resets
|
||||
sequence_starts = torch.cat(
|
||||
[position_ids.new_ones(1), (position_ids[1:] == 0).to(torch.int)]
|
||||
)
|
||||
|
||||
# Get all indices where sequence_starts
|
||||
potential_indices = torch.nonzero(sequence_starts).flatten()
|
||||
|
||||
# Filter out indices where the next index also has a zero
|
||||
valid_indices = []
|
||||
for i, current_pos in enumerate(potential_indices):
|
||||
# Check if this is the last index or if the next element is not a zero
|
||||
if i == len(potential_indices) - 1:
|
||||
break
|
||||
valid_indices.append(current_pos)
|
||||
|
||||
start_indices = torch.tensor(valid_indices, device=potential_indices.device)
|
||||
|
||||
# Calculate packed sequence lengths
|
||||
if len(start_indices) > 1:
|
||||
packed_seq_lens = torch.diff(
|
||||
start_indices, append=torch.tensor([len(position_ids)])
|
||||
)
|
||||
else:
|
||||
packed_seq_lens = torch.tensor([len(position_ids)])
|
||||
|
||||
return packed_seq_lens
|
||||
|
||||
|
||||
def update_ring_attn_params(packed_seq_lens: torch.Tensor, total_seq_len: int):
|
||||
"""
|
||||
Calculate the cumulative sequence lengths for the current forward pass and pass the
|
||||
value to the substituted ring_flash_attn.
|
||||
|
||||
Logic borrowed from
|
||||
https://github.com/zhuzilin/OpenRLHF/blob/47f7cd8fc76de6d057d053251c1b55c00421cc24/openrlhf/models/ring_attn_utils.py#L43.
|
||||
|
||||
Args:
|
||||
packed_seq_lens: Lengths of multipacked sequences.
|
||||
total_seq_len: Length of the full sequence.
|
||||
"""
|
||||
cu_seqlens = torch.cumsum(
|
||||
packed_seq_lens.clone()
|
||||
.detach()
|
||||
.to(device=torch.cuda.current_device(), dtype=torch.int32),
|
||||
dim=-1,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
cu_seqlens = F.pad(F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len)
|
||||
|
||||
from ring_flash_attn import update_ring_flash_attn_params
|
||||
|
||||
update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group())
|
||||
|
||||
@@ -3,7 +3,6 @@ Data collators for axolotl to pad labels and position_ids for packed sequences.
|
||||
includes logic for handling sequence parallelism collation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
@@ -13,46 +12,10 @@ import torch.distributed as dist
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from transformers.utils import PaddingStrategy
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def adjust_position_ids_for_slice(
|
||||
position_ids: torch.Tensor, start_idx: int
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Adjust position IDs for a sliced sequence to maintain proper relative positions.
|
||||
This handles the case where position IDs might not be contiguous due to sample
|
||||
packing.
|
||||
"""
|
||||
# Convert to tensor if not already
|
||||
# Find the boundaries between samples (where position_ids reset)
|
||||
adjusted_pos_ids = position_ids.clone()
|
||||
|
||||
# Process each sequence in the batch
|
||||
for i in range(position_ids.shape[0]):
|
||||
seq = position_ids[i]
|
||||
|
||||
# Find sample boundaries
|
||||
boundaries = []
|
||||
for j in range(1, len(seq)):
|
||||
if seq[j] < seq[j - 1]:
|
||||
boundaries.append(j)
|
||||
|
||||
# No need to adjust if there are no boundaries or this is a single sample
|
||||
if not boundaries:
|
||||
adjusted_pos_ids[i] = seq - start_idx
|
||||
continue
|
||||
|
||||
# Adjust each segment separately
|
||||
prev_boundary = 0
|
||||
for boundary in boundaries:
|
||||
adjusted_pos_ids[i, prev_boundary:boundary] -= start_idx
|
||||
prev_boundary = boundary
|
||||
|
||||
# Last segment
|
||||
adjusted_pos_ids[i, prev_boundary:] -= start_idx
|
||||
|
||||
return adjusted_pos_ids
|
||||
from axolotl.monkeypatch.attention.ring_attn import (
|
||||
calculate_packed_seq_lens,
|
||||
update_ring_attn_params,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -196,23 +159,21 @@ class DataCollatorForSeq2Seq:
|
||||
Returns:
|
||||
Sliced batch dictionary.
|
||||
"""
|
||||
keys_to_slice = ["input_ids", "attention_mask", "labels", "position_ids"]
|
||||
# Get local (start, end) for sequence parallelism slicing
|
||||
total_seq_len = batch["input_ids"].shape[1]
|
||||
slice_size = total_seq_len // self.local_world_size
|
||||
start = self.local_rank * slice_size
|
||||
end = start + slice_size
|
||||
|
||||
# Update params for ring attention calculation
|
||||
packed_seq_lens = calculate_packed_seq_lens(batch["position_ids"])
|
||||
update_ring_attn_params(packed_seq_lens, total_seq_len)
|
||||
|
||||
keys_to_slice = ["input_ids", "attention_mask", "labels", "position_ids"]
|
||||
for key in keys_to_slice:
|
||||
if key in batch:
|
||||
seq_len = batch[key].shape[1]
|
||||
slice_size = seq_len // self.local_world_size
|
||||
start_idx = self.local_rank * slice_size
|
||||
end_idx = (
|
||||
start_idx + slice_size
|
||||
if self.local_rank < self.local_world_size - 1
|
||||
else seq_len
|
||||
)
|
||||
batch[key] = batch[key][:, start_idx:end_idx]
|
||||
|
||||
# Special handling for position_ids
|
||||
# if key == "position_ids" and self.local_rank > 0:
|
||||
# batch[key] = adjust_position_ids_for_slice(batch[key], start_idx)
|
||||
# Slice batch for local sequence parallel processing
|
||||
batch[key] = batch[key][:, start:end]
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
Reference in New Issue
Block a user