This commit is contained in:
Dan Saunders
2025-08-20 04:15:06 +00:00
parent 7bb52d00bb
commit 846aa41baa
2 changed files with 21 additions and 18 deletions

View File

@@ -1,4 +1,14 @@
"""Module containing Dataset functionality"""
"""
Module containing Dataset functionality
We want this to be a wrapper for an existing dataset that we have loaded. Lets use the
concept of middlewares to wrap each dataset, for example:
ConstantLengthDataset(ShuffledDataset([TokenizedPromptDataset(alpaca_dataset)])).
Let's check to ensure we don't truncate an item in the middle. We'll use the collators
later on to pad the datasets.
"""
from typing import Any
import torch
from datasets import Dataset, IterableDataset
@@ -7,12 +17,6 @@ from axolotl.utils.logging import get_logger
from .prompt_tokenizers import PromptTokenizingStrategy
# We want this to be a wrapper for an existing dataset that we have loaded
# lets use the concept of middlewares to wrap each dataset, for example
# ConstantLengthDataset(ShuffledDataset([TokenizedPromptDataset(alpaca_dataset)]))
# let's check to ensure we don't truncate an item in the middle, we'll use
# the collators later on to pad the datasets
LOG = get_logger(__name__)
@@ -42,14 +46,14 @@ class TokenizedPromptDataset(Dataset):
**kwargs,
)
def process(self, dataset):
# For IterableDataset, we can't access features upfront
# We'll need to infer from the first batch
def process(self, dataset: Dataset | IterableDataset) -> Dataset | IterableDataset:
# For IterableDataset, we can't access features up front. Anyways, we don't care
# to remove unused columns from streaming datasets.
features = None
if hasattr(dataset, "features") and dataset.features:
if not isinstance(dataset, IterableDataset):
features = dataset.features.keys()
map_kwargs = {}
map_kwargs: dict[str, Any] = {}
if self.prompt_tokenizer.supports_batched:
map_kwargs["batched"] = True
map_kwargs["batch_size"] = 1_000
@@ -58,9 +62,8 @@ class TokenizedPromptDataset(Dataset):
hasattr(self.prompt_tokenizer, "filter_rows")
and self.prompt_tokenizer.filter_rows
):
filter_kwargs = {"desc": "Strategy Filtering Rows"}
# Only add num_proc for regular datasets
if features is not None:
filter_kwargs: dict[str, Any] = {"desc": "Strategy Filtering Rows"}
if not isinstance(dataset, IterableDataset):
filter_kwargs["num_proc"] = self.process_count
dataset = dataset.filter(
@@ -74,7 +77,7 @@ class TokenizedPromptDataset(Dataset):
}
# Only add remove_columns for regular datasets
if features is not None:
if not isinstance(dataset, IterableDataset):
map_kwargs["remove_columns"] = features
map_kwargs["num_proc"] = self.process_count
map_kwargs["keep_in_memory"] = self.keep_in_memory

View File

@@ -111,17 +111,17 @@ def _prepare_standard_dataset(
"You should set `eval_sample_packing: False` in your config."
)
# Set total_num_steps for training
if isinstance(train_dataset, IterableDataset):
# Streaming case
total_num_steps = cfg.max_steps
else:
# Non-streaming case
if cfg.max_steps:
total_num_steps = min(
calculate_total_num_steps(cfg, train_dataset), cfg.max_steps
)
else:
total_num_steps = calculate_total_num_steps(cfg, train_dataset)
LOG.info(f"Maximum number of steps set at {total_num_steps}")
return train_dataset, eval_dataset, total_num_steps, prompters