remove unuse
This commit is contained in:
@@ -10,7 +10,6 @@ later on to pad the datasets.
|
|||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
|
||||||
from datasets import Dataset, IterableDataset
|
from datasets import Dataset, IterableDataset
|
||||||
|
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
@@ -103,134 +102,3 @@ def wrap_dataset_for_tokenized_prompt(
|
|||||||
**map_kwargs,
|
**map_kwargs,
|
||||||
)
|
)
|
||||||
return TokenizedPromptDataset(prompt_tokenizer, dataset, **kwargs)
|
return TokenizedPromptDataset(prompt_tokenizer, dataset, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
# TODO: this isn't the best since it can't interleave datasets.
|
|
||||||
# NOTE: this is only used in a test. Can it be deleted?
|
|
||||||
class ConstantLengthDataset(IterableDataset):
|
|
||||||
"""Iterable dataset that returns constant length chunks of tokens from stream of
|
|
||||||
text files.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tokenizer: The processor used for processing the data.
|
|
||||||
dataset: Dataset with text files.
|
|
||||||
seq_length: Length of token sequences to return.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__( # pylint: disable=super-init-not-called
|
|
||||||
self,
|
|
||||||
tokenizer,
|
|
||||||
datasets,
|
|
||||||
seq_length=2048,
|
|
||||||
):
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.concat_token_id = tokenizer.eos_token_id
|
|
||||||
self.datasets: list[IterableDataset] = datasets
|
|
||||||
self.seq_length = seq_length
|
|
||||||
|
|
||||||
vocab_size = len(tokenizer.get_vocab())
|
|
||||||
|
|
||||||
if vocab_size <= torch.iinfo(torch.int16).max:
|
|
||||||
self.tokens_dtype = torch.int16
|
|
||||||
elif vocab_size <= torch.iinfo(torch.int32).max:
|
|
||||||
self.tokens_dtype = torch.int32
|
|
||||||
else:
|
|
||||||
self.tokens_dtype = torch.int64
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
buffer = {
|
|
||||||
"input_ids": [],
|
|
||||||
"attention_mask": [],
|
|
||||||
"labels": [],
|
|
||||||
"position_ids": [],
|
|
||||||
}
|
|
||||||
buffer_len = 0
|
|
||||||
for dataset in self.datasets:
|
|
||||||
idx = 0
|
|
||||||
iterator = iter(dataset)
|
|
||||||
more_examples = True
|
|
||||||
while more_examples:
|
|
||||||
try:
|
|
||||||
example = next(iterator)
|
|
||||||
idx += 1
|
|
||||||
except StopIteration:
|
|
||||||
more_examples = False
|
|
||||||
example = None
|
|
||||||
|
|
||||||
add_concat_token = False
|
|
||||||
if example:
|
|
||||||
example_len = len(example["input_ids"])
|
|
||||||
add_concat_token = example["input_ids"][-1] != self.concat_token_id
|
|
||||||
else:
|
|
||||||
example_len = 0
|
|
||||||
|
|
||||||
if not example_len or (
|
|
||||||
buffer_len + int(add_concat_token) + example_len > self.seq_length
|
|
||||||
):
|
|
||||||
if buffer["input_ids"]:
|
|
||||||
input_ids = torch.cat(buffer["input_ids"], dim=-1)[
|
|
||||||
: self.seq_length
|
|
||||||
]
|
|
||||||
attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[
|
|
||||||
: self.seq_length
|
|
||||||
]
|
|
||||||
position_ids = torch.cat(buffer["position_ids"], dim=-1)[
|
|
||||||
: self.seq_length
|
|
||||||
]
|
|
||||||
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
|
|
||||||
if labels.size() == input_ids.size() and (
|
|
||||||
attention_mask.size() == input_ids.size()
|
|
||||||
):
|
|
||||||
yield {
|
|
||||||
"input_ids": input_ids,
|
|
||||||
"labels": labels,
|
|
||||||
"attention_mask": attention_mask,
|
|
||||||
"position_ids": position_ids,
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
LOG.warning(
|
|
||||||
"Dropping batch due to tensor size mismatch "
|
|
||||||
f"input_ids: {input_ids.size()}, "
|
|
||||||
f"labels: {labels.size()}, "
|
|
||||||
f"attention_mask: {attention_mask.size()}"
|
|
||||||
)
|
|
||||||
buffer = {
|
|
||||||
"input_ids": [],
|
|
||||||
"attention_mask": [],
|
|
||||||
"labels": [],
|
|
||||||
"position_ids": [],
|
|
||||||
}
|
|
||||||
buffer_len = 0
|
|
||||||
idx = 1
|
|
||||||
|
|
||||||
if example:
|
|
||||||
# FIXME
|
|
||||||
# just going to drop data points that are too long
|
|
||||||
if len(example["input_ids"]) <= self.seq_length:
|
|
||||||
input_ids = example["input_ids"]
|
|
||||||
attention_mask = example["attention_mask"]
|
|
||||||
labels = example["labels"]
|
|
||||||
|
|
||||||
if add_concat_token:
|
|
||||||
input_ids.append(self.concat_token_id)
|
|
||||||
attention_mask.append(1)
|
|
||||||
labels.append(self.concat_token_id)
|
|
||||||
|
|
||||||
input_ids_with_concat = torch.tensor(
|
|
||||||
input_ids, dtype=self.tokens_dtype
|
|
||||||
)
|
|
||||||
attention_mask_with_concat = torch.tensor(
|
|
||||||
[idx * m for m in attention_mask], dtype=torch.int16
|
|
||||||
)
|
|
||||||
labels_with_concat = torch.tensor(
|
|
||||||
labels, dtype=self.tokens_dtype
|
|
||||||
)
|
|
||||||
position_ids = torch.arange(
|
|
||||||
len(input_ids), dtype=self.tokens_dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
buffer["input_ids"].append(input_ids_with_concat)
|
|
||||||
buffer["attention_mask"].append(attention_mask_with_concat)
|
|
||||||
buffer["labels"].append(labels_with_concat)
|
|
||||||
buffer["position_ids"].append(position_ids)
|
|
||||||
buffer_len += len(input_ids)
|
|
||||||
|
|||||||
@@ -593,7 +593,6 @@ def _merge_datasets_with_strategy(
|
|||||||
LOG.info(f"Merging datasets with mixing strategy: {strategy}...")
|
LOG.info(f"Merging datasets with mixing strategy: {strategy}...")
|
||||||
|
|
||||||
if strategy == "concatenate":
|
if strategy == "concatenate":
|
||||||
# Concatenate only works with non-iterable datasets
|
|
||||||
if not all(isinstance(ds, Dataset) for ds in datasets):
|
if not all(isinstance(ds, Dataset) for ds in datasets):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Cannot concatenate streaming datasets. Use 'round_robin', 'weighted', "
|
"Cannot concatenate streaming datasets. Use 'round_robin', 'weighted', "
|
||||||
@@ -605,7 +604,6 @@ def _merge_datasets_with_strategy(
|
|||||||
if strategy == "weighted":
|
if strategy == "weighted":
|
||||||
return interleave_datasets(datasets, probabilities=weights, seed=cfg.seed)
|
return interleave_datasets(datasets, probabilities=weights, seed=cfg.seed)
|
||||||
if strategy == "random":
|
if strategy == "random":
|
||||||
# Random sampling with equal probability
|
|
||||||
equal_weights = [1.0 / len(datasets)] * len(datasets)
|
equal_weights = [1.0 / len(datasets)] * len(datasets)
|
||||||
return interleave_datasets(datasets, probabilities=equal_weights, seed=cfg.seed)
|
return interleave_datasets(datasets, probabilities=equal_weights, seed=cfg.seed)
|
||||||
raise ValueError(f"Unknown dataset mixing strategy: {strategy}")
|
raise ValueError(f"Unknown dataset mixing strategy: {strategy}")
|
||||||
|
|||||||
@@ -1,16 +1,11 @@
|
|||||||
"""Module for testing dataset sequence packing"""
|
"""Module for testing dataset sequence packing"""
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from datasets import Dataset, load_dataset
|
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset
|
|
||||||
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
|
|
||||||
from axolotl.prompters import AlpacaPrompter
|
|
||||||
from axolotl.train import setup_model_and_trainer
|
from axolotl.train import setup_model_and_trainer
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
@@ -36,43 +31,6 @@ class TestPacking(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_increments_attention(self):
|
|
||||||
prompter = AlpacaPrompter("chat")
|
|
||||||
strat = AlpacaPromptTokenizingStrategy(
|
|
||||||
prompter,
|
|
||||||
self.tokenizer,
|
|
||||||
False,
|
|
||||||
2048,
|
|
||||||
)
|
|
||||||
dateset = load_dataset(
|
|
||||||
"json",
|
|
||||||
data_files=str(Path(__file__).parent / "fixtures/alpaca/alpaca.json"),
|
|
||||||
)["train"]
|
|
||||||
dataset = Dataset.from_list(list(TokenizedPromptDataset(strat, dateset)))
|
|
||||||
|
|
||||||
constant_len_dataset = ConstantLengthDataset(
|
|
||||||
self.tokenizer,
|
|
||||||
[dataset],
|
|
||||||
seq_length=2048,
|
|
||||||
)
|
|
||||||
packed_dataset = Dataset.from_list(list(constant_len_dataset))
|
|
||||||
example = packed_dataset[0]
|
|
||||||
next_bos_index = (
|
|
||||||
example["input_ids"][1:].index(self.tokenizer.bos_token_id) + 1
|
|
||||||
) # add one since we sliced
|
|
||||||
|
|
||||||
# first example doesn't have mask reset
|
|
||||||
assert example["input_ids"][0] == self.tokenizer.bos_token_id
|
|
||||||
assert example["attention_mask"][0] == 1
|
|
||||||
assert example["position_ids"][0] == 0
|
|
||||||
assert example["position_ids"][1] == 1
|
|
||||||
|
|
||||||
# but subsequent one does
|
|
||||||
assert example["input_ids"][next_bos_index] == self.tokenizer.bos_token_id
|
|
||||||
assert example["attention_mask"][next_bos_index] == 2
|
|
||||||
assert example["position_ids"][next_bos_index] == 0
|
|
||||||
assert example["position_ids"][next_bos_index + 1] == 1
|
|
||||||
|
|
||||||
@with_temp_dir
|
@with_temp_dir
|
||||||
def test_lora_packing(self, temp_dir):
|
def test_lora_packing(self, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
|
|||||||
Reference in New Issue
Block a user