remove unuse

This commit is contained in:
Dan Saunders
2025-08-20 19:26:26 +00:00
parent 7eba3795fe
commit 7836da9ed9
3 changed files with 0 additions and 176 deletions

View File

@@ -10,7 +10,6 @@ later on to pad the datasets.
from typing import Any
import torch
from datasets import Dataset, IterableDataset
from axolotl.utils.logging import get_logger
@@ -103,134 +102,3 @@ def wrap_dataset_for_tokenized_prompt(
**map_kwargs,
)
return TokenizedPromptDataset(prompt_tokenizer, dataset, **kwargs)
# TODO: this isn't the best since it can't interleave datasets.
# NOTE: this is only used in a test. Can it be deleted?
class ConstantLengthDataset(IterableDataset):
"""Iterable dataset that returns constant length chunks of tokens from stream of
text files.
Args:
tokenizer: The processor used for processing the data.
dataset: Dataset with text files.
seq_length: Length of token sequences to return.
"""
def __init__( # pylint: disable=super-init-not-called
self,
tokenizer,
datasets,
seq_length=2048,
):
self.tokenizer = tokenizer
self.concat_token_id = tokenizer.eos_token_id
self.datasets: list[IterableDataset] = datasets
self.seq_length = seq_length
vocab_size = len(tokenizer.get_vocab())
if vocab_size <= torch.iinfo(torch.int16).max:
self.tokens_dtype = torch.int16
elif vocab_size <= torch.iinfo(torch.int32).max:
self.tokens_dtype = torch.int32
else:
self.tokens_dtype = torch.int64
def __iter__(self):
buffer = {
"input_ids": [],
"attention_mask": [],
"labels": [],
"position_ids": [],
}
buffer_len = 0
for dataset in self.datasets:
idx = 0
iterator = iter(dataset)
more_examples = True
while more_examples:
try:
example = next(iterator)
idx += 1
except StopIteration:
more_examples = False
example = None
add_concat_token = False
if example:
example_len = len(example["input_ids"])
add_concat_token = example["input_ids"][-1] != self.concat_token_id
else:
example_len = 0
if not example_len or (
buffer_len + int(add_concat_token) + example_len > self.seq_length
):
if buffer["input_ids"]:
input_ids = torch.cat(buffer["input_ids"], dim=-1)[
: self.seq_length
]
attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[
: self.seq_length
]
position_ids = torch.cat(buffer["position_ids"], dim=-1)[
: self.seq_length
]
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
if labels.size() == input_ids.size() and (
attention_mask.size() == input_ids.size()
):
yield {
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask,
"position_ids": position_ids,
}
else:
LOG.warning(
"Dropping batch due to tensor size mismatch "
f"input_ids: {input_ids.size()}, "
f"labels: {labels.size()}, "
f"attention_mask: {attention_mask.size()}"
)
buffer = {
"input_ids": [],
"attention_mask": [],
"labels": [],
"position_ids": [],
}
buffer_len = 0
idx = 1
if example:
# FIXME
# just going to drop data points that are too long
if len(example["input_ids"]) <= self.seq_length:
input_ids = example["input_ids"]
attention_mask = example["attention_mask"]
labels = example["labels"]
if add_concat_token:
input_ids.append(self.concat_token_id)
attention_mask.append(1)
labels.append(self.concat_token_id)
input_ids_with_concat = torch.tensor(
input_ids, dtype=self.tokens_dtype
)
attention_mask_with_concat = torch.tensor(
[idx * m for m in attention_mask], dtype=torch.int16
)
labels_with_concat = torch.tensor(
labels, dtype=self.tokens_dtype
)
position_ids = torch.arange(
len(input_ids), dtype=self.tokens_dtype
)
buffer["input_ids"].append(input_ids_with_concat)
buffer["attention_mask"].append(attention_mask_with_concat)
buffer["labels"].append(labels_with_concat)
buffer["position_ids"].append(position_ids)
buffer_len += len(input_ids)

View File

@@ -593,7 +593,6 @@ def _merge_datasets_with_strategy(
LOG.info(f"Merging datasets with mixing strategy: {strategy}...")
if strategy == "concatenate":
# Concatenate only works with non-iterable datasets
if not all(isinstance(ds, Dataset) for ds in datasets):
raise ValueError(
"Cannot concatenate streaming datasets. Use 'round_robin', 'weighted', "
@@ -605,7 +604,6 @@ def _merge_datasets_with_strategy(
if strategy == "weighted":
return interleave_datasets(datasets, probabilities=weights, seed=cfg.seed)
if strategy == "random":
# Random sampling with equal probability
equal_weights = [1.0 / len(datasets)] * len(datasets)
return interleave_datasets(datasets, probabilities=equal_weights, seed=cfg.seed)
raise ValueError(f"Unknown dataset mixing strategy: {strategy}")