flex batching WIP

This commit is contained in:
Sunny Liu
2025-01-30 14:04:59 -05:00
parent 96ad741cd5
commit 065f6d477e

View File

@@ -1,15 +1,17 @@
'''
Taken from https://github.com/pytorch/torchtune/blob/main/torchtune/modules/attention_utils.py
'''
from typing import Union
import torch
from torch.nn.attention.flex_attention import (
BlockMask,
create_block_mask as create_block_causal_mask_flex,
flex_attention,
)
_MaskType = Union[torch.Tensor, BlockMask]
def _get_document_ids_from_seq_lens(
seq_lens: List[torch.Tensor],
seq_lens: list[torch.Tensor],
) -> torch.Tensor:
"""
Convert a batch tensor of seq lens into integer IDs denoting sample ownership.
@@ -39,7 +41,7 @@ def _get_document_ids_from_seq_lens(
return batch_document_ids
def packed_block_causal_mask(
seq_lens: List[torch.Tensor],
seq_lens: list[torch.Tensor],
) -> _MaskType:
"""
Create a block causal document mask for a batch of packed sequences. If