diff --git a/src/axolotl/datasets.py b/src/axolotl/datasets.py index 0979171f7..99ba2522b 100644 --- a/src/axolotl/datasets.py +++ b/src/axolotl/datasets.py @@ -1,4 +1,14 @@ -"""Module containing Dataset functionality""" +""" +Module containing Dataset functionality + +We want this to be a wrapper for an existing dataset that we have loaded. Lets use the +concept of middlewares to wrap each dataset, for example: +ConstantLengthDataset(ShuffledDataset([TokenizedPromptDataset(alpaca_dataset)])). +Let's check to ensure we don't truncate an item in the middle. We'll use the collators +later on to pad the datasets. +""" + +from typing import Any import torch from datasets import Dataset, IterableDataset @@ -7,12 +17,6 @@ from axolotl.utils.logging import get_logger from .prompt_tokenizers import PromptTokenizingStrategy -# We want this to be a wrapper for an existing dataset that we have loaded -# lets use the concept of middlewares to wrap each dataset, for example -# ConstantLengthDataset(ShuffledDataset([TokenizedPromptDataset(alpaca_dataset)])) -# let's check to ensure we don't truncate an item in the middle, we'll use -# the collators later on to pad the datasets - LOG = get_logger(__name__) @@ -42,14 +46,14 @@ class TokenizedPromptDataset(Dataset): **kwargs, ) - def process(self, dataset): - # For IterableDataset, we can't access features upfront - # We'll need to infer from the first batch + def process(self, dataset: Dataset | IterableDataset) -> Dataset | IterableDataset: + # For IterableDataset, we can't access features up front. Anyways, we don't care + # to remove unused columns from streaming datasets. features = None - if hasattr(dataset, "features") and dataset.features: + if not isinstance(dataset, IterableDataset): features = dataset.features.keys() - map_kwargs = {} + map_kwargs: dict[str, Any] = {} if self.prompt_tokenizer.supports_batched: map_kwargs["batched"] = True map_kwargs["batch_size"] = 1_000 @@ -58,9 +62,8 @@ class TokenizedPromptDataset(Dataset): hasattr(self.prompt_tokenizer, "filter_rows") and self.prompt_tokenizer.filter_rows ): - filter_kwargs = {"desc": "Strategy Filtering Rows"} - # Only add num_proc for regular datasets - if features is not None: + filter_kwargs: dict[str, Any] = {"desc": "Strategy Filtering Rows"} + if not isinstance(dataset, IterableDataset): filter_kwargs["num_proc"] = self.process_count dataset = dataset.filter( @@ -74,7 +77,7 @@ class TokenizedPromptDataset(Dataset): } # Only add remove_columns for regular datasets - if features is not None: + if not isinstance(dataset, IterableDataset): map_kwargs["remove_columns"] = features map_kwargs["num_proc"] = self.process_count map_kwargs["keep_in_memory"] = self.keep_in_memory diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index 38f8d6d4a..4c51aa2d1 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -111,17 +111,17 @@ def _prepare_standard_dataset( "You should set `eval_sample_packing: False` in your config." ) + # Set total_num_steps for training if isinstance(train_dataset, IterableDataset): - # Streaming case total_num_steps = cfg.max_steps else: - # Non-streaming case if cfg.max_steps: total_num_steps = min( calculate_total_num_steps(cfg, train_dataset), cfg.max_steps ) else: total_num_steps = calculate_total_num_steps(cfg, train_dataset) + LOG.info(f"Maximum number of steps set at {total_num_steps}") return train_dataset, eval_dataset, total_num_steps, prompters