From e792b54bab0d9260c0ea3cb12d3d8ef20b99f0d7 Mon Sep 17 00:00:00 2001 From: Sunny Liu Date: Fri, 21 Feb 2025 11:23:21 -0500 Subject: [PATCH] remove unnecessary components --- src/axolotl/core/trainer_builder.py | 2 - src/axolotl/monkeypatch/flex_attn.py | 146 ------------------------ src/axolotl/utils/collators/__init__.py | 1 - src/axolotl/utils/collators/batching.py | 47 +------- 4 files changed, 1 insertion(+), 195 deletions(-) delete mode 100644 src/axolotl/monkeypatch/flex_attn.py diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 1ca992b83..0c4f895bd 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -79,7 +79,6 @@ from axolotl.utils.chat_templates import get_chat_template_from_config from axolotl.utils.collators import ( BatchSamplerDataCollatorForSeq2Seq, DataCollatorForSeq2Seq, - FlexBatchSamplerDataCollatorForSeq2Seq, MambaDataCollator, V2BatchSamplerDataCollatorForSeq2Seq, ) @@ -817,7 +816,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): Union[ V2BatchSamplerDataCollatorForSeq2Seq, BatchSamplerDataCollatorForSeq2Seq, - FlexBatchSamplerDataCollatorForSeq2Seq, DataCollatorForSeq2Seq, DataCollatorWithFlattening, RewardDataCollatorWithPadding, diff --git a/src/axolotl/monkeypatch/flex_attn.py b/src/axolotl/monkeypatch/flex_attn.py deleted file mode 100644 index aa95480e2..000000000 --- a/src/axolotl/monkeypatch/flex_attn.py +++ /dev/null @@ -1,146 +0,0 @@ -""" -Taken from https://github.com/pytorch/torchtune/blob/main/torchtune/modules/attention_utils.py -""" -from typing import Union - -import torch -from torch.nn.attention.flex_attention import BlockMask -from torch.nn.attention.flex_attention import ( - create_block_mask as create_block_causal_mask_flex, -) - -_MaskType = Union[torch.Tensor, BlockMask] - - -def create_block_causal_mask( - seq_lens: list[torch.Tensor], max_seq_len: int -) -> torch.Tensor: - """ - Given a batch tensor of seq lens defining the lengths of samples in each pack, - Construct a 2D block causal mask for each pack in the batch. For example, if - a single sample's seq_lens is [3, 2, 1], the mask would be:: - - mask = [ - [1, 0, 0, 0, 0, 0], - [1, 1, 0, 0, 0, 0], - [1, 1, 1, 0, 0, 0], - [0, 0, 0, 1, 0, 0], - [0, 0, 0, 1, 1, 0], - [0, 0, 0, 0, 0, 1], - ] - - Args: - seq_lens (List[torch.Tensor]): Sequence lengths of samples in each pack in the batch, - shape (batch_size, n), where n is the max number of sequences in a pack and can vary - across packs. - - - Returns: - Tensor: Block causal mask of shape (batch_size, max_seq_len, max_seq_len). - """ - batch_block_attn_masks = [] - batch_size = len(seq_lens) - for sample_idx in range(batch_size): - block_attn_masks = [ - torch.trilu( # torch.tril( - torch.ones(seq_len, seq_len, dtype=torch.bool, device=seq_len.device) - ) - for seq_len in seq_lens[sample_idx] - ] - - """residue_len = max_seq_len - torch.sum(seq_lens[sample_idx]) - block_attn_masks.append( - torch.tril( - torch.ones( - residue_len, residue_len, dtype=torch.bool, device=seq_lens[sample_idx].device - ) - ) - )""" - - batch_block_attn_masks.append(torch.block_diag(*block_attn_masks)) - - return torch.stack(batch_block_attn_masks)[:, None, :, :] - - -def _get_document_ids_from_seq_lens( - seq_lens: list[torch.Tensor], -) -> torch.Tensor: - """ - Convert a batch tensor of seq lens into integer IDs denoting sample ownership. - For example, seq_lens = [2, 3, 1] would return [0, 0, 1, 1, 1, 2]. - - Args: - seq_lens (List[torch.Tensor]): Sequence lengths of samples in each pack in the batch, - shape (batch_size, n), where n is the max number of sequences in a pack and can vary - across packs. - - Returns: - Tensor: Document IDs of shape (batch_size, max_seq_len). - """ - batch_size = len(seq_lens) - batch_document_ids = [] - for sample_idx in range(batch_size): - # We assume seq lens sum to max seq lens, so document_ids should be of - # shape (max_seq_len, ) - document_ids = torch.cat( - [ - torch.full((seq_len,), i, dtype=torch.long, device=seq_len.device) - for i, seq_len in enumerate(seq_lens[sample_idx]) - ] - ) - batch_document_ids.append(document_ids) - batch_document_ids = torch.stack(batch_document_ids) - return batch_document_ids - - -def packed_block_causal_mask( - seq_lens: list[torch.Tensor], totalseqlens: list[int] -) -> _MaskType: - """ - Create a block causal document mask for a batch of packed sequences. If - flex attention is supported by the current hardware, block causal logic and - passing this into :func:`torch.nn.attention.flex_attention.create_block_mask`. - The resultant BlockMask is a compressed representation of the full block causal - mask. If on an older version, a standard 2D block causal mask is created and returned. - - Args: - seq_lens (List[torch.Tensor]): Sequence lengths of samples in each pack in the batch, - shape (batch_size, n), where n is the max number of sequences in a pack and can vary - across packs. - - Returns: - _MaskType: BlockMask or Tensor if torch version < 2.5.0. - """ - - document_ids = _get_document_ids_from_seq_lens(seq_lens) - batch_size , max_seq_len = document_ids.shape - document_ids = document_ids.to("cuda") - totalseqlens = totalseqlens.to("cuda") - - # Instead of passing a tensor mask, flex attention requires a mask_mod function - # that determines which elements of QK^T should be included in the attention - # computation prior to the softmax. For sample packing, we need both the - # logic for both causal mask and document mask. See PyTorch's official - # blog post for more details: https://pytorch.org/blog/flexattention/#mask-mods - def mask_mod(b, h, q_idx, kv_idx): - """ - Defines the logic of a block causal mask by combining both a standard causal mask - and a block diagonal document mask. - - See :func:`~torchtune.modules.attention_utils.create_block_causal_mask` - for an illustration. - """ - causal_mask = q_idx >= kv_idx - document_mask = document_ids[b, q_idx] == document_ids[b, kv_idx] - finite_mask = q_idx < totalseqlens[b] - return causal_mask & document_mask & finite_mask - - return create_block_causal_mask_flex( - mask_mod, - batch_size, - None, - max_seq_len, - max_seq_len, - device="cuda", - BLOCK_SIZE=512, - ) diff --git a/src/axolotl/utils/collators/__init__.py b/src/axolotl/utils/collators/__init__.py index 1287dc920..93502b67d 100644 --- a/src/axolotl/utils/collators/__init__.py +++ b/src/axolotl/utils/collators/__init__.py @@ -4,7 +4,6 @@ shared axolotl collators for multipack, mamba, multimodal from .batching import ( # noqa: F401 BatchSamplerDataCollatorForSeq2Seq, DataCollatorForSeq2Seq, - FlexBatchSamplerDataCollatorForSeq2Seq, PretrainingBatchSamplerDataCollatorForSeq2Seq, V2BatchSamplerDataCollatorForSeq2Seq, ) diff --git a/src/axolotl/utils/collators/batching.py b/src/axolotl/utils/collators/batching.py index 72b4a5475..7cf771421 100644 --- a/src/axolotl/utils/collators/batching.py +++ b/src/axolotl/utils/collators/batching.py @@ -3,21 +3,12 @@ DataCollator for axolotl to pad labels and position_ids for packed sequences """ from dataclasses import dataclass -from typing import Any, List, Optional, Union +from typing import Any, Optional, Union import numpy as np -import torch from transformers import PreTrainedTokenizerBase from transformers.utils import PaddingStrategy -from axolotl.monkeypatch.flex_attn import ( - create_block_causal_mask, - packed_block_causal_mask, -) -from axolotl.monkeypatch.utils import ( - get_packed_mask_from_pos_ids, -) - @dataclass class DataCollatorForSeq2Seq: @@ -160,42 +151,6 @@ class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): return super().__call__(out_features, return_tensors=return_tensors) -@dataclass -class FlexBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): - """ - Collator for multipack specific to Flex Attention using the BatchSampler - """ - - def __call__(self, features, return_tensors=None): - if not isinstance(features[0], list): - features = [features] - out_features = [{} for _ in features] - for i, features_ in enumerate(features): - for feature in features_[0].keys(): - if feature == "length": - continue - elif feature == "attention_mask": - """arrays = [ - i * np.array(item[feature]) - for i, item in enumerate(features_) - if feature in item - ] - out_features[i][feature] = np.concatenate(arrays)""" - continue - else: - arrays = [ - np.array(item[feature]) for item in features_ if feature in item - ] - out_features[i][feature] = np.concatenate(arrays) - out = super().__call__(out_features, return_tensors=return_tensors) - - # collated_seq_lens, totalseqlens = get_seqlens_from_pos_ids(out["position_ids"]) - # out["attention_mask"] = packed_block_causal_mask(collated_seq_lens, totalseqlens) - out["attention_mask"] = get_packed_mask_from_pos_ids(out["position_ids"]) - # out["attention_mask"] = create_block_causal_mask(collated_seq_lens, max_seq_len) - return out - - @dataclass class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): """