nits
This commit is contained in:
@@ -1,4 +1,14 @@
|
|||||||
"""Module containing Dataset functionality"""
|
"""
|
||||||
|
Module containing Dataset functionality
|
||||||
|
|
||||||
|
We want this to be a wrapper for an existing dataset that we have loaded. Lets use the
|
||||||
|
concept of middlewares to wrap each dataset, for example:
|
||||||
|
ConstantLengthDataset(ShuffledDataset([TokenizedPromptDataset(alpaca_dataset)])).
|
||||||
|
Let's check to ensure we don't truncate an item in the middle. We'll use the collators
|
||||||
|
later on to pad the datasets.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from datasets import Dataset, IterableDataset
|
from datasets import Dataset, IterableDataset
|
||||||
@@ -7,12 +17,6 @@ from axolotl.utils.logging import get_logger
|
|||||||
|
|
||||||
from .prompt_tokenizers import PromptTokenizingStrategy
|
from .prompt_tokenizers import PromptTokenizingStrategy
|
||||||
|
|
||||||
# We want this to be a wrapper for an existing dataset that we have loaded
|
|
||||||
# lets use the concept of middlewares to wrap each dataset, for example
|
|
||||||
# ConstantLengthDataset(ShuffledDataset([TokenizedPromptDataset(alpaca_dataset)]))
|
|
||||||
# let's check to ensure we don't truncate an item in the middle, we'll use
|
|
||||||
# the collators later on to pad the datasets
|
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -42,14 +46,14 @@ class TokenizedPromptDataset(Dataset):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def process(self, dataset):
|
def process(self, dataset: Dataset | IterableDataset) -> Dataset | IterableDataset:
|
||||||
# For IterableDataset, we can't access features upfront
|
# For IterableDataset, we can't access features up front. Anyways, we don't care
|
||||||
# We'll need to infer from the first batch
|
# to remove unused columns from streaming datasets.
|
||||||
features = None
|
features = None
|
||||||
if hasattr(dataset, "features") and dataset.features:
|
if not isinstance(dataset, IterableDataset):
|
||||||
features = dataset.features.keys()
|
features = dataset.features.keys()
|
||||||
|
|
||||||
map_kwargs = {}
|
map_kwargs: dict[str, Any] = {}
|
||||||
if self.prompt_tokenizer.supports_batched:
|
if self.prompt_tokenizer.supports_batched:
|
||||||
map_kwargs["batched"] = True
|
map_kwargs["batched"] = True
|
||||||
map_kwargs["batch_size"] = 1_000
|
map_kwargs["batch_size"] = 1_000
|
||||||
@@ -58,9 +62,8 @@ class TokenizedPromptDataset(Dataset):
|
|||||||
hasattr(self.prompt_tokenizer, "filter_rows")
|
hasattr(self.prompt_tokenizer, "filter_rows")
|
||||||
and self.prompt_tokenizer.filter_rows
|
and self.prompt_tokenizer.filter_rows
|
||||||
):
|
):
|
||||||
filter_kwargs = {"desc": "Strategy Filtering Rows"}
|
filter_kwargs: dict[str, Any] = {"desc": "Strategy Filtering Rows"}
|
||||||
# Only add num_proc for regular datasets
|
if not isinstance(dataset, IterableDataset):
|
||||||
if features is not None:
|
|
||||||
filter_kwargs["num_proc"] = self.process_count
|
filter_kwargs["num_proc"] = self.process_count
|
||||||
|
|
||||||
dataset = dataset.filter(
|
dataset = dataset.filter(
|
||||||
@@ -74,7 +77,7 @@ class TokenizedPromptDataset(Dataset):
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Only add remove_columns for regular datasets
|
# Only add remove_columns for regular datasets
|
||||||
if features is not None:
|
if not isinstance(dataset, IterableDataset):
|
||||||
map_kwargs["remove_columns"] = features
|
map_kwargs["remove_columns"] = features
|
||||||
map_kwargs["num_proc"] = self.process_count
|
map_kwargs["num_proc"] = self.process_count
|
||||||
map_kwargs["keep_in_memory"] = self.keep_in_memory
|
map_kwargs["keep_in_memory"] = self.keep_in_memory
|
||||||
|
|||||||
@@ -111,17 +111,17 @@ def _prepare_standard_dataset(
|
|||||||
"You should set `eval_sample_packing: False` in your config."
|
"You should set `eval_sample_packing: False` in your config."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Set total_num_steps for training
|
||||||
if isinstance(train_dataset, IterableDataset):
|
if isinstance(train_dataset, IterableDataset):
|
||||||
# Streaming case
|
|
||||||
total_num_steps = cfg.max_steps
|
total_num_steps = cfg.max_steps
|
||||||
else:
|
else:
|
||||||
# Non-streaming case
|
|
||||||
if cfg.max_steps:
|
if cfg.max_steps:
|
||||||
total_num_steps = min(
|
total_num_steps = min(
|
||||||
calculate_total_num_steps(cfg, train_dataset), cfg.max_steps
|
calculate_total_num_steps(cfg, train_dataset), cfg.max_steps
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
total_num_steps = calculate_total_num_steps(cfg, train_dataset)
|
total_num_steps = calculate_total_num_steps(cfg, train_dataset)
|
||||||
|
|
||||||
LOG.info(f"Maximum number of steps set at {total_num_steps}")
|
LOG.info(f"Maximum number of steps set at {total_num_steps}")
|
||||||
return train_dataset, eval_dataset, total_num_steps, prompters
|
return train_dataset, eval_dataset, total_num_steps, prompters
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user