Attention mask and position id fixes for packing (#285)

* fix attetion mask with packing

* set position ids and use block diagonal attn mask

* fix expand mask for multiple batch items, make sure we pad position_ids

* don't move masks to cpu

* use multi pack dataloader w random sampler

* add position_ids back

* more fixes for dataloader integration

* est total tokens, fix field loop

* more fixes, position_ids seems broken

* more fixes for sample packing

* use distributed sampler, avoid accelerate prepare

* use accelerator prepare for dataloader

* fix for position_ids w packing

* Update src/axolotl/utils/dataloader.py

* validation for sample packing and doc

* more fixes for 4k and optimizations

* optimized expand mask fn

* better handling of variance in multipack dataloader length and trainer hanging when it runs out of data

* fix rounding of len of batches to int

* better handling so that all devices have the same dataloader len

* fix step calc for packing

* pass sample packing efficiency to training args

* add a test for the mask expansion for sequence packing

* only process eval dataset for packing if not None

* don't split batches when packing

* weighted CE losses

* weighted CEL fixes

* limit packing to sequences of max seq len

* seq_len_multiple for packing

* make sure the chunk size is an int

* sample_packing_seq_len_multiplier config

* use cumulative seq len with var len flash attn v2 w packing

* properly calculate max len

* fix flash-attn, xformers, packing, support chatml

* fix chatml system prompt for openorca, legacy tokenizer opts

* add chatml

* add unit tests for cum seq lens, add ability to build cu_seq_lens from positional ids, fix prompt test

* fix test and pylint checks

* more packing and dataset optimizations and fixes

* filter w multiple cpus

* more fixes and optimizations

* fixes and go back to distributed sampler since batch sampler won't work

* fix counts by accounting for num devices

* fix steps calculation

* previous accelerate is still most performant

* add numba to requirements.

* use custom distributed checks

* fix sampler to prevent overfit w new epochs

* let's not cleanup the cached datasets

* calculate cum seq lens with pos_ids instead of mask, simplify packing params, fix distributed barrier

* speed optimizations and set accelerate fsdp env vars

* optimize dataset concatenation?

* more optimizations for dataset handling

* fix import for annotation

* manual pre-commit fixes

* another sum optimization and bug fix for calc steps

* fix packing estimations

* fix formatting

* pylint problems

* add back flash attention branch for handling unpacked sequences seperately

* Address PR feedback

* add optional sample packing config params to readme
This commit is contained in:
Wing Lian
2023-08-12 15:14:56 -04:00
committed by GitHub
parent a276c9c88d
commit 2bb0b78975
23 changed files with 1218 additions and 70 deletions

View File

@@ -5,7 +5,7 @@ import os
from typing import List
import torch
from datasets import IterableDataset
from datasets import Dataset, IterableDataset
from .prompt_tokenizers import PromptTokenizingStrategy
@@ -18,9 +18,9 @@ from .prompt_tokenizers import PromptTokenizingStrategy
LOG = logging.getLogger("axolotl")
class TokenizedPromptDataset(IterableDataset):
class TokenizedPromptDataset(Dataset):
"""
Iterable dataset that returns tokenized prompts from a stream of text files.
Dataset that returns tokenized prompts from a stream of text files.
Args:
prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for proccessing the data.
dataset (dataset.Dataset): Dataset with text files.
@@ -30,19 +30,18 @@ class TokenizedPromptDataset(IterableDataset):
self,
prompt_tokenizer: PromptTokenizingStrategy,
dataset: IterableDataset,
**kwargs,
):
self.prompt_tokenizer = prompt_tokenizer
self.dataset = dataset
super().__init__(self.process(dataset).data, **kwargs)
def __iter__(self):
features = self.dataset.features.keys()
num_proc = os.cpu_count()
return iter(
self.dataset.map(
self.prompt_tokenizer.tokenize_prompt,
num_proc=num_proc,
remove_columns=features,
)
def process(self, dataset):
features = dataset.features.keys()
num_proc = min(64, os.cpu_count())
return dataset.map(
self.prompt_tokenizer.tokenize_prompt,
num_proc=num_proc,
remove_columns=features,
)
@@ -77,14 +76,21 @@ class ConstantLengthDataset(IterableDataset):
self.tokens_dtype = torch.int64
def __iter__(self):
buffer = {"input_ids": [], "attention_mask": [], "labels": []}
buffer = {
"input_ids": [],
"attention_mask": [],
"labels": [],
"position_ids": [],
}
buffer_len = 0
for dataset in self.datasets:
idx = 0
iterator = iter(dataset)
more_examples = True
while more_examples:
try:
example = next(iterator)
idx += 1
except StopIteration:
more_examples = False
example = None
@@ -106,6 +112,9 @@ class ConstantLengthDataset(IterableDataset):
attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[
: self.seq_length
]
position_ids = torch.cat(buffer["position_ids"], dim=-1)[
: self.seq_length
]
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
if labels.size() == input_ids.size() and (
attention_mask.size() == input_ids.size()
@@ -114,6 +123,7 @@ class ConstantLengthDataset(IterableDataset):
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask,
"position_ids": position_ids,
}
else:
LOG.warning(
@@ -123,8 +133,10 @@ class ConstantLengthDataset(IterableDataset):
"input_ids": [],
"attention_mask": [],
"labels": [],
"position_ids": [],
}
buffer_len = 0
idx = 1
if example:
# FIXME
@@ -133,11 +145,6 @@ class ConstantLengthDataset(IterableDataset):
input_ids = example["input_ids"]
attention_mask = example["attention_mask"]
labels = example["labels"]
if (
buffer["input_ids"]
and input_ids[0] == self.tokenizer.bos_token_id
):
attention_mask[0] = 0
if add_concat_token:
input_ids.append(self.concat_token_id)
@@ -148,13 +155,17 @@ class ConstantLengthDataset(IterableDataset):
input_ids, dtype=self.tokens_dtype
)
attention_mask_with_concat = torch.tensor(
attention_mask, dtype=self.tokens_dtype
[idx * m for m in attention_mask], dtype=torch.int16
)
labels_with_concat = torch.tensor(
labels, dtype=self.tokens_dtype
)
position_ids = torch.arange(
len(input_ids), dtype=self.tokens_dtype
)
buffer["input_ids"].append(input_ids_with_concat)
buffer["attention_mask"].append(attention_mask_with_concat)
buffer["labels"].append(labels_with_concat)
buffer["position_ids"].append(position_ids)
buffer_len += len(input_ids)