flex batching WIP
This commit is contained in:
@@ -1,15 +1,17 @@
|
||||
'''
|
||||
Taken from https://github.com/pytorch/torchtune/blob/main/torchtune/modules/attention_utils.py
|
||||
'''
|
||||
from typing import Union
|
||||
import torch
|
||||
from torch.nn.attention.flex_attention import (
|
||||
BlockMask,
|
||||
create_block_mask as create_block_causal_mask_flex,
|
||||
flex_attention,
|
||||
)
|
||||
|
||||
_MaskType = Union[torch.Tensor, BlockMask]
|
||||
|
||||
def _get_document_ids_from_seq_lens(
|
||||
seq_lens: List[torch.Tensor],
|
||||
seq_lens: list[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Convert a batch tensor of seq lens into integer IDs denoting sample ownership.
|
||||
@@ -39,7 +41,7 @@ def _get_document_ids_from_seq_lens(
|
||||
return batch_document_ids
|
||||
|
||||
def packed_block_causal_mask(
|
||||
seq_lens: List[torch.Tensor],
|
||||
seq_lens: list[torch.Tensor],
|
||||
) -> _MaskType:
|
||||
"""
|
||||
Create a block causal document mask for a batch of packed sequences. If
|
||||
|
||||
Reference in New Issue
Block a user