vanills mask
This commit is contained in:
@@ -1,10 +1,11 @@
|
|||||||
'''
|
"""
|
||||||
Taken from https://github.com/pytorch/torchtune/blob/main/torchtune/modules/attention_utils.py
|
Taken from https://github.com/pytorch/torchtune/blob/main/torchtune/modules/attention_utils.py
|
||||||
'''
|
"""
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch.nn.attention.flex_attention import BlockMask
|
||||||
from torch.nn.attention.flex_attention import (
|
from torch.nn.attention.flex_attention import (
|
||||||
BlockMask,
|
|
||||||
create_block_mask as create_block_causal_mask_flex,
|
create_block_mask as create_block_causal_mask_flex,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -46,7 +47,7 @@ def create_block_causal_mask(seq_lens: list[torch.Tensor]) -> torch.Tensor:
|
|||||||
]
|
]
|
||||||
|
|
||||||
batch_block_attn_masks.append(torch.block_diag(*block_attn_masks))
|
batch_block_attn_masks.append(torch.block_diag(*block_attn_masks))
|
||||||
return torch.stack(batch_block_attn_masks)
|
return torch.stack(batch_block_attn_masks)[:,None,:,:]
|
||||||
|
|
||||||
|
|
||||||
def _get_document_ids_from_seq_lens(
|
def _get_document_ids_from_seq_lens(
|
||||||
@@ -79,6 +80,7 @@ def _get_document_ids_from_seq_lens(
|
|||||||
batch_document_ids = torch.stack(batch_document_ids)
|
batch_document_ids = torch.stack(batch_document_ids)
|
||||||
return batch_document_ids
|
return batch_document_ids
|
||||||
|
|
||||||
|
|
||||||
def packed_block_causal_mask(
|
def packed_block_causal_mask(
|
||||||
seq_lens: list[torch.Tensor],
|
seq_lens: list[torch.Tensor],
|
||||||
) -> _MaskType:
|
) -> _MaskType:
|
||||||
@@ -97,7 +99,7 @@ def packed_block_causal_mask(
|
|||||||
Returns:
|
Returns:
|
||||||
_MaskType: BlockMask or Tensor if torch version < 2.5.0.
|
_MaskType: BlockMask or Tensor if torch version < 2.5.0.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
document_ids = _get_document_ids_from_seq_lens(seq_lens)
|
document_ids = _get_document_ids_from_seq_lens(seq_lens)
|
||||||
batch_size, max_seq_len = document_ids.shape
|
batch_size, max_seq_len = document_ids.shape
|
||||||
document_ids = document_ids.to("cuda")
|
document_ids = document_ids.to("cuda")
|
||||||
@@ -127,5 +129,3 @@ def packed_block_causal_mask(
|
|||||||
max_seq_len,
|
max_seq_len,
|
||||||
device="cuda",
|
device="cuda",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -10,8 +10,11 @@ import torch
|
|||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
from transformers.utils import PaddingStrategy
|
from transformers.utils import PaddingStrategy
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.flex_attn import (
|
||||||
|
create_block_causal_mask,
|
||||||
|
packed_block_causal_mask,
|
||||||
|
)
|
||||||
from axolotl.monkeypatch.utils import get_seqlens_from_pos_ids
|
from axolotl.monkeypatch.utils import get_seqlens_from_pos_ids
|
||||||
from axolotl.monkeypatch.flex_attn import create_block_causal_mask, packed_block_causal_mask
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -167,7 +170,7 @@ class FlexBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
out_features = [{} for _ in features]
|
out_features = [{} for _ in features]
|
||||||
for i, features_ in enumerate(features):
|
for i, features_ in enumerate(features):
|
||||||
for feature in features_[0].keys():
|
for feature in features_[0].keys():
|
||||||
if feature in {"length" , "attention_mask"}:
|
if feature in {"length", "attention_mask"}:
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
arrays = [
|
arrays = [
|
||||||
|
|||||||
Reference in New Issue
Block a user