diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index bc3a200d4..fd72cd6db 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -235,6 +235,9 @@ class AxolotlTrainer( self.accelerator.even_batches = False # Return unprepared dataloader if using sequence parallelism + # TODO(djsaunde): We might be able to use `accelerate`'s dataloader preparation + # if we use `dispatch_batches` and `slice_fn_for_dispatch` properly (i.e., + # slice each batch along the sequence dimension). if self.args.sequence_parallel_degree > 1: return dataloader diff --git a/src/axolotl/core/trainers/mixins/sequence_parallel.py b/src/axolotl/core/trainers/mixins/sequence_parallel.py index 0709e2620..3930c6cb3 100644 --- a/src/axolotl/core/trainers/mixins/sequence_parallel.py +++ b/src/axolotl/core/trainers/mixins/sequence_parallel.py @@ -1,88 +1,22 @@ """Module for Axolotl trainer sequence parallelism mixin""" import logging -from typing import Any -import torch import torch.distributed as dist -import torch.nn.functional as F from datasets import Dataset -from torch import nn from torch.utils.data import DistributedSampler, Sampler from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group LOG = logging.getLogger(__name__) -try: - from ring_flash_attn import update_ring_flash_attn_params -except ImportError: - # We pass silently here, but raise an ImportError in our Axolotl config validation - # if cfg.sequence_parallel_degree > 1 and `ring-flash-attn` is not installed. - pass - - -def calculate_cu_seqlens(position_ids: torch.Tensor, total_seq_len: int) -> torch.Tensor: - # Must be batch size 1 - position_ids = position_ids.flatten() - LOG.info(f"position_ids: {position_ids}") - - # Find where the position resets to 0 (indicating a new sequence) - # We add position_ids.new_ones(1) to mark the start of the first sequence - sequence_starts = torch.cat([position_ids.new_ones(1), (position_ids[1:] == 0).to(torch.int)]) - - # Get all indices where sequence_starts - potential_indices = torch.nonzero(sequence_starts).flatten() - - # Filter out indices where the next index also has a zero - valid_indices = [] - for i in range(len(potential_indices)): - # Get current index position in the original tensor - current_pos = potential_indices[i] - - # Check if this is the last index or if the next element is not a zero - if i == len(potential_indices) - 1: - continue - elif potential_indices[i + 1] != current_pos + 1: - valid_indices.append(current_pos) - - start_indices = torch.tensor(valid_indices, device=potential_indices.device) - LOG.info(f"start_indices: {start_indices}") - - # Calculate individual sequence lengths - if len(start_indices) > 1: - sequence_lengths = torch.diff(start_indices, append=torch.tensor([len(position_ids)])) - else: - sequence_lengths = torch.tensor([len(position_ids)]) - - LOG.info(f"sequence_lengths: {sequence_lengths}") - - # Calculate cumulative sequence lengths - cu_seqlens = torch.cumsum( - sequence_lengths.to(torch.cuda.current_device()), - dim=0, - dtype=torch.int32, - ) - LOG.info(f"cu_seqlens: {cu_seqlens}") - - cu_seqlens = F.pad(F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len) - LOG.info(f"cu_seqlens with padding: {cu_seqlens}") - - import torch.distributed as dist - if dist.get_rank() == 1: - import ipdb; ipdb.set_trace() - dist.barrier() - - return cu_seqlens - class SequenceParallelMixin: """ Mixin class for sequence parallelism support in trainers. This mixin provides functionality for handling sequence parallelism, - including creating appropriate samplers, managing data partitioning, - and updating ring flash attention parameters during training. + specifically for creating appropriate data samplers. """ args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined] @@ -153,109 +87,3 @@ class SequenceParallelMixin: return self._create_sequence_parallel_sampler( eval_dataset, shuffle=False, is_eval=True ) - - def _update_ring_flash_attn_params(self, inputs: dict[str, torch.Tensor | Any]): - """ - Calculate the cu_seqlens for the current forward pass and pass the value to - the substituted ring_flash_attn. This is accomplished by using the passed - `input_ids`. - - Args: - inputs: Current batch of inputs. - """ - # At this point, inputs should already be partitioned by the sequence - # parallel data collator - batch_size = inputs["input_ids"].shape[0] - seq_len = inputs["input_ids"].shape[1] - packed_seq_lens = [seq_len] * batch_size - - # Calculate the full sequence length across all GPUs in this SP group - total_seq_len = seq_len * self.args.sequence_parallel_degree - - # cu_seqlens = torch.cumsum( - # torch.tensor( - # packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32 - # ), - # dim=-1, - # dtype=torch.int32, - # ) - # cu_seqlens = F.pad( - # F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len - # ) - - # packed_seq_lens = [] - # current_len = 1 # Start counting the first token - - # # Iterate through position IDs starting from the second element - # for i in range(1, len(inputs["position_ids"])): - # # If current position is less than previous, it's a new sequence - # if inputs["position_ids"][i] < inputs["position_ids"][i - 1]: - # packed_seq_lens.append(current_len) - # current_len = 1 - # else: - # current_len += 1 - - # # Add the last sequence length - # packed_seq_lens.append(current_len) - # LOG.info(f"{packed_seq_lens}: packed_seq_lens") - - # cu_seqlens = torch.cumsum( - # torch.tensor(packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32), - # dim=-1, - # dtype=torch.int32, - # ) - # cu_seqlens = F.pad(F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len) - # LOG.info(f"{cu_seqlens}: cu_seqlens") - - cu_seqlens = calculate_cu_seqlens(inputs["position_ids"], total_seq_len) - update_ring_flash_attn_params(cu_seqlens, self.ring_attn_group) - - def training_step( - self, - model: nn.Module, - inputs: dict[str, torch.Tensor | Any], - num_items_in_batch: int | None = None, - ) -> torch.Tensor: - """ - Perform a training step on a batch of inputs. Overrides the - `transformers.trainer.Trainer` method to handle sequence parallelism if - enabled. - - Args: - model: Model to perform training step for. - inputs: Dictionary mapping. - """ - # Set up sequence parallelism for this step if enabled - if self.args.sequence_parallel_degree > 1: - self._update_ring_flash_attn_params(inputs) - - # Proceed with normal training step - return super().training_step(model, inputs, num_items_in_batch) # type: ignore - - def prediction_step( - self, - model: nn.Module, - inputs: dict[str, torch.Tensor | Any], - prediction_loss_only: bool, - ignore_keys: list[str] | None = None, - ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]: - """ - Perform a prediction step on a batch of inputs. Overrides the - `transformers.trainer.Trainer` method to handle sequence parallelism if - enabled. - - Args: - model: Model to perform prediction step for. - inputs: Dictionary mapping of inputs. - prediction_loss_only: Whether to return only the loss. - ignore_keys: Keys to ignore in the inputs. - - Returns: - Tuple of (loss, logits, labels). - """ - # Set up sequence parallelism for this prediction step if enabled - if self.args.sequence_parallel_degree > 1: - self._update_ring_flash_attn_params(inputs) - - # Proceed with normal prediction step - return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys) # type: ignore diff --git a/src/axolotl/monkeypatch/attention/ring_attn.py b/src/axolotl/monkeypatch/attention/ring_attn.py index 6c9d0b429..64b7fd84c 100644 --- a/src/axolotl/monkeypatch/attention/ring_attn.py +++ b/src/axolotl/monkeypatch/attention/ring_attn.py @@ -6,7 +6,9 @@ package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patc their sequence parallel version of Flash Attention 2. """ +import torch import torch.distributed as dist +import torch.nn.functional as F from accelerate.logging import get_logger from axolotl.logging_config import configure_logging @@ -98,3 +100,72 @@ def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None substitute_hf_flash_attn( process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride ) + + +def calculate_packed_seq_lens(position_ids: torch.Tensor) -> torch.Tensor: + """ + Calculates lengths of packed sequences from position IDs tensor. + + Args: + position_ids: A tensor of shape `[1, seq_len]` containing position IDs, where + zeros indicate potential sequence starts. + + Returns: + A tensor containing the lengths of each sequence in the packed format. + """ + # Batch size must be 1 (checked in Pydantic config model validation) + position_ids = position_ids.flatten() + + # Find where the position resets + sequence_starts = torch.cat( + [position_ids.new_ones(1), (position_ids[1:] == 0).to(torch.int)] + ) + + # Get all indices where sequence_starts + potential_indices = torch.nonzero(sequence_starts).flatten() + + # Filter out indices where the next index also has a zero + valid_indices = [] + for i, current_pos in enumerate(potential_indices): + # Check if this is the last index or if the next element is not a zero + if i == len(potential_indices) - 1: + break + valid_indices.append(current_pos) + + start_indices = torch.tensor(valid_indices, device=potential_indices.device) + + # Calculate packed sequence lengths + if len(start_indices) > 1: + packed_seq_lens = torch.diff( + start_indices, append=torch.tensor([len(position_ids)]) + ) + else: + packed_seq_lens = torch.tensor([len(position_ids)]) + + return packed_seq_lens + + +def update_ring_attn_params(packed_seq_lens: torch.Tensor, total_seq_len: int): + """ + Calculate the cumulative sequence lengths for the current forward pass and pass the + value to the substituted ring_flash_attn. + + Logic borrowed from + https://github.com/zhuzilin/OpenRLHF/blob/47f7cd8fc76de6d057d053251c1b55c00421cc24/openrlhf/models/ring_attn_utils.py#L43. + + Args: + packed_seq_lens: Lengths of multipacked sequences. + total_seq_len: Length of the full sequence. + """ + cu_seqlens = torch.cumsum( + packed_seq_lens.clone() + .detach() + .to(device=torch.cuda.current_device(), dtype=torch.int32), + dim=-1, + dtype=torch.int32, + ) + cu_seqlens = F.pad(F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len) + + from ring_flash_attn import update_ring_flash_attn_params + + update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group()) diff --git a/src/axolotl/utils/collators/batching.py b/src/axolotl/utils/collators/batching.py index df0876021..1cbc702cf 100644 --- a/src/axolotl/utils/collators/batching.py +++ b/src/axolotl/utils/collators/batching.py @@ -3,7 +3,6 @@ Data collators for axolotl to pad labels and position_ids for packed sequences. includes logic for handling sequence parallelism collation. """ -import logging from dataclasses import dataclass from typing import Any, Optional, Union @@ -13,46 +12,10 @@ import torch.distributed as dist from transformers import PreTrainedTokenizerBase from transformers.utils import PaddingStrategy -logger = logging.getLogger(__name__) - - -def adjust_position_ids_for_slice( - position_ids: torch.Tensor, start_idx: int -) -> torch.Tensor: - """ - Adjust position IDs for a sliced sequence to maintain proper relative positions. - This handles the case where position IDs might not be contiguous due to sample - packing. - """ - # Convert to tensor if not already - # Find the boundaries between samples (where position_ids reset) - adjusted_pos_ids = position_ids.clone() - - # Process each sequence in the batch - for i in range(position_ids.shape[0]): - seq = position_ids[i] - - # Find sample boundaries - boundaries = [] - for j in range(1, len(seq)): - if seq[j] < seq[j - 1]: - boundaries.append(j) - - # No need to adjust if there are no boundaries or this is a single sample - if not boundaries: - adjusted_pos_ids[i] = seq - start_idx - continue - - # Adjust each segment separately - prev_boundary = 0 - for boundary in boundaries: - adjusted_pos_ids[i, prev_boundary:boundary] -= start_idx - prev_boundary = boundary - - # Last segment - adjusted_pos_ids[i, prev_boundary:] -= start_idx - - return adjusted_pos_ids +from axolotl.monkeypatch.attention.ring_attn import ( + calculate_packed_seq_lens, + update_ring_attn_params, +) @dataclass @@ -196,23 +159,21 @@ class DataCollatorForSeq2Seq: Returns: Sliced batch dictionary. """ - keys_to_slice = ["input_ids", "attention_mask", "labels", "position_ids"] + # Get local (start, end) for sequence parallelism slicing + total_seq_len = batch["input_ids"].shape[1] + slice_size = total_seq_len // self.local_world_size + start = self.local_rank * slice_size + end = start + slice_size + # Update params for ring attention calculation + packed_seq_lens = calculate_packed_seq_lens(batch["position_ids"]) + update_ring_attn_params(packed_seq_lens, total_seq_len) + + keys_to_slice = ["input_ids", "attention_mask", "labels", "position_ids"] for key in keys_to_slice: if key in batch: - seq_len = batch[key].shape[1] - slice_size = seq_len // self.local_world_size - start_idx = self.local_rank * slice_size - end_idx = ( - start_idx + slice_size - if self.local_rank < self.local_world_size - 1 - else seq_len - ) - batch[key] = batch[key][:, start_idx:end_idx] - - # Special handling for position_ids - # if key == "position_ids" and self.local_rank > 0: - # batch[key] = adjust_position_ids_for_slice(batch[key], start_idx) + # Slice batch for local sequence parallel processing + batch[key] = batch[key][:, start:end] return batch