* fix attetion mask with packing * set position ids and use block diagonal attn mask * fix expand mask for multiple batch items, make sure we pad position_ids * don't move masks to cpu * use multi pack dataloader w random sampler * add position_ids back * more fixes for dataloader integration * est total tokens, fix field loop * more fixes, position_ids seems broken * more fixes for sample packing * use distributed sampler, avoid accelerate prepare * use accelerator prepare for dataloader * fix for position_ids w packing * Update src/axolotl/utils/dataloader.py * validation for sample packing and doc * more fixes for 4k and optimizations * optimized expand mask fn * better handling of variance in multipack dataloader length and trainer hanging when it runs out of data * fix rounding of len of batches to int * better handling so that all devices have the same dataloader len * fix step calc for packing * pass sample packing efficiency to training args * add a test for the mask expansion for sequence packing * only process eval dataset for packing if not None * don't split batches when packing * weighted CE losses * weighted CEL fixes * limit packing to sequences of max seq len * seq_len_multiple for packing * make sure the chunk size is an int * sample_packing_seq_len_multiplier config * use cumulative seq len with var len flash attn v2 w packing * properly calculate max len * fix flash-attn, xformers, packing, support chatml * fix chatml system prompt for openorca, legacy tokenizer opts * add chatml * add unit tests for cum seq lens, add ability to build cu_seq_lens from positional ids, fix prompt test * fix test and pylint checks * more packing and dataset optimizations and fixes * filter w multiple cpus * more fixes and optimizations * fixes and go back to distributed sampler since batch sampler won't work * fix counts by accounting for num devices * fix steps calculation * previous accelerate is still most performant * add numba to requirements. * use custom distributed checks * fix sampler to prevent overfit w new epochs * let's not cleanup the cached datasets * calculate cum seq lens with pos_ids instead of mask, simplify packing params, fix distributed barrier * speed optimizations and set accelerate fsdp env vars * optimize dataset concatenation? * more optimizations for dataset handling * fix import for annotation * manual pre-commit fixes * another sum optimization and bug fix for calc steps * fix packing estimations * fix formatting * pylint problems * add back flash attention branch for handling unpacked sequences seperately * Address PR feedback * add optional sample packing config params to readme
53 lines
1.9 KiB
Python
53 lines
1.9 KiB
Python
"""
|
|
expands the binary attention mask per 3.2.2 of https://arxiv.org/pdf/2107.02027.pdf
|
|
"""
|
|
from typing import Optional
|
|
|
|
import torch
|
|
|
|
|
|
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
|
"""
|
|
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
|
This expansion handles packed sequences so that sequences share the same attention mask integer value
|
|
when they attend to each other within that sequence.
|
|
This expansion transforms the mask to lower triangular form to prevent future peeking.
|
|
"""
|
|
bsz, src_len = mask.size()
|
|
tgt_len = tgt_len if tgt_len is not None else src_len
|
|
|
|
mask = mask.unsqueeze(1).unsqueeze(2)
|
|
mask = mask.expand(bsz, 1, tgt_len, src_len)
|
|
|
|
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
|
|
binary_mask = torch.where(
|
|
mask != 0,
|
|
torch.tensor(1).to(dtype),
|
|
torch.tensor(0).to(dtype),
|
|
)
|
|
|
|
# Create a block-diagonal mask.
|
|
# we multiply by the binary mask so that 0's in the original mask are correctly excluded
|
|
zero_one_mask = torch.eq(mask, mask.transpose(-1, -2)).int() * binary_mask
|
|
|
|
# Now let's create a lower triangular mask of ones that will zero out the upper triangular part
|
|
lower_triangular_ones = torch.tril(torch.ones((tgt_len, src_len), dtype=dtype)).to(
|
|
mask.device
|
|
)
|
|
|
|
# Use the lower triangular mask to zero out the upper triangular part of the zero_one_mask
|
|
masked_zero_one_mask = zero_one_mask * lower_triangular_ones
|
|
inverted_mask = 1.0 - masked_zero_one_mask
|
|
|
|
return inverted_mask.masked_fill(
|
|
inverted_mask.to(torch.bool), torch.finfo(dtype).min
|
|
)
|
|
|
|
|
|
def hijack_expand_mask():
|
|
import transformers
|
|
|
|
transformers.models.llama.modeling_llama._expand_mask = ( # pylint: disable=protected-access
|
|
_expand_mask
|
|
)
|