flex batching WIP

This commit is contained in:
Sunny Liu
2025-01-30 14:04:59 -05:00
parent 96ad741cd5
commit 065f6d477e

View File

@@ -1,15 +1,17 @@
''' '''
Taken from https://github.com/pytorch/torchtune/blob/main/torchtune/modules/attention_utils.py Taken from https://github.com/pytorch/torchtune/blob/main/torchtune/modules/attention_utils.py
''' '''
from typing import Union
import torch import torch
from torch.nn.attention.flex_attention import ( from torch.nn.attention.flex_attention import (
BlockMask, BlockMask,
create_block_mask as create_block_causal_mask_flex, create_block_mask as create_block_causal_mask_flex,
flex_attention,
) )
_MaskType = Union[torch.Tensor, BlockMask]
def _get_document_ids_from_seq_lens( def _get_document_ids_from_seq_lens(
seq_lens: List[torch.Tensor], seq_lens: list[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Convert a batch tensor of seq lens into integer IDs denoting sample ownership. Convert a batch tensor of seq lens into integer IDs denoting sample ownership.
@@ -39,7 +41,7 @@ def _get_document_ids_from_seq_lens(
return batch_document_ids return batch_document_ids
def packed_block_causal_mask( def packed_block_causal_mask(
seq_lens: List[torch.Tensor], seq_lens: list[torch.Tensor],
) -> _MaskType: ) -> _MaskType:
""" """
Create a block causal document mask for a batch of packed sequences. If Create a block causal document mask for a batch of packed sequences. If