remove unnecessary components

This commit is contained in:
Sunny Liu
2025-02-21 11:23:21 -05:00
parent 82d04ea060
commit e792b54bab
4 changed files with 1 additions and 195 deletions

View File

@@ -79,7 +79,6 @@ from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.collators import (
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq,
FlexBatchSamplerDataCollatorForSeq2Seq,
MambaDataCollator,
V2BatchSamplerDataCollatorForSeq2Seq,
)
@@ -817,7 +816,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
Union[
V2BatchSamplerDataCollatorForSeq2Seq,
BatchSamplerDataCollatorForSeq2Seq,
FlexBatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq,
DataCollatorWithFlattening,
RewardDataCollatorWithPadding,

View File

@@ -1,146 +0,0 @@
"""
Taken from https://github.com/pytorch/torchtune/blob/main/torchtune/modules/attention_utils.py
"""
from typing import Union
import torch
from torch.nn.attention.flex_attention import BlockMask
from torch.nn.attention.flex_attention import (
create_block_mask as create_block_causal_mask_flex,
)
_MaskType = Union[torch.Tensor, BlockMask]
def create_block_causal_mask(
seq_lens: list[torch.Tensor], max_seq_len: int
) -> torch.Tensor:
"""
Given a batch tensor of seq lens defining the lengths of samples in each pack,
Construct a 2D block causal mask for each pack in the batch. For example, if
a single sample's seq_lens is [3, 2, 1], the mask would be::
mask = [
[1, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[0, 0, 0, 1, 0, 0],
[0, 0, 0, 1, 1, 0],
[0, 0, 0, 0, 0, 1],
]
Args:
seq_lens (List[torch.Tensor]): Sequence lengths of samples in each pack in the batch,
shape (batch_size, n), where n is the max number of sequences in a pack and can vary
across packs.
Returns:
Tensor: Block causal mask of shape (batch_size, max_seq_len, max_seq_len).
"""
batch_block_attn_masks = []
batch_size = len(seq_lens)
for sample_idx in range(batch_size):
block_attn_masks = [
torch.trilu( # torch.tril(
torch.ones(seq_len, seq_len, dtype=torch.bool, device=seq_len.device)
)
for seq_len in seq_lens[sample_idx]
]
"""residue_len = max_seq_len - torch.sum(seq_lens[sample_idx])
block_attn_masks.append(
torch.tril(
torch.ones(
residue_len, residue_len, dtype=torch.bool, device=seq_lens[sample_idx].device
)
)
)"""
batch_block_attn_masks.append(torch.block_diag(*block_attn_masks))
return torch.stack(batch_block_attn_masks)[:, None, :, :]
def _get_document_ids_from_seq_lens(
seq_lens: list[torch.Tensor],
) -> torch.Tensor:
"""
Convert a batch tensor of seq lens into integer IDs denoting sample ownership.
For example, seq_lens = [2, 3, 1] would return [0, 0, 1, 1, 1, 2].
Args:
seq_lens (List[torch.Tensor]): Sequence lengths of samples in each pack in the batch,
shape (batch_size, n), where n is the max number of sequences in a pack and can vary
across packs.
Returns:
Tensor: Document IDs of shape (batch_size, max_seq_len).
"""
batch_size = len(seq_lens)
batch_document_ids = []
for sample_idx in range(batch_size):
# We assume seq lens sum to max seq lens, so document_ids should be of
# shape (max_seq_len, )
document_ids = torch.cat(
[
torch.full((seq_len,), i, dtype=torch.long, device=seq_len.device)
for i, seq_len in enumerate(seq_lens[sample_idx])
]
)
batch_document_ids.append(document_ids)
batch_document_ids = torch.stack(batch_document_ids)
return batch_document_ids
def packed_block_causal_mask(
seq_lens: list[torch.Tensor], totalseqlens: list[int]
) -> _MaskType:
"""
Create a block causal document mask for a batch of packed sequences. If
flex attention is supported by the current hardware, block causal logic and
passing this into :func:`torch.nn.attention.flex_attention.create_block_mask`.
The resultant BlockMask is a compressed representation of the full block causal
mask. If on an older version, a standard 2D block causal mask is created and returned.
Args:
seq_lens (List[torch.Tensor]): Sequence lengths of samples in each pack in the batch,
shape (batch_size, n), where n is the max number of sequences in a pack and can vary
across packs.
Returns:
_MaskType: BlockMask or Tensor if torch version < 2.5.0.
"""
document_ids = _get_document_ids_from_seq_lens(seq_lens)
batch_size , max_seq_len = document_ids.shape
document_ids = document_ids.to("cuda")
totalseqlens = totalseqlens.to("cuda")
# Instead of passing a tensor mask, flex attention requires a mask_mod function
# that determines which elements of QK^T should be included in the attention
# computation prior to the softmax. For sample packing, we need both the
# logic for both causal mask and document mask. See PyTorch's official
# blog post for more details: https://pytorch.org/blog/flexattention/#mask-mods
def mask_mod(b, h, q_idx, kv_idx):
"""
Defines the logic of a block causal mask by combining both a standard causal mask
and a block diagonal document mask.
See :func:`~torchtune.modules.attention_utils.create_block_causal_mask`
for an illustration.
"""
causal_mask = q_idx >= kv_idx
document_mask = document_ids[b, q_idx] == document_ids[b, kv_idx]
finite_mask = q_idx < totalseqlens[b]
return causal_mask & document_mask & finite_mask
return create_block_causal_mask_flex(
mask_mod,
batch_size,
None,
max_seq_len,
max_seq_len,
device="cuda",
BLOCK_SIZE=512,
)

View File

@@ -4,7 +4,6 @@ shared axolotl collators for multipack, mamba, multimodal
from .batching import ( # noqa: F401
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq,
FlexBatchSamplerDataCollatorForSeq2Seq,
PretrainingBatchSamplerDataCollatorForSeq2Seq,
V2BatchSamplerDataCollatorForSeq2Seq,
)

View File

@@ -3,21 +3,12 @@ DataCollator for axolotl to pad labels and position_ids for packed sequences
"""
from dataclasses import dataclass
from typing import Any, List, Optional, Union
from typing import Any, Optional, Union
import numpy as np
import torch
from transformers import PreTrainedTokenizerBase
from transformers.utils import PaddingStrategy
from axolotl.monkeypatch.flex_attn import (
create_block_causal_mask,
packed_block_causal_mask,
)
from axolotl.monkeypatch.utils import (
get_packed_mask_from_pos_ids,
)
@dataclass
class DataCollatorForSeq2Seq:
@@ -160,42 +151,6 @@ class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
return super().__call__(out_features, return_tensors=return_tensors)
@dataclass
class FlexBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
"""
Collator for multipack specific to Flex Attention using the BatchSampler
"""
def __call__(self, features, return_tensors=None):
if not isinstance(features[0], list):
features = [features]
out_features = [{} for _ in features]
for i, features_ in enumerate(features):
for feature in features_[0].keys():
if feature == "length":
continue
elif feature == "attention_mask":
"""arrays = [
i * np.array(item[feature])
for i, item in enumerate(features_)
if feature in item
]
out_features[i][feature] = np.concatenate(arrays)"""
continue
else:
arrays = [
np.array(item[feature]) for item in features_ if feature in item
]
out_features[i][feature] = np.concatenate(arrays)
out = super().__call__(out_features, return_tensors=return_tensors)
# collated_seq_lens, totalseqlens = get_seqlens_from_pos_ids(out["position_ids"])
# out["attention_mask"] = packed_block_causal_mask(collated_seq_lens, totalseqlens)
out["attention_mask"] = get_packed_mask_from_pos_ids(out["position_ids"])
# out["attention_mask"] = create_block_causal_mask(collated_seq_lens, max_seq_len)
return out
@dataclass
class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
"""