remove iterable CLI arg

This commit is contained in:
Dan Saunders
2025-08-20 00:18:42 +00:00
parent b6431083be
commit 3b2dd05798
4 changed files with 3 additions and 52 deletions

View File

@@ -13,12 +13,6 @@ class PreprocessCliArgs:
debug_num_examples: int = field(default=1)
prompter: Optional[str] = field(default=None)
download: Optional[bool] = field(default=True)
iterable: Optional[bool] = field(
default=None,
metadata={
"help": "Use IterableDataset for streaming processing of large datasets"
},
)
@dataclass

View File

@@ -55,13 +55,11 @@ def load_datasets(
"""
tokenizer = load_tokenizer(cfg)
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
preprocess_iterable = getattr(cli_args, "iterable", False)
train_dataset, eval_dataset, total_num_steps, prompters = prepare_datasets(
cfg,
tokenizer,
processor=processor,
preprocess_iterable=preprocess_iterable,
)
if (

View File

@@ -44,24 +44,11 @@ from axolotl.utils.trainer import (
LOG = get_logger(__name__)
def _determine_streaming_mode(cfg: DictDefault) -> bool:
"""Determine if we should use streaming mode based on config."""
if cfg.streaming is not None:
return cfg.streaming
# Default to streaming for pretraining datasets
if cfg.pretraining_dataset:
return True
return False
@retry_on_request_exceptions(max_retries=3, delay=5)
def prepare_datasets(
cfg: DictDefault,
tokenizer: PreTrainedTokenizer,
processor: ProcessorMixin | None = None,
preprocess_iterable: bool = False,
) -> tuple[IterableDataset | Dataset, Dataset | None, int, list[Prompter | None]]:
"""Prepare training and evaluation datasets based on configuration.
@@ -69,30 +56,19 @@ def prepare_datasets(
cfg: Dictionary mapping `axolotl` config keys to values.
tokenizer: Tokenizer to use for processing text.
processor: Optional processor for multimodal datasets.
preprocess_iterable: Whether to use iterable preprocessing.
Returns:
Tuple of (train_dataset, eval_dataset, total_steps, prompters).
"""
# Determine streaming mode from config
streaming_mode = _determine_streaming_mode(cfg)
# Override preprocess_iterable parameter with streaming config
if streaming_mode:
preprocess_iterable = True
if cfg.pretraining_dataset:
return _prepare_pretraining_dataset(
cfg, tokenizer, processor, preprocess_iterable
)
return _prepare_standard_dataset(cfg, tokenizer, processor, preprocess_iterable)
return _prepare_pretraining_dataset(cfg, tokenizer, processor)
return _prepare_standard_dataset(cfg, tokenizer, processor)
def _prepare_standard_dataset(
cfg: DictDefault,
tokenizer: PreTrainedTokenizer,
processor: ProcessorMixin | None,
preprocess_iterable: bool,
) -> tuple[Dataset, Dataset | None, int, list[Prompter | None]]:
"""Prepare standard (non-pretraining) datasets."""
@@ -103,7 +79,6 @@ def _prepare_standard_dataset(
cfg,
split="train",
processor=processor,
preprocess_iterable=preprocess_iterable,
)
# Overwrite eval_dataset if test data exists
@@ -113,7 +88,6 @@ def _prepare_standard_dataset(
cfg,
split="test",
processor=processor,
preprocess_iterable=preprocess_iterable,
)
return train_dataset, eval_dataset, prompters
@@ -159,7 +133,6 @@ def _prepare_pretraining_dataset(
cfg: DictDefault,
tokenizer: PreTrainedTokenizer,
processor: ProcessorMixin | None,
preprocess_iterable: bool,
) -> tuple[IterableDataset, Dataset | None, int, list[Prompter | None]]:
"""
Prepare dataset for pretraining mode.
@@ -180,7 +153,6 @@ def _prepare_pretraining_dataset(
cfg,
split="test",
processor=processor,
preprocess_iterable=preprocess_iterable,
)
if cfg.dataset_exact_deduplication:
@@ -283,7 +255,6 @@ def _load_tokenized_prepared_datasets(
cfg: DictDefault,
split: Literal["train", "test"] = "train",
processor: ProcessorMixin | None = None,
preprocess_iterable: bool = False,
) -> tuple[Dataset | DatasetDict, list[Prompter | None]]:
"""Load or create tokenized and prepared datasets for training or testing.
@@ -292,7 +263,6 @@ def _load_tokenized_prepared_datasets(
cfg: Configuration object.
split: Dataset split to load ('train' or 'test').
processor: Optional processor for multimodal datasets.
preprocess_iterable: Whether to use iterable preprocessing.
Returns:
Tuple of (dataset, prompters list).
@@ -323,7 +293,6 @@ def _load_tokenized_prepared_datasets(
tokenizer,
split,
processor,
preprocess_iterable,
)
return dataset, prompters
@@ -335,7 +304,6 @@ def _load_raw_datasets(
tokenizer: PreTrainedTokenizer,
split: str,
processor: ProcessorMixin | None = None,
preprocess_iterable: bool = False,
) -> tuple[Dataset, list[Prompter | None]]:
"""Load, process, merge, and save raw datasets."""
LOG.info("Loading raw datasets...", main_process_only=False)
@@ -356,7 +324,6 @@ def _load_raw_datasets(
split=split,
seed=cfg.seed,
processor=processor,
preprocess_iterable=preprocess_iterable,
)
datasets.append(dataset_wrapper)
prompters.append(dataset_prompter)
@@ -388,15 +355,11 @@ def _load_and_process_single_dataset(
split: str,
seed: int,
processor: ProcessorMixin | None = None,
preprocess_iterable: bool = False,
) -> tuple[Dataset | IterableDataset, Prompter | None]:
"""Load and process a single dataset based on the passed config."""
# Load the dataset
dataset = load_dataset_with_config(
dataset_config, cfg.hf_use_auth_token, streaming=preprocess_iterable
dataset_config, cfg.hf_use_auth_token, cfg.streaming
)
# Parse dataset type
d_base_type, d_prompt_style = _parse_dataset_type(dataset_config.type)
# Select the appropriate split
@@ -515,7 +478,6 @@ def _load_and_prepare_datasets(
cfg: DictDefault,
split: Literal["train", "test"] = "train",
processor: ProcessorMixin | None = None,
preprocess_iterable: bool = False,
) -> tuple[Dataset | None, Dataset | None, list[Prompter | None]]:
"""Load and prepare datasets with optional validation split and sharding.
@@ -524,7 +486,6 @@ def _load_and_prepare_datasets(
cfg: Configuration object.
split: Dataset split to load ('train' or 'test').
processor: Optional processor for multimodal datasets.
preprocess_iterable: Whether to use iterable preprocessing.
Returns:
Tuple of (train_dataset, eval_dataset, prompters).
@@ -535,7 +496,6 @@ def _load_and_prepare_datasets(
cfg,
split=split,
processor=processor,
preprocess_iterable=preprocess_iterable,
)
# Apply dataset sharding if configured using shared function

View File

@@ -962,7 +962,6 @@ class AxolotlInputConfig(
# INTERNALS - document for now, generally not set externally
is_preprocess: bool | None = None
preprocess_iterable: bool | None = None
total_num_tokens: int | None = Field(
default=None,