flex batching WIP
This commit is contained in:
@@ -1,15 +1,17 @@
|
|||||||
'''
|
'''
|
||||||
Taken from https://github.com/pytorch/torchtune/blob/main/torchtune/modules/attention_utils.py
|
Taken from https://github.com/pytorch/torchtune/blob/main/torchtune/modules/attention_utils.py
|
||||||
'''
|
'''
|
||||||
|
from typing import Union
|
||||||
import torch
|
import torch
|
||||||
from torch.nn.attention.flex_attention import (
|
from torch.nn.attention.flex_attention import (
|
||||||
BlockMask,
|
BlockMask,
|
||||||
create_block_mask as create_block_causal_mask_flex,
|
create_block_mask as create_block_causal_mask_flex,
|
||||||
flex_attention,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_MaskType = Union[torch.Tensor, BlockMask]
|
||||||
|
|
||||||
def _get_document_ids_from_seq_lens(
|
def _get_document_ids_from_seq_lens(
|
||||||
seq_lens: List[torch.Tensor],
|
seq_lens: list[torch.Tensor],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Convert a batch tensor of seq lens into integer IDs denoting sample ownership.
|
Convert a batch tensor of seq lens into integer IDs denoting sample ownership.
|
||||||
@@ -39,7 +41,7 @@ def _get_document_ids_from_seq_lens(
|
|||||||
return batch_document_ids
|
return batch_document_ids
|
||||||
|
|
||||||
def packed_block_causal_mask(
|
def packed_block_causal_mask(
|
||||||
seq_lens: List[torch.Tensor],
|
seq_lens: list[torch.Tensor],
|
||||||
) -> _MaskType:
|
) -> _MaskType:
|
||||||
"""
|
"""
|
||||||
Create a block causal document mask for a batch of packed sequences. If
|
Create a block causal document mask for a batch of packed sequences. If
|
||||||
|
|||||||
Reference in New Issue
Block a user