using existing packed seqlens util
This commit is contained in:
@@ -8,10 +8,10 @@ their sequence parallel version of Flash Attention 2.
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn.functional as F
|
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
|
|
||||||
from axolotl.logging_config import configure_logging
|
from axolotl.logging_config import configure_logging
|
||||||
|
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||||
|
|
||||||
configure_logging()
|
configure_logging()
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
@@ -102,70 +102,25 @@ def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def calculate_packed_seq_lens(position_ids: torch.Tensor) -> torch.Tensor:
|
def update_ring_attn_params(batch: dict[str, torch.Tensor]):
|
||||||
"""
|
|
||||||
Calculates lengths of packed sequences from position IDs tensor.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
position_ids: A tensor of shape `[1, seq_len]` containing position IDs, where
|
|
||||||
zeros indicate potential sequence starts.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A tensor containing the lengths of each sequence in the packed format.
|
|
||||||
"""
|
|
||||||
# Batch size must be 1 (checked in Pydantic config model validation)
|
|
||||||
position_ids = position_ids.flatten()
|
|
||||||
|
|
||||||
# Find where the position resets
|
|
||||||
sequence_starts = torch.cat(
|
|
||||||
[position_ids.new_ones(1), (position_ids[1:] == 0).to(torch.int)]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get all indices where sequence_starts
|
|
||||||
potential_indices = torch.nonzero(sequence_starts).flatten()
|
|
||||||
|
|
||||||
# Filter out indices where the next index also has a zero
|
|
||||||
valid_indices = []
|
|
||||||
for i, current_pos in enumerate(potential_indices):
|
|
||||||
# Check if this is the last index or if the next element is not a zero
|
|
||||||
if i == len(potential_indices) - 1:
|
|
||||||
break
|
|
||||||
valid_indices.append(current_pos)
|
|
||||||
|
|
||||||
start_indices = torch.tensor(valid_indices, device=potential_indices.device)
|
|
||||||
|
|
||||||
# Calculate packed sequence lengths
|
|
||||||
if len(start_indices) > 1:
|
|
||||||
packed_seq_lens = torch.diff(
|
|
||||||
start_indices, append=torch.tensor([len(position_ids)])
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
packed_seq_lens = torch.tensor([len(position_ids)])
|
|
||||||
|
|
||||||
return packed_seq_lens
|
|
||||||
|
|
||||||
|
|
||||||
def update_ring_attn_params(packed_seq_lens: torch.Tensor, total_seq_len: int):
|
|
||||||
"""
|
"""
|
||||||
Calculate the cumulative sequence lengths for the current forward pass and pass the
|
Calculate the cumulative sequence lengths for the current forward pass and pass the
|
||||||
value to the substituted ring_flash_attn.
|
value to the substituted `ring_flash_attn`.
|
||||||
|
|
||||||
Logic borrowed from
|
|
||||||
https://github.com/zhuzilin/OpenRLHF/blob/47f7cd8fc76de6d057d053251c1b55c00421cc24/openrlhf/models/ring_attn_utils.py#L43.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
packed_seq_lens: Lengths of multipacked sequences.
|
batch: A dictionary with a batch of data. May or may not contain `position_ids`
|
||||||
total_seq_len: Length of the full sequence.
|
data; if not, we compute it.
|
||||||
"""
|
"""
|
||||||
cu_seqlens = torch.cumsum(
|
|
||||||
packed_seq_lens.clone()
|
|
||||||
.detach()
|
|
||||||
.to(device=torch.cuda.current_device(), dtype=torch.int32),
|
|
||||||
dim=-1,
|
|
||||||
dtype=torch.int32,
|
|
||||||
)
|
|
||||||
cu_seqlens = F.pad(F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len)
|
|
||||||
|
|
||||||
from ring_flash_attn import update_ring_flash_attn_params
|
from ring_flash_attn import update_ring_flash_attn_params
|
||||||
|
|
||||||
|
input_ids = batch["input_ids"]
|
||||||
|
position_ids = batch.get("position_ids")
|
||||||
|
if position_ids is None:
|
||||||
|
seq_len = input_ids.shape[1]
|
||||||
|
position_ids = torch.arange(
|
||||||
|
0, seq_len, dtype=torch.long, device=input_ids.device
|
||||||
|
).unsqueeze(0)
|
||||||
|
|
||||||
|
cu_seqlens, _ = get_cu_seqlens_from_pos_ids(position_ids)
|
||||||
|
cu_seqlens = cu_seqlens.squeeze().to(device=torch.cuda.current_device())
|
||||||
update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group())
|
update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group())
|
||||||
|
|||||||
@@ -96,7 +96,9 @@ def get_cu_seqlens(attn_mask):
|
|||||||
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
|
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
|
||||||
|
|
||||||
|
|
||||||
def get_cu_seqlens_from_pos_ids(position_ids):
|
def get_cu_seqlens_from_pos_ids(
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""generate a cumulative sequence length mask for flash attention using pos ids"""
|
"""generate a cumulative sequence length mask for flash attention using pos ids"""
|
||||||
if len(position_ids.shape) == 1:
|
if len(position_ids.shape) == 1:
|
||||||
position_ids = position_ids.unsqueeze(0)
|
position_ids = position_ids.unsqueeze(0)
|
||||||
|
|||||||
@@ -12,10 +12,7 @@ import torch.distributed as dist
|
|||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
from transformers.utils import PaddingStrategy
|
from transformers.utils import PaddingStrategy
|
||||||
|
|
||||||
from axolotl.monkeypatch.attention.ring_attn import (
|
from axolotl.monkeypatch.attention.ring_attn import update_ring_attn_params
|
||||||
calculate_packed_seq_lens,
|
|
||||||
update_ring_attn_params,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -166,13 +163,12 @@ class DataCollatorForSeq2Seq:
|
|||||||
end = start + slice_size
|
end = start + slice_size
|
||||||
|
|
||||||
# Update params for ring attention calculation
|
# Update params for ring attention calculation
|
||||||
packed_seq_lens = calculate_packed_seq_lens(batch["position_ids"])
|
update_ring_attn_params(batch=batch)
|
||||||
update_ring_attn_params(packed_seq_lens, total_seq_len)
|
|
||||||
|
|
||||||
|
# Slice batch for sequence parallel processing
|
||||||
keys_to_slice = ["input_ids", "attention_mask", "labels", "position_ids"]
|
keys_to_slice = ["input_ids", "attention_mask", "labels", "position_ids"]
|
||||||
for key in keys_to_slice:
|
for key in keys_to_slice:
|
||||||
if key in batch:
|
if key in batch:
|
||||||
# Slice batch for local sequence parallel processing
|
|
||||||
batch[key] = batch[key][:, start:end]
|
batch[key] = batch[key][:, start:end]
|
||||||
|
|
||||||
return batch
|
return batch
|
||||||
|
|||||||
Reference in New Issue
Block a user