remove unnecessary components
This commit is contained in:
@@ -79,7 +79,6 @@ from axolotl.utils.chat_templates import get_chat_template_from_config
|
|||||||
from axolotl.utils.collators import (
|
from axolotl.utils.collators import (
|
||||||
BatchSamplerDataCollatorForSeq2Seq,
|
BatchSamplerDataCollatorForSeq2Seq,
|
||||||
DataCollatorForSeq2Seq,
|
DataCollatorForSeq2Seq,
|
||||||
FlexBatchSamplerDataCollatorForSeq2Seq,
|
|
||||||
MambaDataCollator,
|
MambaDataCollator,
|
||||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||||
)
|
)
|
||||||
@@ -817,7 +816,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
Union[
|
Union[
|
||||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||||
BatchSamplerDataCollatorForSeq2Seq,
|
BatchSamplerDataCollatorForSeq2Seq,
|
||||||
FlexBatchSamplerDataCollatorForSeq2Seq,
|
|
||||||
DataCollatorForSeq2Seq,
|
DataCollatorForSeq2Seq,
|
||||||
DataCollatorWithFlattening,
|
DataCollatorWithFlattening,
|
||||||
RewardDataCollatorWithPadding,
|
RewardDataCollatorWithPadding,
|
||||||
|
|||||||
@@ -1,146 +0,0 @@
|
|||||||
"""
|
|
||||||
Taken from https://github.com/pytorch/torchtune/blob/main/torchtune/modules/attention_utils.py
|
|
||||||
"""
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch.nn.attention.flex_attention import BlockMask
|
|
||||||
from torch.nn.attention.flex_attention import (
|
|
||||||
create_block_mask as create_block_causal_mask_flex,
|
|
||||||
)
|
|
||||||
|
|
||||||
_MaskType = Union[torch.Tensor, BlockMask]
|
|
||||||
|
|
||||||
|
|
||||||
def create_block_causal_mask(
|
|
||||||
seq_lens: list[torch.Tensor], max_seq_len: int
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Given a batch tensor of seq lens defining the lengths of samples in each pack,
|
|
||||||
Construct a 2D block causal mask for each pack in the batch. For example, if
|
|
||||||
a single sample's seq_lens is [3, 2, 1], the mask would be::
|
|
||||||
|
|
||||||
mask = [
|
|
||||||
[1, 0, 0, 0, 0, 0],
|
|
||||||
[1, 1, 0, 0, 0, 0],
|
|
||||||
[1, 1, 1, 0, 0, 0],
|
|
||||||
[0, 0, 0, 1, 0, 0],
|
|
||||||
[0, 0, 0, 1, 1, 0],
|
|
||||||
[0, 0, 0, 0, 0, 1],
|
|
||||||
]
|
|
||||||
|
|
||||||
Args:
|
|
||||||
seq_lens (List[torch.Tensor]): Sequence lengths of samples in each pack in the batch,
|
|
||||||
shape (batch_size, n), where n is the max number of sequences in a pack and can vary
|
|
||||||
across packs.
|
|
||||||
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor: Block causal mask of shape (batch_size, max_seq_len, max_seq_len).
|
|
||||||
"""
|
|
||||||
batch_block_attn_masks = []
|
|
||||||
batch_size = len(seq_lens)
|
|
||||||
for sample_idx in range(batch_size):
|
|
||||||
block_attn_masks = [
|
|
||||||
torch.trilu( # torch.tril(
|
|
||||||
torch.ones(seq_len, seq_len, dtype=torch.bool, device=seq_len.device)
|
|
||||||
)
|
|
||||||
for seq_len in seq_lens[sample_idx]
|
|
||||||
]
|
|
||||||
|
|
||||||
"""residue_len = max_seq_len - torch.sum(seq_lens[sample_idx])
|
|
||||||
block_attn_masks.append(
|
|
||||||
torch.tril(
|
|
||||||
torch.ones(
|
|
||||||
residue_len, residue_len, dtype=torch.bool, device=seq_lens[sample_idx].device
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)"""
|
|
||||||
|
|
||||||
batch_block_attn_masks.append(torch.block_diag(*block_attn_masks))
|
|
||||||
|
|
||||||
return torch.stack(batch_block_attn_masks)[:, None, :, :]
|
|
||||||
|
|
||||||
|
|
||||||
def _get_document_ids_from_seq_lens(
|
|
||||||
seq_lens: list[torch.Tensor],
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Convert a batch tensor of seq lens into integer IDs denoting sample ownership.
|
|
||||||
For example, seq_lens = [2, 3, 1] would return [0, 0, 1, 1, 1, 2].
|
|
||||||
|
|
||||||
Args:
|
|
||||||
seq_lens (List[torch.Tensor]): Sequence lengths of samples in each pack in the batch,
|
|
||||||
shape (batch_size, n), where n is the max number of sequences in a pack and can vary
|
|
||||||
across packs.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor: Document IDs of shape (batch_size, max_seq_len).
|
|
||||||
"""
|
|
||||||
batch_size = len(seq_lens)
|
|
||||||
batch_document_ids = []
|
|
||||||
for sample_idx in range(batch_size):
|
|
||||||
# We assume seq lens sum to max seq lens, so document_ids should be of
|
|
||||||
# shape (max_seq_len, )
|
|
||||||
document_ids = torch.cat(
|
|
||||||
[
|
|
||||||
torch.full((seq_len,), i, dtype=torch.long, device=seq_len.device)
|
|
||||||
for i, seq_len in enumerate(seq_lens[sample_idx])
|
|
||||||
]
|
|
||||||
)
|
|
||||||
batch_document_ids.append(document_ids)
|
|
||||||
batch_document_ids = torch.stack(batch_document_ids)
|
|
||||||
return batch_document_ids
|
|
||||||
|
|
||||||
|
|
||||||
def packed_block_causal_mask(
|
|
||||||
seq_lens: list[torch.Tensor], totalseqlens: list[int]
|
|
||||||
) -> _MaskType:
|
|
||||||
"""
|
|
||||||
Create a block causal document mask for a batch of packed sequences. If
|
|
||||||
flex attention is supported by the current hardware, block causal logic and
|
|
||||||
passing this into :func:`torch.nn.attention.flex_attention.create_block_mask`.
|
|
||||||
The resultant BlockMask is a compressed representation of the full block causal
|
|
||||||
mask. If on an older version, a standard 2D block causal mask is created and returned.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
seq_lens (List[torch.Tensor]): Sequence lengths of samples in each pack in the batch,
|
|
||||||
shape (batch_size, n), where n is the max number of sequences in a pack and can vary
|
|
||||||
across packs.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
_MaskType: BlockMask or Tensor if torch version < 2.5.0.
|
|
||||||
"""
|
|
||||||
|
|
||||||
document_ids = _get_document_ids_from_seq_lens(seq_lens)
|
|
||||||
batch_size , max_seq_len = document_ids.shape
|
|
||||||
document_ids = document_ids.to("cuda")
|
|
||||||
totalseqlens = totalseqlens.to("cuda")
|
|
||||||
|
|
||||||
# Instead of passing a tensor mask, flex attention requires a mask_mod function
|
|
||||||
# that determines which elements of QK^T should be included in the attention
|
|
||||||
# computation prior to the softmax. For sample packing, we need both the
|
|
||||||
# logic for both causal mask and document mask. See PyTorch's official
|
|
||||||
# blog post for more details: https://pytorch.org/blog/flexattention/#mask-mods
|
|
||||||
def mask_mod(b, h, q_idx, kv_idx):
|
|
||||||
"""
|
|
||||||
Defines the logic of a block causal mask by combining both a standard causal mask
|
|
||||||
and a block diagonal document mask.
|
|
||||||
|
|
||||||
See :func:`~torchtune.modules.attention_utils.create_block_causal_mask`
|
|
||||||
for an illustration.
|
|
||||||
"""
|
|
||||||
causal_mask = q_idx >= kv_idx
|
|
||||||
document_mask = document_ids[b, q_idx] == document_ids[b, kv_idx]
|
|
||||||
finite_mask = q_idx < totalseqlens[b]
|
|
||||||
return causal_mask & document_mask & finite_mask
|
|
||||||
|
|
||||||
return create_block_causal_mask_flex(
|
|
||||||
mask_mod,
|
|
||||||
batch_size,
|
|
||||||
None,
|
|
||||||
max_seq_len,
|
|
||||||
max_seq_len,
|
|
||||||
device="cuda",
|
|
||||||
BLOCK_SIZE=512,
|
|
||||||
)
|
|
||||||
@@ -4,7 +4,6 @@ shared axolotl collators for multipack, mamba, multimodal
|
|||||||
from .batching import ( # noqa: F401
|
from .batching import ( # noqa: F401
|
||||||
BatchSamplerDataCollatorForSeq2Seq,
|
BatchSamplerDataCollatorForSeq2Seq,
|
||||||
DataCollatorForSeq2Seq,
|
DataCollatorForSeq2Seq,
|
||||||
FlexBatchSamplerDataCollatorForSeq2Seq,
|
|
||||||
PretrainingBatchSamplerDataCollatorForSeq2Seq,
|
PretrainingBatchSamplerDataCollatorForSeq2Seq,
|
||||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -3,21 +3,12 @@ DataCollator for axolotl to pad labels and position_ids for packed sequences
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, List, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
from transformers.utils import PaddingStrategy
|
from transformers.utils import PaddingStrategy
|
||||||
|
|
||||||
from axolotl.monkeypatch.flex_attn import (
|
|
||||||
create_block_causal_mask,
|
|
||||||
packed_block_causal_mask,
|
|
||||||
)
|
|
||||||
from axolotl.monkeypatch.utils import (
|
|
||||||
get_packed_mask_from_pos_ids,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DataCollatorForSeq2Seq:
|
class DataCollatorForSeq2Seq:
|
||||||
@@ -160,42 +151,6 @@ class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
return super().__call__(out_features, return_tensors=return_tensors)
|
return super().__call__(out_features, return_tensors=return_tensors)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class FlexBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|
||||||
"""
|
|
||||||
Collator for multipack specific to Flex Attention using the BatchSampler
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __call__(self, features, return_tensors=None):
|
|
||||||
if not isinstance(features[0], list):
|
|
||||||
features = [features]
|
|
||||||
out_features = [{} for _ in features]
|
|
||||||
for i, features_ in enumerate(features):
|
|
||||||
for feature in features_[0].keys():
|
|
||||||
if feature == "length":
|
|
||||||
continue
|
|
||||||
elif feature == "attention_mask":
|
|
||||||
"""arrays = [
|
|
||||||
i * np.array(item[feature])
|
|
||||||
for i, item in enumerate(features_)
|
|
||||||
if feature in item
|
|
||||||
]
|
|
||||||
out_features[i][feature] = np.concatenate(arrays)"""
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
arrays = [
|
|
||||||
np.array(item[feature]) for item in features_ if feature in item
|
|
||||||
]
|
|
||||||
out_features[i][feature] = np.concatenate(arrays)
|
|
||||||
out = super().__call__(out_features, return_tensors=return_tensors)
|
|
||||||
|
|
||||||
# collated_seq_lens, totalseqlens = get_seqlens_from_pos_ids(out["position_ids"])
|
|
||||||
# out["attention_mask"] = packed_block_causal_mask(collated_seq_lens, totalseqlens)
|
|
||||||
out["attention_mask"] = get_packed_mask_from_pos_ids(out["position_ids"])
|
|
||||||
# out["attention_mask"] = create_block_causal_mask(collated_seq_lens, max_seq_len)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user