diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index d63a10e74..6f3f9f466 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -71,6 +71,7 @@ from axolotl.utils.chat_templates import get_chat_template_from_config from axolotl.utils.collators import ( BatchSamplerDataCollatorForSeq2Seq, DataCollatorForSeq2Seq, + FlexBatchSamplerDataCollatorForSeq2Seq, MambaDataCollator, V2BatchSamplerDataCollatorForSeq2Seq, ) @@ -1941,6 +1942,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): Union[ V2BatchSamplerDataCollatorForSeq2Seq, BatchSamplerDataCollatorForSeq2Seq, + FlexBatchSamplerDataCollatorForSeq2Seq, DataCollatorForSeq2Seq, DataCollatorWithFlattening, RewardDataCollatorWithPadding, @@ -1952,6 +1954,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): if "max_length" in kwargs: kwargs.pop("max_length") elif use_batch_sampler_collator: + if self.cfg.flex_attention is True: + collator = FlexBatchSamplerDataCollatorForSeq2Seq if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES: collator = V2BatchSamplerDataCollatorForSeq2Seq elif ( diff --git a/src/axolotl/monkeypatch/flex_attn.py b/src/axolotl/monkeypatch/flex_attn.py new file mode 100644 index 000000000..bd0104e31 --- /dev/null +++ b/src/axolotl/monkeypatch/flex_attn.py @@ -0,0 +1,90 @@ +''' +Taken from https://github.com/pytorch/torchtune/blob/main/torchtune/modules/attention_utils.py +''' +import torch +from torch.nn.attention.flex_attention import ( + BlockMask, + create_block_mask as create_block_causal_mask_flex, + flex_attention, +) + +def _get_document_ids_from_seq_lens( + seq_lens: List[torch.Tensor], +) -> torch.Tensor: + """ + Convert a batch tensor of seq lens into integer IDs denoting sample ownership. + For example, seq_lens = [2, 3, 1] would return [0, 0, 1, 1, 1, 2]. + + Args: + seq_lens (List[torch.Tensor]): Sequence lengths of samples in each pack in the batch, + shape (batch_size, n), where n is the max number of sequences in a pack and can vary + across packs. + + Returns: + Tensor: Document IDs of shape (batch_size, max_seq_len). + """ + batch_size = len(seq_lens) + batch_document_ids = [] + for sample_idx in range(batch_size): + # We assume seq lens sum to max seq lens, so document_ids should be of + # shape (max_seq_len, ) + document_ids = torch.cat( + [ + torch.full((seq_len,), i, dtype=torch.long, device=seq_len.device) + for i, seq_len in enumerate(seq_lens[sample_idx]) + ] + ) + batch_document_ids.append(document_ids) + batch_document_ids = torch.stack(batch_document_ids) + return batch_document_ids + +def packed_block_causal_mask( + seq_lens: List[torch.Tensor], +) -> _MaskType: + """ + Create a block causal document mask for a batch of packed sequences. If + flex attention is supported by the current hardware, block causal logic and + passing this into :func:`torch.nn.attention.flex_attention.create_block_mask`. + The resultant BlockMask is a compressed representation of the full block causal + mask. If on an older version, a standard 2D block causal mask is created and returned. + + Args: + seq_lens (List[torch.Tensor]): Sequence lengths of samples in each pack in the batch, + shape (batch_size, n), where n is the max number of sequences in a pack and can vary + across packs. + + Returns: + _MaskType: BlockMask or Tensor if torch version < 2.5.0. + """ + + document_ids = _get_document_ids_from_seq_lens(seq_lens) + batch_size, max_seq_len = document_ids.shape + document_ids = document_ids.to("cuda") + + # Instead of passing a tensor mask, flex attention requires a mask_mod function + # that determines which elements of QK^T should be included in the attention + # computation prior to the softmax. For sample packing, we need both the + # logic for both causal mask and document mask. See PyTorch's official + # blog post for more details: https://pytorch.org/blog/flexattention/#mask-mods + def mask_mod(b, h, q_idx, kv_idx): + """ + Defines the logic of a block causal mask by combining both a standard causal mask + and a block diagonal document mask. + + See :func:`~torchtune.modules.attention_utils.create_block_causal_mask` + for an illustration. + """ + causal_mask = q_idx >= kv_idx + document_mask = document_ids[b, q_idx] == document_ids[b, kv_idx] + return causal_mask & document_mask + + return create_block_causal_mask_flex( + mask_mod, + batch_size, + None, + max_seq_len, + max_seq_len, + device="cuda", + ) + + diff --git a/src/axolotl/utils/collators/__init__.py b/src/axolotl/utils/collators/__init__.py index 93502b67d..1287dc920 100644 --- a/src/axolotl/utils/collators/__init__.py +++ b/src/axolotl/utils/collators/__init__.py @@ -4,6 +4,7 @@ shared axolotl collators for multipack, mamba, multimodal from .batching import ( # noqa: F401 BatchSamplerDataCollatorForSeq2Seq, DataCollatorForSeq2Seq, + FlexBatchSamplerDataCollatorForSeq2Seq, PretrainingBatchSamplerDataCollatorForSeq2Seq, V2BatchSamplerDataCollatorForSeq2Seq, ) diff --git a/src/axolotl/utils/collators/batching.py b/src/axolotl/utils/collators/batching.py index 39f3d0c04..6508602cf 100644 --- a/src/axolotl/utils/collators/batching.py +++ b/src/axolotl/utils/collators/batching.py @@ -11,6 +11,7 @@ from transformers import PreTrainedTokenizerBase from transformers.utils import PaddingStrategy from axolotl.monkeypatch.utils import get_seqlens_from_pos_ids +from axolotl.monkeypatch.flex_attn import packed_block_causal_mask @dataclass @@ -166,15 +167,8 @@ class FlexBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): out_features = [{} for _ in features] for i, features_ in enumerate(features): for feature in features_[0].keys(): - if feature == "length": + if feature in {"length" , "attention_mask"}: continue - if feature == "attention_mask": - arrays = [ - (i + 1) * np.array(item[feature]) - for i, item in enumerate(features_) - if feature in item - ] - out_features[i][feature] = np.concatenate(arrays) else: arrays = [ np.array(item[feature]) for item in features_ if feature in item @@ -183,10 +177,7 @@ class FlexBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): out = super().__call__(out_features, return_tensors=return_tensors) collated_seq_lens = get_seqlens_from_pos_ids(out["position_ids"]) - - doc_mask = - - out["attention_mask"] + out["attention_mask"] = packed_block_causal_mask(collated_seq_lens) return out