nits
This commit is contained in:
@@ -1,4 +1,14 @@
|
||||
"""Module containing Dataset functionality"""
|
||||
"""
|
||||
Module containing Dataset functionality
|
||||
|
||||
We want this to be a wrapper for an existing dataset that we have loaded. Lets use the
|
||||
concept of middlewares to wrap each dataset, for example:
|
||||
ConstantLengthDataset(ShuffledDataset([TokenizedPromptDataset(alpaca_dataset)])).
|
||||
Let's check to ensure we don't truncate an item in the middle. We'll use the collators
|
||||
later on to pad the datasets.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from datasets import Dataset, IterableDataset
|
||||
@@ -7,12 +17,6 @@ from axolotl.utils.logging import get_logger
|
||||
|
||||
from .prompt_tokenizers import PromptTokenizingStrategy
|
||||
|
||||
# We want this to be a wrapper for an existing dataset that we have loaded
|
||||
# lets use the concept of middlewares to wrap each dataset, for example
|
||||
# ConstantLengthDataset(ShuffledDataset([TokenizedPromptDataset(alpaca_dataset)]))
|
||||
# let's check to ensure we don't truncate an item in the middle, we'll use
|
||||
# the collators later on to pad the datasets
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -42,14 +46,14 @@ class TokenizedPromptDataset(Dataset):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def process(self, dataset):
|
||||
# For IterableDataset, we can't access features upfront
|
||||
# We'll need to infer from the first batch
|
||||
def process(self, dataset: Dataset | IterableDataset) -> Dataset | IterableDataset:
|
||||
# For IterableDataset, we can't access features up front. Anyways, we don't care
|
||||
# to remove unused columns from streaming datasets.
|
||||
features = None
|
||||
if hasattr(dataset, "features") and dataset.features:
|
||||
if not isinstance(dataset, IterableDataset):
|
||||
features = dataset.features.keys()
|
||||
|
||||
map_kwargs = {}
|
||||
map_kwargs: dict[str, Any] = {}
|
||||
if self.prompt_tokenizer.supports_batched:
|
||||
map_kwargs["batched"] = True
|
||||
map_kwargs["batch_size"] = 1_000
|
||||
@@ -58,9 +62,8 @@ class TokenizedPromptDataset(Dataset):
|
||||
hasattr(self.prompt_tokenizer, "filter_rows")
|
||||
and self.prompt_tokenizer.filter_rows
|
||||
):
|
||||
filter_kwargs = {"desc": "Strategy Filtering Rows"}
|
||||
# Only add num_proc for regular datasets
|
||||
if features is not None:
|
||||
filter_kwargs: dict[str, Any] = {"desc": "Strategy Filtering Rows"}
|
||||
if not isinstance(dataset, IterableDataset):
|
||||
filter_kwargs["num_proc"] = self.process_count
|
||||
|
||||
dataset = dataset.filter(
|
||||
@@ -74,7 +77,7 @@ class TokenizedPromptDataset(Dataset):
|
||||
}
|
||||
|
||||
# Only add remove_columns for regular datasets
|
||||
if features is not None:
|
||||
if not isinstance(dataset, IterableDataset):
|
||||
map_kwargs["remove_columns"] = features
|
||||
map_kwargs["num_proc"] = self.process_count
|
||||
map_kwargs["keep_in_memory"] = self.keep_in_memory
|
||||
|
||||
@@ -111,17 +111,17 @@ def _prepare_standard_dataset(
|
||||
"You should set `eval_sample_packing: False` in your config."
|
||||
)
|
||||
|
||||
# Set total_num_steps for training
|
||||
if isinstance(train_dataset, IterableDataset):
|
||||
# Streaming case
|
||||
total_num_steps = cfg.max_steps
|
||||
else:
|
||||
# Non-streaming case
|
||||
if cfg.max_steps:
|
||||
total_num_steps = min(
|
||||
calculate_total_num_steps(cfg, train_dataset), cfg.max_steps
|
||||
)
|
||||
else:
|
||||
total_num_steps = calculate_total_num_steps(cfg, train_dataset)
|
||||
|
||||
LOG.info(f"Maximum number of steps set at {total_num_steps}")
|
||||
return train_dataset, eval_dataset, total_num_steps, prompters
|
||||
|
||||
|
||||
Reference in New Issue
Block a user