Attention mask and position id fixes for packing (#285)
* fix attetion mask with packing * set position ids and use block diagonal attn mask * fix expand mask for multiple batch items, make sure we pad position_ids * don't move masks to cpu * use multi pack dataloader w random sampler * add position_ids back * more fixes for dataloader integration * est total tokens, fix field loop * more fixes, position_ids seems broken * more fixes for sample packing * use distributed sampler, avoid accelerate prepare * use accelerator prepare for dataloader * fix for position_ids w packing * Update src/axolotl/utils/dataloader.py * validation for sample packing and doc * more fixes for 4k and optimizations * optimized expand mask fn * better handling of variance in multipack dataloader length and trainer hanging when it runs out of data * fix rounding of len of batches to int * better handling so that all devices have the same dataloader len * fix step calc for packing * pass sample packing efficiency to training args * add a test for the mask expansion for sequence packing * only process eval dataset for packing if not None * don't split batches when packing * weighted CE losses * weighted CEL fixes * limit packing to sequences of max seq len * seq_len_multiple for packing * make sure the chunk size is an int * sample_packing_seq_len_multiplier config * use cumulative seq len with var len flash attn v2 w packing * properly calculate max len * fix flash-attn, xformers, packing, support chatml * fix chatml system prompt for openorca, legacy tokenizer opts * add chatml * add unit tests for cum seq lens, add ability to build cu_seq_lens from positional ids, fix prompt test * fix test and pylint checks * more packing and dataset optimizations and fixes * filter w multiple cpus * more fixes and optimizations * fixes and go back to distributed sampler since batch sampler won't work * fix counts by accounting for num devices * fix steps calculation * previous accelerate is still most performant * add numba to requirements. * use custom distributed checks * fix sampler to prevent overfit w new epochs * let's not cleanup the cached datasets * calculate cum seq lens with pos_ids instead of mask, simplify packing params, fix distributed barrier * speed optimizations and set accelerate fsdp env vars * optimize dataset concatenation? * more optimizations for dataset handling * fix import for annotation * manual pre-commit fixes * another sum optimization and bug fix for calc steps * fix packing estimations * fix formatting * pylint problems * add back flash attention branch for handling unpacked sequences seperately * Address PR feedback * add optional sample packing config params to readme
This commit is contained in:
@@ -5,7 +5,7 @@ import os
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from datasets import IterableDataset
|
||||
from datasets import Dataset, IterableDataset
|
||||
|
||||
from .prompt_tokenizers import PromptTokenizingStrategy
|
||||
|
||||
@@ -18,9 +18,9 @@ from .prompt_tokenizers import PromptTokenizingStrategy
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
class TokenizedPromptDataset(IterableDataset):
|
||||
class TokenizedPromptDataset(Dataset):
|
||||
"""
|
||||
Iterable dataset that returns tokenized prompts from a stream of text files.
|
||||
Dataset that returns tokenized prompts from a stream of text files.
|
||||
Args:
|
||||
prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for proccessing the data.
|
||||
dataset (dataset.Dataset): Dataset with text files.
|
||||
@@ -30,19 +30,18 @@ class TokenizedPromptDataset(IterableDataset):
|
||||
self,
|
||||
prompt_tokenizer: PromptTokenizingStrategy,
|
||||
dataset: IterableDataset,
|
||||
**kwargs,
|
||||
):
|
||||
self.prompt_tokenizer = prompt_tokenizer
|
||||
self.dataset = dataset
|
||||
super().__init__(self.process(dataset).data, **kwargs)
|
||||
|
||||
def __iter__(self):
|
||||
features = self.dataset.features.keys()
|
||||
num_proc = os.cpu_count()
|
||||
return iter(
|
||||
self.dataset.map(
|
||||
self.prompt_tokenizer.tokenize_prompt,
|
||||
num_proc=num_proc,
|
||||
remove_columns=features,
|
||||
)
|
||||
def process(self, dataset):
|
||||
features = dataset.features.keys()
|
||||
num_proc = min(64, os.cpu_count())
|
||||
return dataset.map(
|
||||
self.prompt_tokenizer.tokenize_prompt,
|
||||
num_proc=num_proc,
|
||||
remove_columns=features,
|
||||
)
|
||||
|
||||
|
||||
@@ -77,14 +76,21 @@ class ConstantLengthDataset(IterableDataset):
|
||||
self.tokens_dtype = torch.int64
|
||||
|
||||
def __iter__(self):
|
||||
buffer = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
buffer = {
|
||||
"input_ids": [],
|
||||
"attention_mask": [],
|
||||
"labels": [],
|
||||
"position_ids": [],
|
||||
}
|
||||
buffer_len = 0
|
||||
for dataset in self.datasets:
|
||||
idx = 0
|
||||
iterator = iter(dataset)
|
||||
more_examples = True
|
||||
while more_examples:
|
||||
try:
|
||||
example = next(iterator)
|
||||
idx += 1
|
||||
except StopIteration:
|
||||
more_examples = False
|
||||
example = None
|
||||
@@ -106,6 +112,9 @@ class ConstantLengthDataset(IterableDataset):
|
||||
attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[
|
||||
: self.seq_length
|
||||
]
|
||||
position_ids = torch.cat(buffer["position_ids"], dim=-1)[
|
||||
: self.seq_length
|
||||
]
|
||||
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
|
||||
if labels.size() == input_ids.size() and (
|
||||
attention_mask.size() == input_ids.size()
|
||||
@@ -114,6 +123,7 @@ class ConstantLengthDataset(IterableDataset):
|
||||
"input_ids": input_ids,
|
||||
"labels": labels,
|
||||
"attention_mask": attention_mask,
|
||||
"position_ids": position_ids,
|
||||
}
|
||||
else:
|
||||
LOG.warning(
|
||||
@@ -123,8 +133,10 @@ class ConstantLengthDataset(IterableDataset):
|
||||
"input_ids": [],
|
||||
"attention_mask": [],
|
||||
"labels": [],
|
||||
"position_ids": [],
|
||||
}
|
||||
buffer_len = 0
|
||||
idx = 1
|
||||
|
||||
if example:
|
||||
# FIXME
|
||||
@@ -133,11 +145,6 @@ class ConstantLengthDataset(IterableDataset):
|
||||
input_ids = example["input_ids"]
|
||||
attention_mask = example["attention_mask"]
|
||||
labels = example["labels"]
|
||||
if (
|
||||
buffer["input_ids"]
|
||||
and input_ids[0] == self.tokenizer.bos_token_id
|
||||
):
|
||||
attention_mask[0] = 0
|
||||
|
||||
if add_concat_token:
|
||||
input_ids.append(self.concat_token_id)
|
||||
@@ -148,13 +155,17 @@ class ConstantLengthDataset(IterableDataset):
|
||||
input_ids, dtype=self.tokens_dtype
|
||||
)
|
||||
attention_mask_with_concat = torch.tensor(
|
||||
attention_mask, dtype=self.tokens_dtype
|
||||
[idx * m for m in attention_mask], dtype=torch.int16
|
||||
)
|
||||
labels_with_concat = torch.tensor(
|
||||
labels, dtype=self.tokens_dtype
|
||||
)
|
||||
position_ids = torch.arange(
|
||||
len(input_ids), dtype=self.tokens_dtype
|
||||
)
|
||||
|
||||
buffer["input_ids"].append(input_ids_with_concat)
|
||||
buffer["attention_mask"].append(attention_mask_with_concat)
|
||||
buffer["labels"].append(labels_with_concat)
|
||||
buffer["position_ids"].append(position_ids)
|
||||
buffer_len += len(input_ids)
|
||||
|
||||
Reference in New Issue
Block a user