From 7836da9ed90c8f86ebd62054c1b9ab500b39510f Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Wed, 20 Aug 2025 19:26:26 +0000 Subject: [PATCH] remove unuse --- src/axolotl/datasets.py | 132 ------------------------------- src/axolotl/utils/data/shared.py | 2 - tests/test_packed_dataset.py | 42 ---------- 3 files changed, 176 deletions(-) diff --git a/src/axolotl/datasets.py b/src/axolotl/datasets.py index 87f26275f..d70703d6b 100644 --- a/src/axolotl/datasets.py +++ b/src/axolotl/datasets.py @@ -10,7 +10,6 @@ later on to pad the datasets. from typing import Any -import torch from datasets import Dataset, IterableDataset from axolotl.utils.logging import get_logger @@ -103,134 +102,3 @@ def wrap_dataset_for_tokenized_prompt( **map_kwargs, ) return TokenizedPromptDataset(prompt_tokenizer, dataset, **kwargs) - - -# TODO: this isn't the best since it can't interleave datasets. -# NOTE: this is only used in a test. Can it be deleted? -class ConstantLengthDataset(IterableDataset): - """Iterable dataset that returns constant length chunks of tokens from stream of - text files. - - Args: - tokenizer: The processor used for processing the data. - dataset: Dataset with text files. - seq_length: Length of token sequences to return. - """ - - def __init__( # pylint: disable=super-init-not-called - self, - tokenizer, - datasets, - seq_length=2048, - ): - self.tokenizer = tokenizer - self.concat_token_id = tokenizer.eos_token_id - self.datasets: list[IterableDataset] = datasets - self.seq_length = seq_length - - vocab_size = len(tokenizer.get_vocab()) - - if vocab_size <= torch.iinfo(torch.int16).max: - self.tokens_dtype = torch.int16 - elif vocab_size <= torch.iinfo(torch.int32).max: - self.tokens_dtype = torch.int32 - else: - self.tokens_dtype = torch.int64 - - def __iter__(self): - buffer = { - "input_ids": [], - "attention_mask": [], - "labels": [], - "position_ids": [], - } - buffer_len = 0 - for dataset in self.datasets: - idx = 0 - iterator = iter(dataset) - more_examples = True - while more_examples: - try: - example = next(iterator) - idx += 1 - except StopIteration: - more_examples = False - example = None - - add_concat_token = False - if example: - example_len = len(example["input_ids"]) - add_concat_token = example["input_ids"][-1] != self.concat_token_id - else: - example_len = 0 - - if not example_len or ( - buffer_len + int(add_concat_token) + example_len > self.seq_length - ): - if buffer["input_ids"]: - input_ids = torch.cat(buffer["input_ids"], dim=-1)[ - : self.seq_length - ] - attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[ - : self.seq_length - ] - position_ids = torch.cat(buffer["position_ids"], dim=-1)[ - : self.seq_length - ] - labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length] - if labels.size() == input_ids.size() and ( - attention_mask.size() == input_ids.size() - ): - yield { - "input_ids": input_ids, - "labels": labels, - "attention_mask": attention_mask, - "position_ids": position_ids, - } - else: - LOG.warning( - "Dropping batch due to tensor size mismatch " - f"input_ids: {input_ids.size()}, " - f"labels: {labels.size()}, " - f"attention_mask: {attention_mask.size()}" - ) - buffer = { - "input_ids": [], - "attention_mask": [], - "labels": [], - "position_ids": [], - } - buffer_len = 0 - idx = 1 - - if example: - # FIXME - # just going to drop data points that are too long - if len(example["input_ids"]) <= self.seq_length: - input_ids = example["input_ids"] - attention_mask = example["attention_mask"] - labels = example["labels"] - - if add_concat_token: - input_ids.append(self.concat_token_id) - attention_mask.append(1) - labels.append(self.concat_token_id) - - input_ids_with_concat = torch.tensor( - input_ids, dtype=self.tokens_dtype - ) - attention_mask_with_concat = torch.tensor( - [idx * m for m in attention_mask], dtype=torch.int16 - ) - labels_with_concat = torch.tensor( - labels, dtype=self.tokens_dtype - ) - position_ids = torch.arange( - len(input_ids), dtype=self.tokens_dtype - ) - - buffer["input_ids"].append(input_ids_with_concat) - buffer["attention_mask"].append(attention_mask_with_concat) - buffer["labels"].append(labels_with_concat) - buffer["position_ids"].append(position_ids) - buffer_len += len(input_ids) diff --git a/src/axolotl/utils/data/shared.py b/src/axolotl/utils/data/shared.py index 59408f151..606f47ea6 100644 --- a/src/axolotl/utils/data/shared.py +++ b/src/axolotl/utils/data/shared.py @@ -593,7 +593,6 @@ def _merge_datasets_with_strategy( LOG.info(f"Merging datasets with mixing strategy: {strategy}...") if strategy == "concatenate": - # Concatenate only works with non-iterable datasets if not all(isinstance(ds, Dataset) for ds in datasets): raise ValueError( "Cannot concatenate streaming datasets. Use 'round_robin', 'weighted', " @@ -605,7 +604,6 @@ def _merge_datasets_with_strategy( if strategy == "weighted": return interleave_datasets(datasets, probabilities=weights, seed=cfg.seed) if strategy == "random": - # Random sampling with equal probability equal_weights = [1.0 / len(datasets)] * len(datasets) return interleave_datasets(datasets, probabilities=equal_weights, seed=cfg.seed) raise ValueError(f"Unknown dataset mixing strategy: {strategy}") diff --git a/tests/test_packed_dataset.py b/tests/test_packed_dataset.py index 699d5e6cc..992cfa330 100644 --- a/tests/test_packed_dataset.py +++ b/tests/test_packed_dataset.py @@ -1,16 +1,11 @@ """Module for testing dataset sequence packing""" import unittest -from pathlib import Path -from datasets import Dataset, load_dataset from transformers import AutoTokenizer from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets -from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset -from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy -from axolotl.prompters import AlpacaPrompter from axolotl.train import setup_model_and_trainer from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault @@ -36,43 +31,6 @@ class TestPacking(unittest.TestCase): } ) - def test_increments_attention(self): - prompter = AlpacaPrompter("chat") - strat = AlpacaPromptTokenizingStrategy( - prompter, - self.tokenizer, - False, - 2048, - ) - dateset = load_dataset( - "json", - data_files=str(Path(__file__).parent / "fixtures/alpaca/alpaca.json"), - )["train"] - dataset = Dataset.from_list(list(TokenizedPromptDataset(strat, dateset))) - - constant_len_dataset = ConstantLengthDataset( - self.tokenizer, - [dataset], - seq_length=2048, - ) - packed_dataset = Dataset.from_list(list(constant_len_dataset)) - example = packed_dataset[0] - next_bos_index = ( - example["input_ids"][1:].index(self.tokenizer.bos_token_id) + 1 - ) # add one since we sliced - - # first example doesn't have mask reset - assert example["input_ids"][0] == self.tokenizer.bos_token_id - assert example["attention_mask"][0] == 1 - assert example["position_ids"][0] == 0 - assert example["position_ids"][1] == 1 - - # but subsequent one does - assert example["input_ids"][next_bos_index] == self.tokenizer.bos_token_id - assert example["attention_mask"][next_bos_index] == 2 - assert example["position_ids"][next_bos_index] == 0 - assert example["position_ids"][next_bos_index + 1] == 1 - @with_temp_dir def test_lora_packing(self, temp_dir): # pylint: disable=duplicate-code