flex batching WIP

This commit is contained in:
Sunny Liu
2025-01-30 12:35:25 -05:00
parent ba88bc7840
commit 96ad741cd5
4 changed files with 98 additions and 12 deletions

View File

@@ -71,6 +71,7 @@ from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.collators import (
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq,
FlexBatchSamplerDataCollatorForSeq2Seq,
MambaDataCollator,
V2BatchSamplerDataCollatorForSeq2Seq,
)
@@ -1941,6 +1942,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
Union[
V2BatchSamplerDataCollatorForSeq2Seq,
BatchSamplerDataCollatorForSeq2Seq,
FlexBatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq,
DataCollatorWithFlattening,
RewardDataCollatorWithPadding,
@@ -1952,6 +1954,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if "max_length" in kwargs:
kwargs.pop("max_length")
elif use_batch_sampler_collator:
if self.cfg.flex_attention is True:
collator = FlexBatchSamplerDataCollatorForSeq2Seq
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
collator = V2BatchSamplerDataCollatorForSeq2Seq
elif (

View File

@@ -0,0 +1,90 @@
'''
Taken from https://github.com/pytorch/torchtune/blob/main/torchtune/modules/attention_utils.py
'''
import torch
from torch.nn.attention.flex_attention import (
BlockMask,
create_block_mask as create_block_causal_mask_flex,
flex_attention,
)
def _get_document_ids_from_seq_lens(
seq_lens: List[torch.Tensor],
) -> torch.Tensor:
"""
Convert a batch tensor of seq lens into integer IDs denoting sample ownership.
For example, seq_lens = [2, 3, 1] would return [0, 0, 1, 1, 1, 2].
Args:
seq_lens (List[torch.Tensor]): Sequence lengths of samples in each pack in the batch,
shape (batch_size, n), where n is the max number of sequences in a pack and can vary
across packs.
Returns:
Tensor: Document IDs of shape (batch_size, max_seq_len).
"""
batch_size = len(seq_lens)
batch_document_ids = []
for sample_idx in range(batch_size):
# We assume seq lens sum to max seq lens, so document_ids should be of
# shape (max_seq_len, )
document_ids = torch.cat(
[
torch.full((seq_len,), i, dtype=torch.long, device=seq_len.device)
for i, seq_len in enumerate(seq_lens[sample_idx])
]
)
batch_document_ids.append(document_ids)
batch_document_ids = torch.stack(batch_document_ids)
return batch_document_ids
def packed_block_causal_mask(
seq_lens: List[torch.Tensor],
) -> _MaskType:
"""
Create a block causal document mask for a batch of packed sequences. If
flex attention is supported by the current hardware, block causal logic and
passing this into :func:`torch.nn.attention.flex_attention.create_block_mask`.
The resultant BlockMask is a compressed representation of the full block causal
mask. If on an older version, a standard 2D block causal mask is created and returned.
Args:
seq_lens (List[torch.Tensor]): Sequence lengths of samples in each pack in the batch,
shape (batch_size, n), where n is the max number of sequences in a pack and can vary
across packs.
Returns:
_MaskType: BlockMask or Tensor if torch version < 2.5.0.
"""
document_ids = _get_document_ids_from_seq_lens(seq_lens)
batch_size, max_seq_len = document_ids.shape
document_ids = document_ids.to("cuda")
# Instead of passing a tensor mask, flex attention requires a mask_mod function
# that determines which elements of QK^T should be included in the attention
# computation prior to the softmax. For sample packing, we need both the
# logic for both causal mask and document mask. See PyTorch's official
# blog post for more details: https://pytorch.org/blog/flexattention/#mask-mods
def mask_mod(b, h, q_idx, kv_idx):
"""
Defines the logic of a block causal mask by combining both a standard causal mask
and a block diagonal document mask.
See :func:`~torchtune.modules.attention_utils.create_block_causal_mask`
for an illustration.
"""
causal_mask = q_idx >= kv_idx
document_mask = document_ids[b, q_idx] == document_ids[b, kv_idx]
return causal_mask & document_mask
return create_block_causal_mask_flex(
mask_mod,
batch_size,
None,
max_seq_len,
max_seq_len,
device="cuda",
)

View File

@@ -4,6 +4,7 @@ shared axolotl collators for multipack, mamba, multimodal
from .batching import ( # noqa: F401
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq,
FlexBatchSamplerDataCollatorForSeq2Seq,
PretrainingBatchSamplerDataCollatorForSeq2Seq,
V2BatchSamplerDataCollatorForSeq2Seq,
)

View File

@@ -11,6 +11,7 @@ from transformers import PreTrainedTokenizerBase
from transformers.utils import PaddingStrategy
from axolotl.monkeypatch.utils import get_seqlens_from_pos_ids
from axolotl.monkeypatch.flex_attn import packed_block_causal_mask
@dataclass
@@ -166,15 +167,8 @@ class FlexBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
out_features = [{} for _ in features]
for i, features_ in enumerate(features):
for feature in features_[0].keys():
if feature == "length":
if feature in {"length" , "attention_mask"}:
continue
if feature == "attention_mask":
arrays = [
(i + 1) * np.array(item[feature])
for i, item in enumerate(features_)
if feature in item
]
out_features[i][feature] = np.concatenate(arrays)
else:
arrays = [
np.array(item[feature]) for item in features_ if feature in item
@@ -183,10 +177,7 @@ class FlexBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
out = super().__call__(out_features, return_tensors=return_tensors)
collated_seq_lens = get_seqlens_from_pos_ids(out["position_ids"])
doc_mask =
out["attention_mask"]
out["attention_mask"] = packed_block_causal_mask(collated_seq_lens)
return out