From 00cda8cc70ca6a2f501cef7e843ae87931faece7 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Tue, 10 Jun 2025 19:53:07 -0400 Subject: [PATCH] Data loader refactor (#2707) * data loading refactor (wip) * updates * progress * pytest * pytest fix * lint * zero_first -> filelock, more simplifications * small simplification * import change * nit * lint * simplify dedup * couldnt resist * review comments WIP * continued wip * minor changes * fix; remove contrived test * further refactor * set default seed in pydantic config * lint * continued simplication * lint * renaming and nits * filelock tests * fix * fix * lint * remove nullable arg * remove unnecessary code * moving dataset save fn to shared module * remove debug print * matching var naming * fn name change * coderabbit comments * naming nit * fix test --- src/axolotl/common/const.py | 4 +- src/axolotl/common/datasets.py | 77 +- src/axolotl/core/builders/causal.py | 2 +- src/axolotl/datasets.py | 43 +- src/axolotl/loaders/tokenizer.py | 9 +- .../prompt_strategies/messages/__init__.py | 1 - src/axolotl/prompt_tokenizers.py | 11 + src/axolotl/train.py | 4 +- src/axolotl/utils/data/__init__.py | 25 +- src/axolotl/utils/data/lock.py | 66 ++ src/axolotl/utils/data/pretraining.py | 2 +- src/axolotl/utils/data/rl.py | 385 +++---- src/axolotl/utils/data/sft.py | 978 ++++++++---------- src/axolotl/utils/data/shared.py | 686 ++++++++---- src/axolotl/utils/data/utils.py | 165 ++- src/axolotl/utils/data/wrappers.py | 425 ++++++++ src/axolotl/utils/schemas/config.py | 10 +- tests/core/test_builders.py | 12 +- .../integrations/test_cut_cross_entropy.py | 10 +- tests/e2e/integrations/test_hooks.py | 4 +- tests/e2e/integrations/test_kd.py | 7 +- tests/e2e/integrations/test_liger.py | 7 +- tests/e2e/integrations/test_llm_compressor.py | 4 +- tests/e2e/multigpu/solo/test_grpo.py | 2 +- tests/e2e/multigpu/test_locking.py | 192 ++++ tests/e2e/patched/test_4d_multipack_llama.py | 7 +- .../patched/test_activation_checkpointing.py | 4 +- tests/e2e/patched/test_fa_xentropy.py | 4 +- tests/e2e/patched/test_falcon_samplepack.py | 7 +- tests/e2e/patched/test_fused_llama.py | 4 +- tests/e2e/patched/test_llama_s2_attention.py | 7 +- .../e2e/patched/test_lora_llama_multipack.py | 7 +- tests/e2e/patched/test_mistral_samplepack.py | 7 +- tests/e2e/patched/test_mixtral_samplepack.py | 7 +- tests/e2e/patched/test_phi_multipack.py | 7 +- tests/e2e/patched/test_resume.py | 5 +- tests/e2e/patched/test_unsloth_qlora.py | 10 +- tests/e2e/solo/test_flex.py | 4 +- tests/e2e/solo/test_relora_llama.py | 4 +- tests/e2e/test_deepseekv3.py | 7 +- tests/e2e/test_dpo.py | 4 +- tests/e2e/test_embeddings_lr.py | 7 +- tests/e2e/test_falcon.py | 10 +- tests/e2e/test_gemma2.py | 7 +- tests/e2e/test_gemma3_text.py | 7 +- tests/e2e/test_llama.py | 13 +- tests/e2e/test_llama_pretrain.py | 12 +- tests/e2e/test_llama_vision.py | 7 +- tests/e2e/test_lora_llama.py | 4 +- tests/e2e/test_mamba.py | 4 +- tests/e2e/test_mistral.py | 7 +- tests/e2e/test_mixtral.py | 16 +- tests/e2e/test_optimizers.py | 16 +- tests/e2e/test_packing_loss.py | 4 +- tests/e2e/test_phi.py | 7 +- .../e2e/test_process_reward_model_smollm2.py | 4 +- tests/e2e/test_qat.py | 4 +- tests/e2e/test_reward_model_smollm2.py | 4 +- tests/e2e/test_schedulers.py | 4 +- tests/prompt_strategies/test_dpo_chatml.py | 6 +- tests/test_datasets.py | 81 +- tests/test_exact_deduplication.py | 111 +- 62 files changed, 2125 insertions(+), 1436 deletions(-) create mode 100644 src/axolotl/utils/data/lock.py create mode 100644 src/axolotl/utils/data/wrappers.py create mode 100644 tests/e2e/multigpu/test_locking.py diff --git a/src/axolotl/common/const.py b/src/axolotl/common/const.py index fd34ad469..8aae06e99 100644 --- a/src/axolotl/common/const.py +++ b/src/axolotl/common/const.py @@ -1,5 +1,3 @@ -""" -Various shared constants -""" +"""Various shared constants""" DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared" diff --git a/src/axolotl/common/datasets.py b/src/axolotl/common/datasets.py index d9c384112..4d64958b6 100644 --- a/src/axolotl/common/datasets.py +++ b/src/axolotl/common/datasets.py @@ -3,15 +3,13 @@ import math import random from dataclasses import dataclass -from typing import Optional, Union from datasets import Dataset import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401 from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs from axolotl.loaders import load_processor, load_tokenizer -from axolotl.utils.data import prepare_dataset -from axolotl.utils.data.rl import load_prepare_preference_datasets +from axolotl.utils.data import prepare_datasets, prepare_preference_datasets from axolotl.utils.dict import DictDefault from axolotl.utils.logging import get_logger from axolotl.utils.schemas.enums import RLType @@ -30,16 +28,7 @@ class TrainDatasetMeta: def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset: - """ - Randomly sample `num_samples` samples from `dataset`. - - Args: - dataset: Dataset. - num_samples: Number of samples to return. - - Returns: - Random sample (with replacement) of examples in `dataset`. - """ + """Randomly sample `num_samples` samples with replacement from `dataset`.""" return dataset.select( [random.randrange(0, len(dataset) - 1) for _ in range(num_samples)] # nosec ) @@ -51,44 +40,37 @@ def load_datasets( cli_args: PreprocessCliArgs | TrainerCliArgs | None = None, debug: bool = False, ) -> TrainDatasetMeta: - """ - Loads one or more training or evaluation datasets, calling - `axolotl.utils.data.prepare_dataset`. Optionally, logs out debug information. + """Loads one or more training or evaluation datasets, calling + `axolotl.utils.data.prepare_datasets`. Optionally, logs out debug information. Args: cfg: Dictionary mapping `axolotl` config keys to values. cli_args: Command-specific CLI arguments. - debug: Whether to print out tokenization of sample + debug: Whether to print out tokenization of sample. This is duplicated in + `cfg` and `cli_args`, but is kept due to use in our Colab notebooks. Returns: Dataclass with fields for training and evaluation datasets and the computed - `total_num_steps`. + `total_num_steps`. """ tokenizer = load_tokenizer(cfg) processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None - preprocess_iterable = ( - cli_args - and hasattr(cli_args, "iterable") - and cli_args.iterable is not None - and cli_args.iterable - ) + preprocess_iterable = getattr(cli_args, "iterable", False) - train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset( + train_dataset, eval_dataset, total_num_steps, prompters = prepare_datasets( cfg, tokenizer, processor=processor, preprocess_iterable=preprocess_iterable, ) - if ( # pylint: disable=too-many-boolean-expressions - cli_args - and ( - cli_args.debug - or cfg.debug - or cli_args.debug_text_only - or int(cli_args.debug_num_examples) > 0 - ) - ) or debug: + if ( + cfg.debug + or getattr(cli_args, "debug", False) + or getattr(cli_args, "debug_text_only", False) + or getattr(cli_args, "debug_num_examples", 0) > 0 + or debug + ): LOG.info("check_dataset_labels...") num_examples = cli_args.debug_num_examples if cli_args else 1 @@ -113,13 +95,10 @@ def load_datasets( def load_preference_datasets( - *, - cfg: DictDefault, - cli_args: Union[PreprocessCliArgs, TrainerCliArgs], + *, cfg: DictDefault, cli_args: PreprocessCliArgs | TrainerCliArgs ) -> TrainDatasetMeta: - """ - Loads one or more training or evaluation datasets for RL training using paired - preference data, calling `axolotl.utils.data.rl.load_prepare_preference_datasets`. + """Loads one or more training or evaluation datasets for RL training using paired + preference data, calling `axolotl.utils.data.rl.prepare_preference_datasets`. Optionally, logs out debug information. Args: @@ -130,12 +109,14 @@ def load_preference_datasets( Dataclass with fields for training and evaluation datasets and the computed `total_num_steps`. """ - train_dataset, eval_dataset = load_prepare_preference_datasets(cfg) - total_num_steps: Optional[int] = int( - math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) - ) - if cfg.rl is RLType.GRPO: - total_num_steps = None + tokenizer = load_tokenizer(cfg) + train_dataset, eval_dataset = prepare_preference_datasets(cfg, tokenizer) + + total_num_steps: int | None = None + if cfg.rl is not RLType.GRPO: + total_num_steps = int( + math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) + ) if cli_args.debug or cfg.debug: LOG.info("check_dataset_labels...") @@ -143,8 +124,8 @@ def load_preference_datasets( tokenizer = load_tokenizer(cfg) train_samples = sample_dataset(train_dataset, cli_args.debug_num_examples) check_dataset_labels( - train_samples, - tokenizer, + dataset=train_samples, + tokenizer=tokenizer, num_examples=cli_args.debug_num_examples, text_only=cli_args.debug_text_only, rl_mode=True, diff --git a/src/axolotl/core/builders/causal.py b/src/axolotl/core/builders/causal.py index 7a81616ba..8ff565dbb 100644 --- a/src/axolotl/core/builders/causal.py +++ b/src/axolotl/core/builders/causal.py @@ -381,7 +381,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): elif "tokenizer" in sig.parameters: trainer_kwargs["tokenizer"] = self.tokenizer if ( - not (trainer_cls in [AxolotlRewardTrainer, AxolotlPRMTrainer]) + trainer_cls not in [AxolotlRewardTrainer, AxolotlPRMTrainer] and self.cfg.datasets is not None ): trainer_kwargs["dataset_tags"] = [ diff --git a/src/axolotl/datasets.py b/src/axolotl/datasets.py index 9f1d9500d..7c112c59e 100644 --- a/src/axolotl/datasets.py +++ b/src/axolotl/datasets.py @@ -1,7 +1,6 @@ """Module containing Dataset functionality""" import os -from typing import List, Optional, Union import torch from datasets import Dataset, IterableDataset @@ -20,21 +19,21 @@ LOG = get_logger(__name__) class TokenizedPromptDataset(Dataset): - """ - Dataset that returns tokenized prompts from a stream of text files. - Args: - prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for processing the data. - dataset (dataset.Dataset): Dataset with text files. - process_count (int): Number of processes to use for tokenizing. - keep_in_memory (bool): Whether to keep the tokenized dataset in memory. + """Dataset that returns tokenized prompts from a stream of text files. + + Args: + prompt_tokenizer: The prompt tokenizing method for processing the data. + dataset: Dataset with text files. + process_count: Number of processes to use for tokenizing. + keep_in_memory: Whether to keep the tokenized dataset in memory. """ def __init__( # pylint: disable=super-init-not-called self, prompt_tokenizer: PromptTokenizingStrategy, dataset: Dataset, - process_count: Optional[int] = None, - keep_in_memory: Optional[bool] = False, + process_count: int | None = None, + keep_in_memory: bool | None = False, **kwargs, ): self.prompt_tokenizer = prompt_tokenizer @@ -76,14 +75,14 @@ class TokenizedPromptDataset(Dataset): def wrap_dataset_for_tokenized_prompt( prompt_tokenizer: PromptTokenizingStrategy, - dataset: Union[Dataset, IterableDataset], + dataset: Dataset | IterableDataset, **kwargs, ): if isinstance(dataset, IterableDataset): map_kwargs = {} if prompt_tokenizer.supports_batched: map_kwargs["batched"] = True - features = dataset.features.keys() + features = list(dataset.features.keys()) return dataset.map( prompt_tokenizer.tokenize_prompt, remove_columns=features, @@ -94,12 +93,13 @@ def wrap_dataset_for_tokenized_prompt( # TODO this isn't the best since it can't interleave datasets class ConstantLengthDataset(IterableDataset): - """ - Iterable dataset that returns constant length chunks of tokens from stream of text files. - Args: - tokenizer (Tokenizer): The processor used for processing the data. - dataset (dataset.Dataset): Dataset with text files. - seq_length (int): Length of token sequences to return. + """Iterable dataset that returns constant length chunks of tokens from stream of + text files. + + Args: + tokenizer: The processor used for processing the data. + dataset: Dataset with text files. + seq_length: Length of token sequences to return. """ def __init__( # pylint: disable=super-init-not-called @@ -110,7 +110,7 @@ class ConstantLengthDataset(IterableDataset): ): self.tokenizer = tokenizer self.concat_token_id = tokenizer.eos_token_id - self.datasets: List[IterableDataset] = datasets + self.datasets: list[IterableDataset] = datasets self.seq_length = seq_length vocab_size = len(tokenizer.get_vocab()) @@ -174,7 +174,10 @@ class ConstantLengthDataset(IterableDataset): } else: LOG.warning( - f"dropping batch due to tensor size mismatch input_ids: {input_ids.size()}, labels: {labels.size()}, attention_mask: {attention_mask.size()}" + "Dropping batch due to tensor size mismatch " + f"input_ids: {input_ids.size()}, " + f"labels: {labels.size()}, " + f"attention_mask: {attention_mask.size()}" ) buffer = { "input_ids": [], diff --git a/src/axolotl/loaders/tokenizer.py b/src/axolotl/loaders/tokenizer.py index c311d5247..5a174186d 100644 --- a/src/axolotl/loaders/tokenizer.py +++ b/src/axolotl/loaders/tokenizer.py @@ -7,12 +7,14 @@ import transformers from transformers import ( AddedToken, AutoTokenizer, + PreTrainedTokenizer, ) from axolotl.integrations.base import PluginManager from axolotl.loaders.utils import get_linear_embedding_layers, load_model_config from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN from axolotl.utils.chat_templates import get_chat_template_from_config +from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import ( barrier, is_local_main_process, @@ -117,7 +119,7 @@ def modify_tokenizer_files( return tokenizer_dir -def load_tokenizer(cfg): +def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer: """Load and configure the tokenizer based on the provided config.""" model_config = load_model_config(cfg) tokenizer_kwargs = {} @@ -207,11 +209,12 @@ def load_tokenizer(cfg): ) and k != "pad_token" ): - lora_modules_to_save = ", ".join( + lora_modules_to_save_str = ", ".join( [f"`{x}`" for x in lora_modules_to_save] ) raise ValueError( - f"Please set lora_modules_to_save to [{lora_modules_to_save}] when using an adapter and changing the special tokens." + f"Please set lora_modules_to_save to [{lora_modules_to_save_str}] " + "when using an adapter and changing the special tokens." ) tokenizer.add_special_tokens( diff --git a/src/axolotl/prompt_strategies/messages/__init__.py b/src/axolotl/prompt_strategies/messages/__init__.py index cc7b84da1..6eae9dfd8 100644 --- a/src/axolotl/prompt_strategies/messages/__init__.py +++ b/src/axolotl/prompt_strategies/messages/__init__.py @@ -32,4 +32,3 @@ def load(tokenizer, cfg, ds_cfg, processor=None): except Exception as exc: # pylint: disable=broad-exception-caught LOG.error(f"Failed to load prompt strategy `{strategy}`: {str(exc)}") raise exc - return None diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index cb1a1ba4e..9ca645de3 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -3,6 +3,7 @@ import abc from typing import Callable, Dict, List, Optional, Tuple, Union +from datasets import Dataset from transformers import BatchEncoding, PreTrainedTokenizer from axolotl.prompters import Prompter @@ -28,6 +29,16 @@ class DatasetWrappingStrategy(abc.ABC): Abstract class for wrapping datasets for Chat Messages """ + @abc.abstractmethod + def wrap_dataset( + self, + dataset, + process_count: int | None = None, + keep_in_memory: bool | None = False, + **kwargs, + ) -> Dataset: + pass + class PromptTokenizingStrategy(abc.ABC): """ diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 866a9c454..13ac8ec0d 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -53,8 +53,8 @@ def setup_model_and_tokenizer( ) -> tuple[ PreTrainedModel, PreTrainedTokenizer, PeftConfig | None, ProcessorMixin | None ]: - """ - Load the tokenizer, processor (for multimodal models), and model based on configuration. + """Load the tokenizer, processor (for multimodal models), and model based on + configuration. Args: cfg: Dictionary mapping `axolotl` config keys to values. diff --git a/src/axolotl/utils/data/__init__.py b/src/axolotl/utils/data/__init__.py index 8dedcbe69..d162a7d0b 100644 --- a/src/axolotl/utils/data/__init__.py +++ b/src/axolotl/utils/data/__init__.py @@ -1,16 +1,21 @@ -""" -Data processing modules -""" +"""Init for `axolotl.utils.data` module.""" -from axolotl.utils.data.pretraining import ( # noqa: F401 +from axolotl.utils.data.pretraining import ( encode_pretraining, wrap_pretraining_dataset, ) -from axolotl.utils.data.rl import load_prepare_preference_datasets # noqa: F401 -from axolotl.utils.data.sft import ( # noqa: F401 +from axolotl.utils.data.rl import prepare_preference_datasets +from axolotl.utils.data.sft import ( get_dataset_wrapper, - load_prepare_datasets, - load_tokenized_prepared_datasets, - prepare_dataset, + prepare_datasets, ) -from axolotl.utils.data.utils import md5 # noqa: F401 +from axolotl.utils.data.utils import md5 + +__all__ = [ + "encode_pretraining", + "wrap_pretraining_dataset", + "prepare_preference_datasets", + "get_dataset_wrapper", + "prepare_datasets", + "md5", +] diff --git a/src/axolotl/utils/data/lock.py b/src/axolotl/utils/data/lock.py new file mode 100644 index 000000000..f5ec1679b --- /dev/null +++ b/src/axolotl/utils/data/lock.py @@ -0,0 +1,66 @@ +"""Logic for loading / preparing a dataset once over all processes.""" + +import time +from pathlib import Path +from typing import Any, Callable + +from filelock import FileLock + +from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH +from axolotl.utils.dict import DictDefault + +LOCK_FILE_NAME = "datasets_prep.lock" +READY_FILE_NAME = "datasets_ready.flag" +PROCESS_COUNTER_FILE_NAME = "process_counter.txt" + + +class FileLockLoader: + """ + Simple class for abstracting single process data loading / processing. The first + process that creates a lock file does the work; the remaining procesees simply load + the preprocessed dataset once the first process is done. + """ + + def __init__(self, cfg: DictDefault): + self.cfg = cfg + self.dataset_prepared_path = ( + cfg.dataset_prepared_path or DEFAULT_DATASET_PREPARED_PATH + ) + self.lock_file_path = Path(self.dataset_prepared_path) / LOCK_FILE_NAME + self.ready_flag_path = Path(self.dataset_prepared_path) / READY_FILE_NAME + self.counter_path = Path(self.dataset_prepared_path) / PROCESS_COUNTER_FILE_NAME + + def load(self, load_fn: Callable[[], Any]) -> Any: + with FileLock(str(self.lock_file_path)): + self._increment_counter() + + if not self.ready_flag_path.exists(): + result = load_fn() + self.ready_flag_path.touch() + return result + + while not self.ready_flag_path.exists(): + time.sleep(1) + return load_fn() + + def _increment_counter(self): + """Safely increment the process counter.""" + if self.counter_path.exists(): + count = int(self.counter_path.read_text().strip()) + else: + count = 0 + self.counter_path.write_text(str(count + 1)) + + def cleanup(self): + """Clean up ready flag when last process is done.""" + with FileLock(str(self.lock_file_path)): + count = int(self.counter_path.read_text().strip()) + count -= 1 + + if count == 0: + # Last process cleans everything up + self.ready_flag_path.unlink(missing_ok=True) + self.counter_path.unlink(missing_ok=True) + else: + # Still have active processes + self.counter_path.write_text(str(count)) diff --git a/src/axolotl/utils/data/pretraining.py b/src/axolotl/utils/data/pretraining.py index 44d8d6fed..4ff108aee 100644 --- a/src/axolotl/utils/data/pretraining.py +++ b/src/axolotl/utils/data/pretraining.py @@ -250,7 +250,7 @@ def encode_packed_pretraining( # pylint: disable=duplicate-code # tokenize all the examples # rows get split with stride (overlap) - train_dataset = ds_wrapper(Dataset.from_dict(examples))[0] + train_dataset = ds_wrapper(dataset=Dataset.from_dict(examples))[0] train_dataset = process_pretraining_datasets_for_packing( train_dataset, diff --git a/src/axolotl/utils/data/rl.py b/src/axolotl/utils/data/rl.py index 9264c86ab..6fd539758 100644 --- a/src/axolotl/utils/data/rl.py +++ b/src/axolotl/utils/data/rl.py @@ -1,75 +1,117 @@ -"""data handling specific to DPO""" +"""Data handling specific to RL trainers.""" import inspect from functools import partial -from pathlib import Path -from typing import Any, List, Union +from typing import Any, Callable, Literal -import yaml -from datasets import Dataset, DatasetDict, concatenate_datasets, load_from_disk +from datasets import Dataset, DatasetDict +from transformers import PreTrainedTokenizer -from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH from axolotl.loaders import load_tokenizer from axolotl.prompt_strategies.dpo import load as load_dpo from axolotl.prompt_strategies.kto import load as load_kto from axolotl.prompt_strategies.orpo import load as load_orpo -from axolotl.utils.data.shared import datasets_w_name_generator, load_dataset_w_config -from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5 +from axolotl.utils.data.lock import FileLockLoader +from axolotl.utils.data.shared import ( + create_train_validation_split, + datasets_with_name_generator, + generate_dataset_hash_from_config, + load_dataset_with_config, + load_preprocessed_dataset, + merge_datasets, + save_preprocessed_dataset, + try_load_from_hub, +) +from axolotl.utils.data.utils import ( + deduplicate_and_log_datasets, + retry_on_request_exceptions, +) from axolotl.utils.dict import DictDefault -from axolotl.utils.distributed import is_main_process, zero_first from axolotl.utils.logging import get_logger from axolotl.utils.schemas.enums import RLType LOG = get_logger(__name__) -def _get_path(ds_hash, cfg): - prepared_ds_path = ( - Path(cfg.dataset_prepared_path) / ds_hash - if cfg.dataset_prepared_path - else Path(DEFAULT_DATASET_PREPARED_PATH) / ds_hash - ) +@retry_on_request_exceptions(max_retries=3, delay=5) +def prepare_preference_datasets( + cfg: DictDefault, tokenizer: PreTrainedTokenizer +) -> tuple[Dataset, Dataset | None]: + """Load and prepare preference datasets for RL training. - return prepared_ds_path + Loads training and evaluation datasets, handling preprocessing, caching, and + deduplication as configured. Uses FileLock for distributed coordination. + + Args: + cfg: Configuration object containing dataset and training settings. + tokenizer: Tokenizer to use for processing text. + + Returns: + Tuple of (train_dataset, eval_dataset). eval_dataset may be None + if no evaluation dataset is configured. + """ + + def _load_datasets(): + # Load training dataset + train_dataset = _load_or_create_dataset_split(cfg, tokenizer, split="train") + + # Load or create evaluation dataset + eval_dataset: Dataset | None = None + if cfg.test_datasets: + eval_dataset = _load_or_create_dataset_split(cfg, tokenizer, split="test") + elif cfg.val_set_size: + # Create validation split from training data + train_dataset, eval_dataset = create_train_validation_split( + train_dataset, cfg, cfg.val_set_size + ) + + return train_dataset, eval_dataset + + # Prepare datasets (with file locking logic for multiple ranks) + loader = FileLockLoader(cfg) + try: + train_dataset, eval_dataset = loader.load(_load_datasets) + finally: + loader.cleanup() + + # Apply deduplication if configured + if cfg.dataset_exact_deduplication: + train_dataset, eval_dataset = deduplicate_and_log_datasets( + dataset=train_dataset, other_dataset=eval_dataset + ) + + return train_dataset, eval_dataset -def _load_preprocessed_ds(cfg, sub_cfg): - ds_hash = md5(yaml.dump(sub_cfg, Dumper=yaml.Dumper)) - prepared_ds_path = _get_path(ds_hash, cfg) - dataset = None +def _map_dataset( + cfg: DictDefault, + dataset: Dataset | DatasetDict, + ds_transform_fn: Callable[..., Any], + tokenizer: Any | None = None, + **map_kwargs: Any, +) -> Dataset: + """Apply transformation function to dataset. - # pylint: disable=duplicate-code - if ( - cfg.dataset_prepared_path - and any(prepared_ds_path.glob("*")) - and not cfg.is_preprocess - ): - LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...") - dataset = load_from_disk(str(prepared_ds_path)) + Args: + cfg: Configuration object. + dataset: Dataset to transform. + ds_transform_fn: Transformation function to apply. + tokenizer: Optional tokenizer for transformation. + **map_kwargs: Additional arguments for dataset mapping. - return dataset - - -def _save_preprocessed_ds(cfg, sub_cfg, dataset): - ds_hash = md5(yaml.dump(sub_cfg, Dumper=yaml.Dumper)) - prepared_ds_path = _get_path(ds_hash, cfg) - - if cfg.is_preprocess and is_main_process(): - LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...") - dataset.save_to_disk(str(prepared_ds_path)) - - -def map_dataset(cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs): + Returns: + Transformed dataset. + """ sig = inspect.signature(ds_transform_fn) if "tokenizer" in sig.parameters: if not tokenizer: tokenizer = load_tokenizer(cfg) ds_transform_fn = partial(ds_transform_fn, tokenizer=tokenizer) - if isinstance(data_set, DatasetDict): - data_set = data_set["train"] + if isinstance(dataset, DatasetDict): + dataset = dataset["train"] - data_set = data_set.map( + dataset = dataset.map( ds_transform_fn, num_proc=cfg.dataset_processes, load_from_cache_file=not cfg.is_preprocess, @@ -77,13 +119,27 @@ def map_dataset(cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs): **map_kwargs, ) - return data_set + return dataset -def drop_long_rl_seq( - sample, rl, tokenizer, sequence_len # pylint: disable=invalid-name -): - if rl in (RLType.DPO, RLType.IPO, RLType.ORPO, RLType.SIMPO): +def _drop_long_sequences( + sample: dict[str, Any], rl: RLType, tokenizer: Any, sequence_len: int +) -> bool: + """Filter out samples that exceed maximum sequence length. + + Args: + sample: Dataset sample to check. + rl: Reinforcement learning type. + tokenizer: Tokenizer for length calculation. + sequence_len: Maximum allowed sequence length. + + Returns: + True if sample should be kept, False if it should be dropped. + + Raises: + ValueError: If required keys are missing or RL type is unknown. + """ + if rl in {RLType.DPO, RLType.IPO, RLType.ORPO, RLType.SIMPO}: if not ( sample.get("prompt") and sample.get("chosen") and sample.get("rejected") ): @@ -123,132 +179,115 @@ def drop_long_rl_seq( raise ValueError("Unknown RL type") -def load_prepare_preference_datasets(cfg): - def load_split(dataset_cfgs, _cfg): - split_datasets: List[Any] = [] - use_auth_token = _cfg.hf_use_auth_token - for config_dataset in datasets_w_name_generator(dataset_cfgs): - ds: Union[Dataset, DatasetDict] = load_dataset_w_config( - config_dataset, use_auth_token, streaming=False - ) - split_datasets.append(ds) +def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset: + """Load and process dataset split for RL training. - tokenizer = load_tokenizer(cfg) + Args: + cfg: Configuration object containing dataset settings. + split: Dataset split to load ("train" or "test"). - for i, data_set in enumerate(split_datasets): - _type = dataset_cfgs[i]["type"] - if _type: - if isinstance(_type, DictDefault): - _type = "user_defined.default" - if _cfg.rl is RLType.ORPO: - ds_transform_fn = load_orpo(_type, _cfg, dataset_idx=i) - elif _cfg.rl is RLType.KTO: - ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i) - else: - ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i) + Returns: + Combined and processed dataset for the specified split. + """ + datasets_configs = cfg.datasets if split == "train" else cfg.test_datasets + split_datasets: list[Dataset | DatasetDict] = [] - map_kwargs = {} - if isinstance(ds_transform_fn, tuple): - ds_transform_fn, map_kwargs = ds_transform_fn - split_datasets[i] = map_dataset( - cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs - ) - elif _cfg.rl is RLType.KTO: - ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i) - map_kwargs = {} - if isinstance(ds_transform_fn, tuple): - ds_transform_fn, map_kwargs = ds_transform_fn - split_datasets[i] = map_dataset( - cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs - ) - else: - # If no `type` is provided, assume the dataset is already in the expected format with - # "prompt", "chosen" and "rejected" already preprocessed - split_datasets[i] = data_set - - if not cfg.skip_prepare_dataset: - drop_long = partial( - drop_long_rl_seq, - rl=_cfg.rl, - tokenizer=tokenizer, - sequence_len=cfg.sequence_len, - ) - - prior_len = len(split_datasets[i]) - split_datasets[i] = split_datasets[i].filter( - drop_long, - num_proc=cfg.dataset_processes, - load_from_cache_file=not cfg.is_preprocess, - desc="Dropping Long Sequences", - ) - dropped = prior_len - len(split_datasets[i]) - if dropped: - LOG.warning( - f"Dropped {dropped} long samples from dataset index {i}" - ) - - combined_datasets = concatenate_datasets(split_datasets) - combined_datasets = combined_datasets.shuffle(seed=cfg.seed or 42) - - return combined_datasets - - with zero_first(is_main_process()): - train_is_preprocessed = False - eval_is_preprocessed = False - if train_dataset := _load_preprocessed_ds(cfg, cfg.datasets): - train_is_preprocessed = True - else: - train_dataset = load_split(cfg.datasets, cfg) - - eval_dataset = None - if cfg.test_datasets: - if eval_dataset := _load_preprocessed_ds(cfg, cfg.test_datasets): - eval_is_preprocessed = True - else: - eval_dataset = load_split(cfg.test_datasets, cfg) - if not eval_dataset: - if cfg.val_set_size: - seed = cfg.seed if cfg.seed is not None else 42 - - # ensure we end up with the same fingerprint by doing rank0 first and being able to cache - to_hash_train = ( - train_dataset._fingerprint # pylint: disable=protected-access - + "|" - + str(cfg.val_set_size) - + "|" - + "train" - + "|" - + str(cfg.seed or 42) - ) - to_hash_test = ( - train_dataset._fingerprint # pylint: disable=protected-access - + "|" - + str(cfg.val_set_size) - + "|" - + "test" - + "|" - + str(cfg.seed or 42) - ) - train_fingerprint = md5(to_hash_train) - test_fingerprint = md5(to_hash_test) - ds_w_test_split = train_dataset.train_test_split( - test_size=cfg.val_set_size, - seed=seed, - shuffle=False, - train_new_fingerprint=train_fingerprint, - test_new_fingerprint=test_fingerprint, - ) - eval_dataset = ds_w_test_split["test"] - train_dataset = ds_w_test_split["train"] - - if not train_is_preprocessed: - _save_preprocessed_ds(cfg, cfg.datasets, train_dataset) - if eval_dataset and not eval_is_preprocessed: - _save_preprocessed_ds(cfg, cfg.test_datasets, eval_dataset) - - if cfg.dataset_exact_deduplication: - train_dataset, eval_dataset, _ = deduplicate_and_log_datasets( - train_dataset=train_dataset, eval_dataset=eval_dataset + for dataset_config in datasets_with_name_generator(datasets_configs): + dataset: Dataset | DatasetDict = load_dataset_with_config( + dataset_config, cfg.hf_use_auth_token, streaming=False ) + split_datasets.append(dataset) - return train_dataset, eval_dataset + tokenizer = load_tokenizer(cfg) + + for i, dataset in enumerate(split_datasets): + _type = datasets_configs[i]["type"] + if _type: + if isinstance(_type, DictDefault): + _type = "user_defined.default" + if cfg.rl is RLType.ORPO: + ds_transform_fn = load_orpo(_type, cfg, dataset_idx=i) + elif cfg.rl is RLType.KTO: + ds_transform_fn = load_kto(_type, cfg, dataset_idx=i) + else: + ds_transform_fn = load_dpo(_type, cfg, dataset_idx=i) + + map_kwargs: dict[str, Any] = {} + if isinstance(ds_transform_fn, tuple): + ds_transform_fn, map_kwargs = ds_transform_fn + split_datasets[i] = _map_dataset( + cfg, dataset, ds_transform_fn, tokenizer, **map_kwargs + ) + else: + # If no `type` is provided, assume the dataset is already in the expected format with + # "prompt", "chosen", and "rejected" already preprocessed + split_datasets[i] = dataset + + if not cfg.skip_prepare_dataset: + drop_long = partial( + _drop_long_sequences, + rl=cfg.rl, + tokenizer=tokenizer, + sequence_len=cfg.sequence_len, + ) + + prior_len = len(split_datasets[i]) + split_datasets[i] = split_datasets[i].filter( + drop_long, + num_proc=cfg.dataset_processes, + load_from_cache_file=not cfg.is_preprocess, + desc="Dropping Long Sequences", + ) + dropped = prior_len - len(split_datasets[i]) + if dropped: + LOG.warning(f"Dropped {dropped} long samples from dataset index {i}") + + # Merge datasets + dataset = merge_datasets(split_datasets, cfg) + + if not cfg.skip_prepare_dataset: + # Save preprocessed dataset + dataset_hash = generate_dataset_hash_from_config( + cfg, datasets_configs, tokenizer.name_or_path + ) + save_preprocessed_dataset(cfg, dataset, dataset_hash, split) + + return dataset + + +# pylint: disable=duplicate-code +def _load_or_create_dataset_split( + cfg: DictDefault, tokenizer: PreTrainedTokenizer, split: Literal["train", "test"] +) -> Dataset: + """Load preprocessed dataset or create new one for given split. + + Args: + cfg: Configuration object. + tokenizer: Tokenizer to use for processing text. + split: Dataset split to load. + + Returns: + Tuple of (dataset, is_preprocessed). + """ + # Select correct dataset configuration based on split + datasets_config = cfg.datasets if split == "train" else cfg.test_datasets + + # Generate dataset hash for caching + dataset_hash = generate_dataset_hash_from_config( + cfg, datasets_config, tokenizer.name_or_path + ) + + # Try loading from hub if push_dataset_to_hub is configured + dataset = None + if cfg.push_dataset_to_hub: + dataset = try_load_from_hub(cfg, dataset_hash, split) + + # Attempt to load preprocessed dataset + if dataset is None: + dataset = load_preprocessed_dataset(cfg, dataset_hash) + + # Otherwise, load it + if dataset is None: + dataset = _load_split(cfg, split=split) + + return dataset diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index 88c78174b..d0b8ab743 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -1,58 +1,38 @@ -"""data handling specific to SFT""" +"""Data handling specific to SFT.""" import functools -import os import tempfile -from pathlib import Path -from typing import List, Optional, Tuple, Union +from typing import Literal from datasets import ( Dataset, DatasetDict, IterableDataset, - Sequence, - Value, - concatenate_datasets, load_dataset, - load_from_disk, ) -from transformers import PreTrainedTokenizerBase +from transformers import PreTrainedTokenizer, ProcessorMixin -from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH -from axolotl.datasets import TokenizedPromptDataset, wrap_dataset_for_tokenized_prompt -from axolotl.prompt_strategies import load -from axolotl.prompt_strategies.bradley_terry import load as bradley_terry_load -from axolotl.prompt_tokenizers import ( - AlpacaMultipleChoicePromptTokenizingStrategy, - AlpacaPromptTokenizingStrategy, - AlpacaReflectionPTStrategy, - DatasetWrappingStrategy, - GPTeacherPromptTokenizingStrategy, - JeopardyPromptTokenizingStrategy, - OpenAssistantPromptTokenizingStrategy, - SummarizeTLDRPromptTokenizingStrategy, -) -from axolotl.prompters import ( - AlpacaPrompter, - GPTeacherPrompter, - JeopardyPrompter, - MultipleChoiceConcisePrompter, - MultipleChoiceExplainPrompter, - Prompter, - ReflectAlpacaPrompter, - SummarizeTLDRPrompter, - UnsupportedPrompter, -) +from axolotl.prompters import Prompter +from axolotl.utils.data.lock import FileLockLoader from axolotl.utils.data.pretraining import wrap_pretraining_dataset -from axolotl.utils.data.shared import datasets_w_name_generator, load_dataset_w_config +from axolotl.utils.data.shared import ( + create_train_validation_split, + datasets_with_name_generator, + generate_dataset_hash_from_config, + load_dataset_with_config, + load_preprocessed_dataset, + merge_datasets, + save_preprocessed_dataset, + try_load_from_hub, +) from axolotl.utils.data.utils import ( deduplicate_and_log_datasets, drop_long_seq_in_dataset, - md5, retry_on_request_exceptions, ) +from axolotl.utils.data.wrappers import get_dataset_wrapper from axolotl.utils.dict import DictDefault -from axolotl.utils.distributed import is_local_main_process, zero_first +from axolotl.utils.distributed import is_local_main_process from axolotl.utils.logging import get_logger from axolotl.utils.trainer import ( calculate_total_num_steps, @@ -63,121 +43,77 @@ LOG = get_logger(__name__) @retry_on_request_exceptions(max_retries=3, delay=5) -def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None): - prompters = [] - if not cfg.pretraining_dataset: - with zero_first(is_local_main_process()): - if cfg.test_datasets: - train_dataset, _, prompters = load_prepare_datasets( - tokenizer, - cfg, - DEFAULT_DATASET_PREPARED_PATH, - split="train", - processor=processor, - preprocess_iterable=preprocess_iterable, - ) - _, eval_dataset, _ = load_prepare_datasets( - tokenizer, - cfg, - DEFAULT_DATASET_PREPARED_PATH, - split="test", - processor=processor, - preprocess_iterable=preprocess_iterable, - ) - else: - train_dataset, eval_dataset, prompters = load_prepare_datasets( - tokenizer, - cfg, - DEFAULT_DATASET_PREPARED_PATH, - processor=processor, - preprocess_iterable=preprocess_iterable, - ) - else: - # Load streaming dataset if pretraining_dataset is given - path = cfg.pretraining_dataset - split = "train" - name = None - data_files = None - skip = 0 - if isinstance(cfg.pretraining_dataset, list) and isinstance( - cfg.pretraining_dataset[0], dict - ): - path = cfg.pretraining_dataset[0]["path"] - name = cfg.pretraining_dataset[0]["name"] - skip = cfg.pretraining_dataset[0]["skip"] - if "split" in cfg.pretraining_dataset[0]: - split = cfg.pretraining_dataset[0]["split"] +def prepare_datasets( + cfg: DictDefault, + tokenizer: PreTrainedTokenizer, + processor: ProcessorMixin | None = None, + preprocess_iterable: bool = False, +) -> tuple[IterableDataset | Dataset, Dataset | None, int, list[Prompter | None]]: + """Prepare training and evaluation datasets based on configuration. - data_files = cfg.pretraining_dataset[0].get("data_files") + Args: + cfg: Dictionary mapping `axolotl` config keys to values. + tokenizer: Tokenizer to use for processing text. + processor: Optional processor for multimodal datasets. + preprocess_iterable: Whether to use iterable preprocessing. - ds_wrapper_partial = functools.partial( - get_dataset_wrapper, - cfg.pretraining_dataset[0], + Returns: + Tuple of (train_dataset, eval_dataset, total_steps, prompters). + """ + if cfg.pretraining_dataset: + return _prepare_pretraining_dataset( + cfg, tokenizer, processor, preprocess_iterable + ) + return _prepare_standard_dataset(cfg, tokenizer, processor, preprocess_iterable) + + +def _prepare_standard_dataset( + cfg: DictDefault, + tokenizer: PreTrainedTokenizer, + processor: ProcessorMixin | None, + preprocess_iterable: bool, +) -> tuple[Dataset, Dataset | None, int, list[Prompter | None]]: + """Prepare standard (non-pretraining) datasets.""" + + def _load_datasets(): + # Always load training dataset + train_dataset, eval_dataset, prompters = _load_and_prepare_datasets( tokenizer, cfg, - cfg.pretraining_dataset[0]["type"] or "pretrain", + split="train", + processor=processor, + preprocess_iterable=preprocess_iterable, ) - # when letting accelerator dispatch batches from the main process, we don't need to load the dataset from - # other ranks, we just need to present a fake dataset - if ( - cfg.accelerator_config - and cfg.accelerator_config.dispatch_batches - and not is_local_main_process() - ): - with tempfile.NamedTemporaryFile(mode="w+", delete=False) as f: - f.write("text\n") - f.write("lorem ipsum dolor sit amet\n") - # rewind the file pointer to the beginning so we can read it again - f.seek(0) - iter_ds = load_dataset( - "csv", data_files=f.name, split="train", streaming=True - ) - else: - iter_ds = load_dataset( - path, streaming=True, split=split, name=name, data_files=data_files - ) - - if skip: - LOG.info(f"Skipping {skip} samples from the dataset") - iter_ds = iter_ds.skip(skip) - train_dataset = wrap_pretraining_dataset( - iter_ds, - tokenizer, - cfg, - ds_wrapper_partial, - max_tokens=cfg.sequence_len, - batch_size=cfg.micro_batch_size, - seed=cfg.seed if cfg.seed is not None else 42, - buffer_size=cfg.pretrain_multipack_buffer_size or 10_000, - ) - # https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230 - train_dataset = train_dataset.with_format("torch") - - # Load eval dataset (non-streaming) if specified - eval_dataset = None + # Overwrite eval_dataset if test data exists if cfg.test_datasets: - _, eval_dataset, _ = load_prepare_datasets( + _, eval_dataset, _ = _load_and_prepare_datasets( tokenizer, cfg, - DEFAULT_DATASET_PREPARED_PATH, split="test", processor=processor, preprocess_iterable=preprocess_iterable, ) - if cfg.dataset_exact_deduplication: - LOG.info("Deduplication not available for pretrained datasets") + return train_dataset, eval_dataset, prompters - return train_dataset, eval_dataset, cfg.max_steps, prompters + # Prepare datasets (with file locking logic for multiple ranks) + loader = FileLockLoader(cfg) + try: + train_dataset, eval_dataset, prompters = loader.load(_load_datasets) + finally: + loader.cleanup() + # Validate sample packing configuration for evaluation if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False: total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False) if total_eval_steps == 0: raise ValueError( - "eval dataset split is too small for sample_packing. You should set `eval_sample_packing: False`. " + "eval dataset split is too small for sample_packing. " + "You should set `eval_sample_packing: False` in your config." ) + # Calculate total number of training steps if cfg.max_steps: total_num_steps = min( calculate_total_num_steps(cfg, train_dataset), cfg.max_steps @@ -188,219 +124,338 @@ def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None): return train_dataset, eval_dataset, total_num_steps, prompters -def load_tokenized_prepared_datasets( - tokenizer, - cfg, - default_dataset_prepared_path, - split="train", - processor=None, - preprocess_iterable: Optional[bool] = None, -) -> Tuple[DatasetDict, List[Prompter]]: - cfg_datasets = cfg.test_datasets if split == "test" else cfg.datasets - tokenizer_name = cfg.tokenizer_config +def _prepare_pretraining_dataset( + cfg: DictDefault, + tokenizer: PreTrainedTokenizer, + processor: ProcessorMixin | None, + preprocess_iterable: bool, +) -> tuple[IterableDataset, Dataset | None, int, list[Prompter | None]]: + """ + Prepare dataset for pretraining mode. - ds_hash = str( - md5( - ( - str(cfg.sequence_len) - + "@" - + str(cfg.sample_packing) - + "@" - + str(cfg.eval_sample_packing) - + "@" - + str(cfg.group_by_length) - + "@" - + str(cfg.kd_temperature or 1.0) - + "|".join( - sorted( - [ - f"{d.path}:{d.type}:{d.shards}:{d.conversation}:{d.split}:{d.temperature or 1.0}" - for d in cfg_datasets - ] - ) - ) - + "|" - + tokenizer_name - ) + Note: Pre-training datasets are streamed from the HuggingFace Hub. + """ + # Extract pretraining dataset configuration + pretraining_config = _extract_pretraining_config(cfg) + + # Load streaming dataset for training + train_dataset = _load_pretraining_dataset(pretraining_config, cfg, tokenizer) + + # Load evaluation dataset if specified + eval_dataset = None + if cfg.test_datasets: + _, eval_dataset, _ = _load_and_prepare_datasets( + tokenizer, + cfg, + split="test", + processor=processor, + preprocess_iterable=preprocess_iterable, ) - ) - prepared_ds_path = ( - Path(cfg.dataset_prepared_path) / ds_hash - if cfg.dataset_prepared_path - else Path(default_dataset_prepared_path) / ds_hash - ) - dataset = None - prompters = [] - use_auth_token = cfg.hf_use_auth_token - try: - if cfg.push_dataset_to_hub: - LOG.info( - f"Attempting to load prepared dataset from Huggingface hub at {cfg.push_dataset_to_hub} (version {ds_hash})..." - ) - dataset = load_dataset( - cfg.push_dataset_to_hub, - ds_hash, - token=use_auth_token, - ) - dataset = dataset[split] - except Exception: # pylint: disable=broad-except # nosec - pass - # pylint: disable=duplicate-code - if dataset: - # This is for the case where we already loaded a pretokenized dataset from the hub - ... - elif ( - cfg.dataset_prepared_path - and any(prepared_ds_path.glob("*")) - and not cfg.is_preprocess - and not cfg.skip_prepare_dataset + if cfg.dataset_exact_deduplication: + LOG.info("Deduplication not available for pretrained datasets") + + # For pretraining, we return max_steps directly from config + return train_dataset, eval_dataset, cfg.max_steps, [] + + +def _extract_pretraining_config(cfg: DictDefault) -> DictDefault: + """Extract pretraining configuration from the main config.""" + if isinstance(cfg.pretraining_dataset, list) and isinstance( + cfg.pretraining_dataset[0], dict ): - LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...") - dataset = load_from_disk(str(prepared_ds_path)) - LOG.info("Prepared dataset loaded from disk...") + config = cfg.pretraining_dataset[0] + return DictDefault( + { + "path": config["path"], + "name": config["name"], + "skip": config["skip"], + "split": config.get("split", "train"), + "data_files": config.get("data_files"), + "type": config.get("type", "pretrain"), + } + ) + # Simple string path case + return DictDefault( + { + "path": cfg.pretraining_dataset, + "name": None, + "skip": 0, + "split": "train", + "data_files": None, + "type": "pretrain", + } + ) + + +def _load_pretraining_dataset( + pretraining_config: DictDefault, cfg: DictDefault, tokenizer: PreTrainedTokenizer +) -> IterableDataset: + """Load and prepare a streaming dataset for pretraining.""" + # Create dataset wrapper partial function + dataset_wrapper_partial = functools.partial( + get_dataset_wrapper, + dataset_config=pretraining_config, + tokenizer=tokenizer, + cfg=cfg, + dataset_base_type=pretraining_config["type"], + ) + + # Load the actual dataset + if ( + cfg.accelerator_config + and cfg.accelerator_config.dispatch_batches + and not is_local_main_process() + ): + iter_dataset = _create_placeholder_dataset() else: - if cfg.push_dataset_to_hub: - LOG.info("Unable to find prepared dataset in Huggingface hub") - if cfg.is_preprocess: - LOG.info( - f"Skipping prepared dataset in {prepared_ds_path} for pre-processing..." - ) - else: - LOG.info(f"Unable to find prepared dataset in {prepared_ds_path}") - LOG.info("Loading raw datasets...") - if not cfg.is_preprocess: - LOG.warning( - "Processing datasets during training can lead to VRAM instability. Please pre-process your dataset." - ) + iter_dataset = load_dataset( + pretraining_config["path"], + streaming=True, + split=pretraining_config["split"], + name=pretraining_config["name"], + data_files=pretraining_config["data_files"], + ) - if cfg.seed: - seed = cfg.seed - else: - LOG.info("No seed provided, using default seed of 42") - seed = 42 + # Apply skip if specified + if pretraining_config["skip"]: + LOG.info(f"Skipping {pretraining_config['skip']} samples from the dataset") + iter_dataset = iter_dataset.skip(pretraining_config["skip"]) - datasets = [] + # Wrap the dataset for pretraining + train_dataset = wrap_pretraining_dataset( + iter_dataset, + tokenizer, + cfg, + dataset_wrapper_partial, + max_tokens=cfg.sequence_len, + batch_size=cfg.micro_batch_size, + seed=cfg.seed, + buffer_size=cfg.pretrain_multipack_buffer_size or 10_000, + ) - streaming_ds = False - if preprocess_iterable: - streaming_ds = True - # pylint: disable=invalid-name - for config_dataset in datasets_w_name_generator(cfg_datasets): - ds: Union[Dataset, DatasetDict] = load_dataset_w_config( - config_dataset, use_auth_token, streaming=streaming_ds - ) + # Format for PyTorch + return train_dataset.with_format("torch") - d_base_type = d_prompt_style = None - d_type = config_dataset.type - if isinstance(d_type, str): - d_type_split = d_type.split(":") - d_base_type = d_type_split[0] - d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None - if isinstance(ds, DatasetDict): - if config_dataset.split and config_dataset.split in ds: - ds = ds[config_dataset.split] - elif split in ds: - ds = ds[split] - else: - raise ValueError( - f"no {split} split found for dataset {config_dataset.path}, you may specify a split with 'split: `" - ) +def _create_placeholder_dataset() -> IterableDataset: + """Create a minimal placeholder dataset for non-main processes.""" + with tempfile.NamedTemporaryFile(mode="w+", delete=False) as f: + f.write("text\n") + f.write("lorem ipsum dolor sit amet\n") + f.seek(0) + return load_dataset("csv", data_files=f.name, split="train", streaming=True) - # support for using a subset of the data - if config_dataset.shards: - shards_idx = config_dataset.get("shards_idx", 0) - ds = ds.shuffle(seed=seed).shard( - num_shards=config_dataset.shards, index=shards_idx - ) - dataset_wrapper, dataset_prompter = get_dataset_wrapper( - config_dataset=config_dataset, - tokenizer=tokenizer, - cfg=cfg, - d_base_type=d_base_type, - dataset=ds, - d_prompt_style=d_prompt_style, - processor=processor, - ) - datasets.append(dataset_wrapper) - prompters.append(dataset_prompter) +def _load_tokenized_prepared_datasets( + tokenizer: PreTrainedTokenizer, + cfg: DictDefault, + split: Literal["train", "test"] = "train", + processor: ProcessorMixin | None = None, + preprocess_iterable: bool = False, +) -> tuple[Dataset | DatasetDict, list[Prompter | None]]: + """Load or create tokenized and prepared datasets for training or testing. - if len(datasets) == 1: - dataset = datasets[0] - else: - LOG.info("Merging datasets...") - dataset = concatenate_datasets(datasets) + Args: + tokenizer: Tokenizer for processing text. + cfg: Configuration object. + split: Dataset split to load ('train' or 'test'). + processor: Optional processor for multimodal datasets. + preprocess_iterable: Whether to use iterable preprocessing. - if len(datasets) > 1: - if cfg.shuffle_merged_datasets: - LOG.debug("Shuffling merged datasets...") - dataset = dataset.shuffle(seed=seed) - else: - LOG.debug("NOT shuffling merged datasets") + Returns: + Tuple of (dataset, prompters list). + """ + # Select correct dataset configuration based on split + datasets_configs = cfg.datasets if split == "train" else cfg.test_datasets - if not cfg.skip_prepare_dataset: - dataset = drop_long_seq_in_dataset(dataset, cfg) + # Generate dataset hash for caching + dataset_hash = generate_dataset_hash_from_config( + cfg, datasets_configs, tokenizer.name_or_path + ) - if cfg.sample_packing: - dataset, _ = process_datasets_for_packing(cfg, dataset, None) + # Try loading from hub if push_dataset_to_hub is configured + dataset = None + if cfg.push_dataset_to_hub: + dataset = try_load_from_hub(cfg, dataset_hash, split) - if cfg.local_rank == 0 and not cfg.skip_prepare_dataset: - LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}") - if isinstance(dataset, IterableDataset): - num_workers = cfg.dataset_processes + # If not found on hub, try loading from disk + if dataset is None: + dataset = load_preprocessed_dataset(cfg, dataset_hash) - def gen_from_iter_ds(_ds, worker_id: List[int], num_workers: List[int]): - """Generator function to correctly splice the dataset for each worker""" - for i, item in enumerate(_ds): - if i % num_workers[0] == worker_id[0]: - yield item - - ds_from_iter = Dataset.from_generator( - functools.partial(gen_from_iter_ds, dataset), - features=dataset.features, - num_proc=num_workers, - split=split, - gen_kwargs={ - "worker_id": list(range(num_workers)), - "num_workers": [num_workers] * num_workers, - }, - ) - ds_from_iter.save_to_disk(str(prepared_ds_path)) - else: - os.makedirs(prepared_ds_path, exist_ok=True) - dataset.save_to_disk(str(prepared_ds_path)) - if cfg.push_dataset_to_hub: - LOG.info( - f"Pushing merged prepared dataset to Huggingface hub at {cfg.push_dataset_to_hub} (version {ds_hash})..." - ) - dataset.push_to_hub( - cfg.push_dataset_to_hub, - ds_hash, - private=True, - ) + # If not found on disk or skipping prepared dataset, load and process raw datasets + prompters: list[Prompter | None] = [] + if dataset is None: + dataset, prompters = _load_raw_datasets( + cfg, + datasets_configs, + tokenizer, + split, + processor, + preprocess_iterable, + ) return dataset, prompters -def load_prepare_datasets( - tokenizer: PreTrainedTokenizerBase, - cfg, - default_dataset_prepared_path, - split="train", - processor=None, - preprocess_iterable: Optional[bool] = False, -) -> Tuple[Dataset, Dataset, List[Prompter]]: - dataset, prompters = load_tokenized_prepared_datasets( - tokenizer, - cfg, - default_dataset_prepared_path, - split=split, - processor=processor, - preprocess_iterable=preprocess_iterable, +def _load_raw_datasets( + cfg: DictDefault, + datasets_configs: list, + tokenizer: PreTrainedTokenizer, + split: str, + processor: ProcessorMixin | None = None, + preprocess_iterable: bool = False, +) -> tuple[Dataset, list[Prompter | None]]: + """Load, process, merge, and save raw datasets.""" + LOG.info("Loading raw datasets...", main_process_only=False) + if not cfg.is_preprocess: + LOG.warning( + "Processing datasets during training can lead to VRAM instability. Please " + "pre-process your dataset using `axolotl preprocess path/to/config.yml`." + ) + + # Load and process individual datasets + datasets = [] + prompters = [] + for dataset_config in datasets_with_name_generator(datasets_configs): + dataset_wrapper, dataset_prompter = _load_and_process_single_dataset( + dataset_config=dataset_config, + cfg=cfg, + tokenizer=tokenizer, + split=split, + seed=cfg.seed, + processor=processor, + preprocess_iterable=preprocess_iterable, + ) + datasets.append(dataset_wrapper) + prompters.append(dataset_prompter) + + # Merge datasets + dataset = merge_datasets(datasets, cfg) + + if not cfg.skip_prepare_dataset: + dataset = drop_long_seq_in_dataset(dataset, cfg) + if cfg.sample_packing: + dataset, _ = process_datasets_for_packing(cfg, dataset, None) + + # Save the prepared dataset + dataset_hash = generate_dataset_hash_from_config( + cfg, datasets_configs, tokenizer.name_or_path + ) + save_preprocessed_dataset(cfg, dataset, dataset_hash, split) + + return dataset, prompters + + +def _load_and_process_single_dataset( + dataset_config: DictDefault, + cfg: DictDefault, + tokenizer: PreTrainedTokenizer, + split: str, + seed: int, + processor: ProcessorMixin | None = None, + preprocess_iterable: bool = False, +) -> tuple[Dataset | IterableDataset, Prompter | None]: + """Load and process a single dataset based on the passed config.""" + # Load the dataset + dataset = load_dataset_with_config( + dataset_config, cfg.hf_use_auth_token, streaming=preprocess_iterable ) + # Parse dataset type + d_base_type, d_prompt_style = _parse_dataset_type(dataset_config.type) + + # Select the appropriate split + if isinstance(dataset, DatasetDict): + if dataset_config.split and dataset_config.split in dataset: + dataset = dataset[dataset_config.split] + elif split in dataset: + dataset = dataset[split] + else: + raise ValueError( + f"no {split} split found for dataset {dataset_config.path}, you may " + "specify a split with 'split: ...'" + ) + + # Apply sharding if configured + if dataset_config.shards: + shards_idx = dataset_config.get("shards_idx", 0) + dataset = dataset.shuffle(seed=seed).shard( + num_shards=dataset_config.shards, index=shards_idx + ) + + # Apply dataset wrapper + dataset_wrapper, dataset_prompter = get_dataset_wrapper( + dataset_config=dataset_config, + tokenizer=tokenizer, + cfg=cfg, + dataset_base_type=d_base_type, + dataset=dataset, + dataset_prompt_style=d_prompt_style, + processor=processor, + ) + + return dataset_wrapper, dataset_prompter + + +def _parse_dataset_type(d_type: str) -> tuple[str | None, str | None]: + """Parse the dataset type string into base type and prompt style.""" + if not isinstance(d_type, str): + return None, None + + d_type_split = d_type.split(":") + d_base_type = d_type_split[0] + d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None + + return d_base_type, d_prompt_style + + +def _handle_train_dataset_split( + dataset: Dataset, cfg: DictDefault +) -> tuple[Dataset, Dataset | None]: + """Handle processing for train split, including validation set creation.""" + val_set_size = ( + int(cfg.val_set_size) if cfg.val_set_size > 1 else float(cfg.val_set_size) + ) + + if val_set_size: + # Create train/validation split + train_dataset, eval_dataset = create_train_validation_split( + dataset, cfg, val_set_size + ) + return train_dataset, eval_dataset + + # No validation split - apply deduplication if needed and return as train dataset + if cfg.dataset_exact_deduplication: + train_dataset, _ = deduplicate_and_log_datasets(dataset=dataset) + else: + train_dataset = dataset + + return train_dataset, None + + +def _handle_test_dataset_split( + dataset: Dataset, cfg: DictDefault +) -> tuple[None, Dataset | None]: + """Handle processing for test split.""" + if cfg.dataset_exact_deduplication: + eval_dataset, _ = deduplicate_and_log_datasets(dataset=dataset) + else: + eval_dataset = dataset + + return None, eval_dataset + + +def _apply_dataset_sharding(dataset: Dataset, cfg: DictDefault) -> Dataset: + """Apply dataset sharding if configured. + + Args: + dataset: Dataset to shard. + cfg: Configuration object containing shard settings. + + Returns: + Sharded dataset or original dataset if no sharding configured. + """ if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None: LOG.info( f"Using index #{cfg.dataset_shard_idx} of {cfg.dataset_shard_num} shards" @@ -409,259 +464,44 @@ def load_prepare_datasets( num_shards=cfg.dataset_shard_num, index=cfg.dataset_shard_idx, ) + return dataset - val_set_size = ( - int(cfg.val_set_size) if cfg.val_set_size > 1 else float(cfg.val_set_size) + +def _load_and_prepare_datasets( + tokenizer: PreTrainedTokenizer, + cfg: DictDefault, + split: Literal["train", "test"] = "train", + processor: ProcessorMixin | None = None, + preprocess_iterable: bool = False, +) -> tuple[Dataset | None, Dataset | None, list[Prompter | None]]: + """Load and prepare datasets with optional validation split and sharding. + + Args: + tokenizer: Tokenizer for processing text. + cfg: Configuration object. + split: Dataset split to load ('train' or 'test'). + processor: Optional processor for multimodal datasets. + preprocess_iterable: Whether to use iterable preprocessing. + + Returns: + Tuple of (train_dataset, eval_dataset, prompters). + """ + # Load the base dataset + dataset, prompters = _load_tokenized_prepared_datasets( + tokenizer, + cfg, + split=split, + processor=processor, + preprocess_iterable=preprocess_iterable, ) - if split == "train" and val_set_size: - seed = cfg.seed if cfg.seed is not None else 42 + # Apply dataset sharding if configured using shared function + dataset = _apply_dataset_sharding(dataset, cfg) - # ensure we end up with the same fingerprint by doing rank0 first and being able to cache - to_hash_train = ( - dataset._fingerprint # pylint: disable=protected-access - + "|" - + str(val_set_size) - + "|" - + "train" - + "|" - + str(cfg.seed or 42) - ) - to_hash_test = ( - dataset._fingerprint # pylint: disable=protected-access - + "|" - + str(val_set_size) - + "|" - + "test" - + "|" - + str(cfg.seed or 42) - ) - train_fingerprint = md5(to_hash_train) - test_fingerprint = md5(to_hash_test) - if cfg.dataset_exact_deduplication: - _, _, dataset = deduplicate_and_log_datasets(dataset=dataset) - dataset = dataset.train_test_split( - test_size=val_set_size, - shuffle=False, - seed=seed, - train_new_fingerprint=train_fingerprint, - test_new_fingerprint=test_fingerprint, - ) - - train_dataset = dataset["train"] - eval_dataset = dataset["test"] - elif split == "test": - if cfg.dataset_exact_deduplication: - _, eval_dataset, _ = deduplicate_and_log_datasets(eval_dataset=dataset) - else: - eval_dataset = dataset - train_dataset = None + # Apply deduplication and create train / validation splits based on the split type + if split == "train": + train_dataset, eval_dataset = _handle_train_dataset_split(dataset, cfg) else: - if cfg.dataset_exact_deduplication: - train_dataset, _, _ = deduplicate_and_log_datasets(train_dataset=dataset) - else: - train_dataset = dataset - eval_dataset = None + train_dataset, eval_dataset = _handle_test_dataset_split(dataset, cfg) + return train_dataset, eval_dataset, prompters - - -def get_dataset_wrapper( - config_dataset, - tokenizer, - cfg, - d_base_type, - dataset, - d_prompt_style=None, - processor=None, # pylint: disable=unused-argument -): - dataset_wrapper = None - dataset_prompter = None - - ds_kwargs = { - "process_count": cfg.dataset_processes, - "keep_in_memory": cfg.dataset_keep_in_memory is True, - } - - LOG.info( - f"Loading dataset: {config_dataset['path']} with base_type: {d_base_type} and prompt_style: {d_prompt_style}" - ) - - if ( - isinstance(dataset, Dataset) - and "input_ids" in dataset.features - and "attention_mask" in dataset.features - and "labels" in dataset.features - ): - # dataset is already tokenized, just drop it straight in - dataset_prompter = UnsupportedPrompter() - dataset_wrapper = dataset - elif isinstance(config_dataset.type, DictDefault): - ds_strategy = load( - "user_defined", tokenizer, cfg, config_dataset.type.to_dict() - ) - dataset_prompter = UnsupportedPrompter() - dataset_wrapper = wrap_dataset_for_tokenized_prompt( - ds_strategy, - dataset, - **ds_kwargs, - ) - elif cfg.skip_prepare_dataset: - dataset_wrapper = dataset - elif ds_strategy := config_dataset.type.startswith( - "bradley_terry" - ) and bradley_terry_load( - config_dataset.type.split(".", 1)[1], tokenizer, cfg, config_dataset - ): - dataset_prompter = UnsupportedPrompter() - dataset_wrapper = wrap_dataset_for_tokenized_prompt( - ds_strategy, - dataset, - **ds_kwargs, - ) - elif config_dataset.type.startswith("stepwise_supervised"): - dataset_prompter = UnsupportedPrompter() - ds_strategy = load(config_dataset.type, tokenizer, cfg, config_dataset) - # we need to explicitly cast boolean labels to int - # for compatibility with how trl's PRMTrainer works - dataset = dataset.cast_column("labels", Sequence(Value("int64"))) - dataset_wrapper = TokenizedPromptDataset( - ds_strategy, - dataset, - **ds_kwargs, - ) - elif ds_strategy := load( - config_dataset.type, tokenizer, cfg, config_dataset, processor=processor - ): - if isinstance(ds_strategy, DatasetWrappingStrategy): - dataset_wrapper = ds_strategy.wrap_dataset(dataset, **ds_kwargs) - else: - dataset_prompter = UnsupportedPrompter() - dataset_wrapper = wrap_dataset_for_tokenized_prompt( - ds_strategy, - dataset, - **ds_kwargs, - ) - elif d_base_type == "alpaca": - dataset_prompter = AlpacaPrompter(d_prompt_style) - ds_strategy = AlpacaPromptTokenizingStrategy( - dataset_prompter, - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - ds_wrapper = wrap_dataset_for_tokenized_prompt( - ds_strategy, - dataset, - **ds_kwargs, - ) - dataset_wrapper = ds_wrapper - elif d_base_type == "explainchoice": - dataset_prompter = MultipleChoiceExplainPrompter(d_prompt_style) - ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy( - dataset_prompter, - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - ds_wrapper = wrap_dataset_for_tokenized_prompt( - ds_strategy, - dataset, - **ds_kwargs, - ) - dataset_wrapper = ds_wrapper - elif d_base_type == "concisechoice": - dataset_prompter = MultipleChoiceConcisePrompter(d_prompt_style) - ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy( - dataset_prompter, - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - ds_wrapper = wrap_dataset_for_tokenized_prompt( - ds_strategy, - dataset, - **ds_kwargs, - ) - dataset_wrapper = ds_wrapper - elif d_base_type == "summarizetldr": - dataset_prompter = SummarizeTLDRPrompter(d_prompt_style) - ds_strategy = SummarizeTLDRPromptTokenizingStrategy( - dataset_prompter, - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - ds_wrapper = wrap_dataset_for_tokenized_prompt( - ds_strategy, - dataset, - **ds_kwargs, - ) - dataset_wrapper = ds_wrapper - elif d_base_type == "jeopardy": - dataset_prompter = JeopardyPrompter(d_prompt_style) - ds_strategy = JeopardyPromptTokenizingStrategy( - dataset_prompter, - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - ds_wrapper = wrap_dataset_for_tokenized_prompt( - ds_strategy, - dataset, - **ds_kwargs, - ) - dataset_wrapper = ds_wrapper - elif d_base_type == "oasst": - dataset_prompter = AlpacaPrompter(d_prompt_style) - ds_strategy = OpenAssistantPromptTokenizingStrategy( - dataset_prompter, - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - ds_wrapper = wrap_dataset_for_tokenized_prompt( - ds_strategy, - dataset, - **ds_kwargs, - ) - dataset_wrapper = ds_wrapper - elif d_base_type == "gpteacher": - dataset_prompter = GPTeacherPrompter(d_prompt_style) - ds_strategy = GPTeacherPromptTokenizingStrategy( - dataset_prompter, - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - ds_wrapper = wrap_dataset_for_tokenized_prompt( - ds_strategy, - dataset, - **ds_kwargs, - ) - dataset_wrapper = ds_wrapper - elif d_base_type == "reflection": - dataset_prompter = ReflectAlpacaPrompter(d_prompt_style) - ds_strategy = AlpacaReflectionPTStrategy( - dataset_prompter, - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - ds_wrapper = wrap_dataset_for_tokenized_prompt( - ds_strategy, - dataset, - **ds_kwargs, - ) - dataset_wrapper = ds_wrapper - else: - suffix = "" - if ":load_" in config_dataset.type: - suffix = f" Did you mean {config_dataset.type.replace(':load_', '.load_')}?" - LOG.error( - f"unhandled prompt tokenization strategy: {config_dataset.type}. {suffix}" - ) - raise ValueError( - f"unhandled prompt tokenization strategy: {config_dataset.type} {suffix}" - ) - - return dataset_wrapper, dataset_prompter diff --git a/src/axolotl/utils/data/shared.py b/src/axolotl/utils/data/shared.py index d2e119f77..3c58b4c85 100644 --- a/src/axolotl/utils/data/shared.py +++ b/src/axolotl/utils/data/shared.py @@ -1,11 +1,21 @@ -""" -dataset loading shared utils -""" +"""Dataset loading shared utils.""" +from __future__ import annotations + +import functools +import os from pathlib import Path -from typing import Optional, Union +from typing import TYPE_CHECKING, Any, Generator -from datasets import Dataset, DatasetDict, load_dataset, load_from_disk +from datasets import ( + Dataset, + DatasetDict, + IterableDataset, + IterableDatasetDict, + concatenate_datasets, + load_dataset, + load_from_disk, +) from huggingface_hub import hf_hub_download, snapshot_download from huggingface_hub.errors import ( HFValidationError, @@ -13,78 +23,141 @@ from huggingface_hub.errors import ( RevisionNotFoundError, ) +from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH +from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5 from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger + +if TYPE_CHECKING: + from adlfs import AzureBlobFileSystem + from gcsfs import GCSFileSystem + from ocifs import OCIFileSystem + from s3fs import S3FileSystem + +LOG = get_logger(__name__) + +EXTENSIONS_TO_DATASET_TYPES = { + ".parquet": "parquet", + ".arrow": "arrow", + ".csv": "csv", + ".txt": "text", +} -def get_ds_type(config_dataset: DictDefault): - """ - Get the dataset type from the path if it's not specified - """ - ds_type = "json" - if config_dataset.ds_type: - ds_type = config_dataset.ds_type - elif ".parquet" in config_dataset.path: - ds_type = "parquet" - elif ".arrow" in config_dataset.path: - ds_type = "arrow" - elif ".csv" in config_dataset.path: - ds_type = "csv" - elif ".txt" in config_dataset.path: - ds_type = "text" - return ds_type +def get_dataset_type(dataset_config: DictDefault) -> str: + """Get the dataset type from the path if it's not specified.""" + if dataset_config.ds_type: + return dataset_config.ds_type + + for extension, dataset_type in EXTENSIONS_TO_DATASET_TYPES.items(): + if extension in dataset_config.path: + return dataset_type + + return "json" -def datasets_w_name_generator(dataset_configs: list[DictDefault]): - """ - Yields dataset configs handling multiple names or preprocess_shards +def datasets_with_name_generator( + dataset_configs: list[DictDefault], +) -> Generator[DictDefault, None, None]: + """Yields expanded dataset configurations based on multiple names or preprocessing + shards. + + When a dataset config has a list of names, it yields separate configs for each + name. When a dataset config specifies preprocessing shards, it yields configs for + each shard. Args: - dataset_configs: list of dataset configs (equivalent to cfg.datasets) + dataset_configs: List of dataset configuration objects. + + Yields: + Individual dataset configurations, expanded as needed for names or shards. """ - for dataset in dataset_configs: - if dataset.name and isinstance(dataset.name, list): - # load_dataset doesn't properly handle multiple named configurations - # at the same time for a given dataset - for name in dataset.name: - yield DictDefault({**dataset, "name": name}) - elif dataset.preprocess_shards and not dataset.shards: - for shard in range(dataset.preprocess_shards): + for config in dataset_configs: + if config.name and isinstance(config.name, list): + for name in config.name: + yield DictDefault({**config, "name": name}) + elif config.preprocess_shards and not config.shards: + for shard_idx in range(config.preprocess_shards): yield DictDefault( { - **dataset, - "shards": dataset.preprocess_shards, - "shards_idx": shard, + **config, + "shards": config.preprocess_shards, + "shards_idx": shard_idx, } ) else: - yield dataset + yield config -def load_dataset_w_config( - config_dataset: DictDefault, use_auth_token: bool, streaming=False -) -> Union[Dataset, DatasetDict]: - """ - Load a dataset from a config +def load_dataset_with_config( + dataset_config: DictDefault, use_auth_token: bool, streaming=False +) -> Dataset | IterableDataset: + """Load a dataset from a config. Handles datasets that are stored locally, in the + HuggingFace Hub, in a remote filesystem (S3, GCS, Azure, OCI), a URL, or + `data_files`. Args: - config_dataset: single dataset config - use_auth_token: whether to use HF auth token - streaming: whether to stream the dataset + dataset_config: Single dataset config. + use_auth_token: Whether to use HF auth token. + streaming: Whether to stream the dataset. + + Returns: + Loaded dataset. """ - # pylint: disable=invalid-name - ds: Optional[Union[Dataset, DatasetDict]] = None # pylint: disable=invalid-name - ds_from_hub = False + # Set up common kwargs for dataset loading + load_dataset_kwargs = { + "split": dataset_config.split if dataset_config.split else None, + "name": dataset_config.name, + "streaming": streaming, + "trust_remote_code": dataset_config.trust_remote_code, + } + + # First check if it's a local path + if Path(dataset_config.path).exists(): + return _load_from_local_path(dataset_config, load_dataset_kwargs) + + # Check if it's a HuggingFace dataset + is_hub_dataset = _check_if_hub_dataset(dataset_config, use_auth_token) + + # Check if it's a cloud storage path and get appropriate filesystem + remote_fs, storage_options = _get_remote_filesystem(dataset_config.path) + is_cloud_dataset = False + if remote_fs: + try: + is_cloud_dataset = remote_fs.exists(dataset_config.path) + except (FileNotFoundError, ConnectionError): + pass + + # Load from appropriate source + if is_hub_dataset: + return _load_from_hub(dataset_config, use_auth_token, load_dataset_kwargs) + if is_cloud_dataset: + return _load_from_cloud( + dataset_config, remote_fs, storage_options, load_dataset_kwargs + ) + if dataset_config.path.startswith("https://"): + return _load_from_url(dataset_config, load_dataset_kwargs) + if dataset_config.data_files: + return _load_from_data_files(dataset_config, load_dataset_kwargs) + + raise ValueError( + f"The dataset could not be loaded. This could be due to a misconfigured dataset path " + f"({dataset_config.path}). Try double-check your path / name / data_files. " + f"This is not caused by the dataset type." + ) + + +def _check_if_hub_dataset(dataset_config: DictDefault, use_auth_token: bool) -> bool: + """Check if a dataset exists on the HuggingFace Hub.""" try: - # this is just a basic check to see if the path is a - # valid HF dataset that's loadable snapshot_download( - repo_id=config_dataset.path, + repo_id=dataset_config.path, repo_type="dataset", token=use_auth_token, - revision=config_dataset.revision, + revision=dataset_config.revision, ignore_patterns=["*"], ) - ds_from_hub = True + return True except ( RepositoryNotFoundError, RevisionNotFoundError, @@ -93,198 +166,373 @@ def load_dataset_w_config( HFValidationError, ValueError, ): - pass + return False - ds_from_cloud = False - storage_options: dict = {} - remote_file_system = None - if config_dataset.path.startswith("s3://"): + +def _get_remote_filesystem( + path: str, +) -> tuple[ + S3FileSystem | GCSFileSystem | AzureBlobFileSystem | OCIFileSystem | None, dict +]: + """Get the appropriate filesystem for a remote path.""" + if path.startswith("s3://"): try: - import s3fs # type: ignore + import s3fs + + storage_options = {"anon": False} + return s3fs.S3FileSystem(**storage_options), storage_options except ImportError as exc: raise ImportError("s3:// paths require s3fs to be installed") from exc - # Reads env, credentials from ~/.aws/credentials, or IAM metadata provider - # https://s3fs.readthedocs.io/en/latest/index.html?highlight=storage_options#credentials - storage_options = {"anon": False} - remote_file_system = s3fs.S3FileSystem(**storage_options) - elif config_dataset.path.startswith("gs://") or config_dataset.path.startswith( - "gcs://" - ): + elif path.startswith(("gs://", "gcs://")): try: - import gcsfs # type: ignore + import gcsfs + + storage_options = {"token": None} # type: ignore + return gcsfs.GCSFileSystem(**storage_options), storage_options except ImportError as exc: raise ImportError( "gs:// or gcs:// paths require gcsfs to be installed" ) from exc - # gcsfs will use default credentials from the environment else anon - # https://gcsfs.readthedocs.io/en/latest/#credentials - storage_options = {"token": None} - remote_file_system = gcsfs.GCSFileSystem(**storage_options) - elif ( - config_dataset.path.startswith("adl://") - or config_dataset.path.startswith("abfs://") - or config_dataset.path.startswith("az://") - ): + elif path.startswith(("adl://", "abfs://", "az://")): try: import adlfs + + storage_options = {"anon": False} + return adlfs.AzureBlobFileSystem(**storage_options), storage_options except ImportError as exc: raise ImportError( "adl:// or abfs:// paths require adlfs to be installed" ) from exc - # # Ensure you have the following environment variables set: - # # Gen 1 - # storage_options = { - # "tenant_id": AZURE_STORAGE_TENANT_ID, - # "client_id": AZURE_STORAGE_CLIENT_ID, - # "client_secret": AZURE_STORAGE_CLIENT_SECRET, - # } - # # Gen 2 - # storage_options = { - # "account_name": AZURE_STORAGE_ACCOUNT_NAME, - # "account_key": AZURE_STORAGE_ACCOUNT_KEY, - # } - - # Reads env - # https://github.com/fsspec/adlfs?tab=readme-ov-file#setting-credentials - storage_options = {"anon": False} - remote_file_system = adlfs.AzureBlobFileSystem(**storage_options) - elif config_dataset.path.startswith("oci://"): + elif path.startswith("oci://"): try: import ocifs + + storage_options = {} + return ocifs.OCIFileSystem(**storage_options), storage_options except ImportError as exc: raise ImportError("oci:// paths require ocifs to be installed") from exc - # https://ocifs.readthedocs.io/en/latest/getting-connected.html#Using-Environment-Variables - remote_file_system = ocifs.OCIFileSystem(**storage_options) + return None, {} - try: - if remote_file_system and remote_file_system.exists(config_dataset.path): - ds_from_cloud = True - except (FileNotFoundError, ConnectionError): - pass - # gather extra args from the config - load_ds_kwargs = {} - if config_dataset.split: - load_ds_kwargs["split"] = config_dataset.split +def _load_from_local_path( + dataset_config: DictDefault, load_dataset_kwargs: dict +) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict: + """Load a dataset from a local path.""" + local_path = Path(dataset_config.path) + + if local_path.is_dir(): + if dataset_config.data_files: + dataset_type = get_dataset_type(dataset_config) + return load_dataset( + dataset_type, + data_files=dataset_config.data_files, + **load_dataset_kwargs, + ) + try: + return load_from_disk(dataset_config.path) + except FileNotFoundError: + load_dataset_kwargs["streaming"] = False + return load_dataset(dataset_config.path, **load_dataset_kwargs) + elif local_path.is_file(): + dataset_type = get_dataset_type(dataset_config) + load_dataset_kwargs["streaming"] = False + return load_dataset( + dataset_type, + data_files=dataset_config.path, + **load_dataset_kwargs, + ) else: - load_ds_kwargs["split"] = None - - # prefer local dataset, even if hub exists - local_path = Path(config_dataset.path) - if local_path.exists(): - if local_path.is_dir(): - if config_dataset.data_files: - ds_type = get_ds_type(config_dataset) - ds = load_dataset( # pylint: disable=invalid-name - ds_type, - name=config_dataset.name, - data_files=config_dataset.data_files, - streaming=streaming, - **load_ds_kwargs, - ) - else: - try: - ds = load_from_disk( - config_dataset.path - ) # pylint: disable=invalid-name - except FileNotFoundError: - ds = load_dataset( - config_dataset.path, - name=config_dataset.name, - streaming=False, - **load_ds_kwargs, - ) - elif local_path.is_file(): - ds_type = get_ds_type(config_dataset) - - ds = load_dataset( # pylint: disable=invalid-name - ds_type, - name=config_dataset.name, - data_files=config_dataset.path, - streaming=False, - **load_ds_kwargs, - ) - else: - raise ValueError( - "unhandled dataset load: local path exists, but is neither a directory or a file" - ) - elif ds_from_hub: - ds = load_dataset( - config_dataset.path, - name=config_dataset.name, - streaming=streaming, - data_files=config_dataset.data_files, - token=use_auth_token, - revision=config_dataset.revision, - trust_remote_code=config_dataset.trust_remote_code, - **load_ds_kwargs, - ) - elif ds_from_cloud and remote_file_system: - if remote_file_system.isdir(config_dataset.path): - ds = load_from_disk( - config_dataset.path, - storage_options=storage_options, - ) - elif remote_file_system.isfile(config_dataset.path): - ds_type = get_ds_type(config_dataset) - ds = load_dataset( - ds_type, - name=config_dataset.name, - data_files=config_dataset.path, - streaming=streaming, - storage_options=storage_options, - trust_remote_code=config_dataset.trust_remote_code, - **load_ds_kwargs, - ) - elif config_dataset.path.startswith("https://"): - ds_type = get_ds_type(config_dataset) - ds = load_dataset( - ds_type, - name=config_dataset.name, - data_files=config_dataset.path, - streaming=streaming, - storage_options=storage_options, - trust_remote_code=config_dataset.trust_remote_code, - **load_ds_kwargs, - ) - elif config_dataset.data_files: - fp: str | list[str] | None = None - if isinstance(config_dataset.data_files, str): - fp = hf_hub_download( - repo_id=config_dataset.path, - repo_type="dataset", - filename=config_dataset.data_files, - revision=config_dataset.revision, - ) - elif isinstance(config_dataset.data_files, list): - fp = [] - for file in config_dataset.data_files: - fp.append( - hf_hub_download( - repo_id=config_dataset.path, - repo_type="dataset", - filename=file, - revision=config_dataset.revision, - ) - ) - else: - raise ValueError("data_files must be either a string or list of strings") - ds = load_dataset( - "json", - name=config_dataset.name, - data_files=fp, - streaming=streaming, - **load_ds_kwargs, - ) - if not ds: raise ValueError( - "The dataset could not be loaded. This could be due to a misconfigured dataset path " - f"({config_dataset.path}). Try double-check your path / name / data_files. " - "This is not caused by the dataset type." + "Unhandled dataset load: local path exists, but is neither a directory or a file" ) - return ds + +def _load_from_hub( + dataset_config: DictDefault, use_auth_token: bool, load_dataset_kwargs: dict +) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict: + """Load a dataset from the HuggingFace Hub.""" + return load_dataset( + dataset_config.path, + data_files=dataset_config.data_files, + token=use_auth_token, + revision=dataset_config.revision, + **load_dataset_kwargs, + ) + + +def _load_from_cloud( + dataset_config: DictDefault, + remote_fs: S3FileSystem | GCSFileSystem | AzureBlobFileSystem | OCIFileSystem, + storage_options: dict, + load_dataset_kwargs: dict, +) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict: + """Load a dataset from cloud storage.""" + if remote_fs.isdir(dataset_config.path): + return load_from_disk( + dataset_config.path, + storage_options=storage_options, + ) + + if remote_fs.isfile(dataset_config.path): + dataset_type = get_dataset_type(dataset_config) + return load_dataset( + dataset_type, + data_files=dataset_config.path, + storage_options=storage_options, + **load_dataset_kwargs, + ) + + raise ValueError( + f"Cloud path {dataset_config.path} is neither a directory nor a file" + ) + + +def _load_from_url( + dataset_config: DictDefault, load_dataset_kwargs: dict +) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict: + """Load a dataset from a URL.""" + dataset_type = get_dataset_type(dataset_config) + return load_dataset( + dataset_type, + data_files=dataset_config.path, + **load_dataset_kwargs, + ) + + +def _load_from_data_files( + dataset_config: DictDefault, load_dataset_kwargs: dict +) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict: + """Load a dataset from data files.""" + file_path = None + + if isinstance(dataset_config.data_files, str): + file_path = hf_hub_download( + repo_id=dataset_config.path, + repo_type="dataset", + filename=dataset_config.data_files, + revision=dataset_config.revision, + ) + elif isinstance(dataset_config.data_files, list): + file_path = [ + hf_hub_download( + repo_id=dataset_config.path, + repo_type="dataset", + filename=file, + revision=dataset_config.revision, + ) + for file in dataset_config.data_files + ] + else: + raise ValueError("data_files must be either a string or list of strings") + + return load_dataset("json", data_files=file_path, **load_dataset_kwargs) + + +def generate_split_fingerprints( + dataset: Dataset, val_set_size: int | float, seed: int +) -> tuple[str, str]: + """Generate consistent fingerprints for train/test splits.""" + fingerprint = dataset._fingerprint # pylint: disable=protected-access + + train_hash_input = f"{fingerprint}|{val_set_size}|train|{seed}" + test_hash_input = f"{fingerprint}|{val_set_size}|test|{seed}" + + train_fingerprint = md5(train_hash_input) + test_fingerprint = md5(test_hash_input) + + return train_fingerprint, test_fingerprint + + +def get_prepared_dataset_path(cfg: DictDefault, dataset_hash: str) -> Path: + """Get standardized path for prepared datasets. + + Args: + cfg: Configuration object. + dataset_hash: Hash identifying the specific dataset configuration. + + Returns: + Path where the prepared dataset should be stored. + """ + base_path = cfg.dataset_prepared_path or DEFAULT_DATASET_PREPARED_PATH + return Path(base_path) / dataset_hash + + +def create_train_validation_split( + dataset: Dataset, cfg: DictDefault, val_set_size: int | float +) -> tuple[Dataset, Dataset]: + """Create train/validation split with consistent fingerprinting. + + Args: + dataset: Dataset to split. + cfg: Configuration object containing seed and other settings. + val_set_size: Size of validation set (absolute number or fraction). + + Returns: + Tuple of (train_dataset, eval_dataset). + """ + train_fingerprint, test_fingerprint = generate_split_fingerprints( + dataset, val_set_size, cfg.seed + ) + + # Apply deduplication before splitting if configured + if cfg.dataset_exact_deduplication: + dataset, _ = deduplicate_and_log_datasets(dataset=dataset) + + split_dataset = dataset.train_test_split( + test_size=val_set_size, + shuffle=False, + seed=cfg.seed, + train_new_fingerprint=train_fingerprint, + test_new_fingerprint=test_fingerprint, + ) + + return split_dataset["train"], split_dataset["test"] + + +def _generate_from_iterable_dataset( + dataset: IterableDataset, worker_id: list[int], num_workers: list[int] +) -> Generator[Any, None, None]: + """Generator function to correctly split the dataset for each worker""" + for i, item in enumerate(dataset): + if i % num_workers[0] == worker_id[0]: + yield item + + +def save_preprocessed_dataset( + cfg: DictDefault, + dataset: Dataset, + dataset_hash: str, + split: str, +) -> None: + """Save preprocessed dataset to disk and optionally push to the HF Hub.""" + prepared_ds_path = get_prepared_dataset_path(cfg, dataset_hash) + if isinstance(dataset, IterableDataset): + num_workers = cfg.dataset_processes + + ds_from_iter = Dataset.from_generator( + functools.partial(_generate_from_iterable_dataset, dataset), + features=dataset.features, + num_proc=num_workers, + split=split, + gen_kwargs={ + "worker_id": list(range(num_workers)), + "num_workers": [num_workers] * num_workers, + }, + ) + ds_from_iter.save_to_disk(str(prepared_ds_path)) + else: + os.makedirs(prepared_ds_path, exist_ok=True) + dataset.save_to_disk(str(prepared_ds_path)) + if cfg.push_dataset_to_hub: + LOG.info( + "Pushing merged prepared dataset to Huggingface hub at " + f"{cfg.push_dataset_to_hub} (version {dataset_hash})...", + main_process_only=False, + ) + dataset.push_to_hub( + cfg.push_dataset_to_hub, + dataset_hash, + private=True, + ) + + +def load_preprocessed_dataset(cfg: DictDefault, dataset_hash: str) -> Dataset | None: + """Load preprocessed dataset from disk if available. + + Args: + cfg: Configuration object. + dataset_hash: Hash identifying the dataset configuration. + + Returns: + Loaded dataset if found and conditions are met, None otherwise. + """ + prepared_ds_path = get_prepared_dataset_path(cfg, dataset_hash) + + if ( + cfg.dataset_prepared_path + and any(prepared_ds_path.glob("*")) + and not cfg.skip_prepare_dataset + and not cfg.is_preprocess + ): + LOG.info( + f"Loading prepared dataset from disk at {prepared_ds_path}...", + main_process_only=False, + ) + return load_from_disk(str(prepared_ds_path)) + + LOG.info( + f"Unable to find prepared dataset in {prepared_ds_path}", + main_process_only=False, + ) + return None + + +def try_load_from_hub( + cfg: DictDefault, dataset_hash: str, split: str +) -> Dataset | None: + """Try to load the prepared dataset from HuggingFace Hub.""" + try: + LOG.info( + "Attempting to load prepared dataset from HuggingFace Hub at " + f"{cfg.push_dataset_to_hub} (version {dataset_hash})..." + ) + dataset = load_dataset( + cfg.push_dataset_to_hub, + dataset_hash, + token=cfg.hf_use_auth_token, + ) + return dataset[split] + except Exception: # pylint: disable=broad-except # nosec + LOG.info("Unable to find prepared dataset in HuggingFace Hub") + return None + + +def generate_dataset_hash_from_config( + cfg: DictDefault, cfg_datasets: list, tokenizer_name: str +) -> str: + """Generate a hash to uniquely identify a dataset configuration for SFT. + + Args: + cfg: Main configuration object. + cfg_datasets: List of dataset configurations. + tokenizer_name: Name of the tokenizer being used. + + Returns: + MD5 hash string representing the configuration. + """ + config_str = ( + f"{cfg.sequence_len}@{cfg.sample_packing}@{cfg.eval_sample_packing}@" + f"{cfg.group_by_length}@{cfg.kd_temperature or 1.0}|" + f"{'|'.join(sorted([f'{d.path}:{d.type}:{d.shards}:{d.conversation}:{d.split}:{d.temperature or 1.0}' for d in cfg_datasets]))}" + f"|{tokenizer_name}" + ) + return str(md5(config_str)) + + +def merge_datasets(datasets: list[Dataset], cfg: DictDefault) -> Dataset: + """Merge multiple datasets into one with optional shuffling. + + Args: + datasets: List of datasets to merge. + cfg: Configuration object containing shuffle settings. + + Returns: + Merged dataset. + """ + if len(datasets) == 1: + return datasets[0] + + LOG.info("Merging datasets...") + merged_dataset = concatenate_datasets(datasets) + + if cfg.shuffle_merged_datasets: + LOG.debug("Shuffling merged datasets...") + merged_dataset = merged_dataset.shuffle(seed=cfg.seed) + else: + LOG.debug("Not shuffling merged datasets.") + + return merged_dataset diff --git a/src/axolotl/utils/data/utils.py b/src/axolotl/utils/data/utils.py index 5f3b8d3cc..0ffaa932f 100644 --- a/src/axolotl/utils/data/utils.py +++ b/src/axolotl/utils/data/utils.py @@ -1,9 +1,11 @@ -"""data handling helpers""" +"""Data handling helpers""" +import contextlib import functools import hashlib import time from enum import Enum +from typing import Callable import huggingface_hub import numpy as np @@ -19,9 +21,7 @@ LOG = get_logger(__name__) class RetryStrategy(Enum): - """ - Enum for retry strategies. - """ + """Enum for retry strategies.""" CONSTANT = 1 LINEAR = 2 @@ -30,7 +30,18 @@ class RetryStrategy(Enum): def retry_on_request_exceptions( max_retries=3, delay=1, retry_strategy: RetryStrategy = RetryStrategy.LINEAR -): +) -> Callable: + """Decorator that retries function calls on specific request exceptions. + + Args: + max_retries: Maximum number of retry attempts. + delay: Base delay between retries in seconds. + retry_strategy: Strategy for calculating retry delays. + + Returns: + Decorated function with retry logic. + """ + def decorator(func): @functools.wraps(func) def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements @@ -59,6 +70,7 @@ def retry_on_request_exceptions( def md5(to_hash: str, encoding: str = "utf-8") -> str: + """Generate MD5 hash of a string.""" try: return hashlib.md5(to_hash.encode(encoding), usedforsecurity=False).hexdigest() except TypeError: @@ -66,102 +78,89 @@ def md5(to_hash: str, encoding: str = "utf-8") -> str: def sha256(to_hash: str, encoding: str = "utf-8") -> str: + """Generate SHA256 hash of a string.""" return hashlib.sha256(to_hash.encode(encoding)).hexdigest() -def deduplicate_dataset( - dataset: Dataset, seen_hashes: dict[str, list[int]], other_dataset: Dataset = None -) -> Dataset: - unique_indices = [] +def _deduplicate_dataset( + dataset: Dataset, + seen_hashes: set[str] | None = None, +) -> tuple[Dataset, set[str]]: + """Remove duplicate rows from a dataset using SHA256 hashes. + Args: + dataset: Dataset to deduplicate. + seen_hashes: Set of previously seen row hashes (for cross-deduplication). + + Returns: + Tuple of deduplicated dataset and the set of seen hashes. + """ + if seen_hashes is None: + seen_hashes = set() + + unique_indices = [] for idx, row in enumerate(dataset): - row_hash = sha256(str(row)) # Using SHA256 for collision resistance. + row_hash = sha256(str(row)) # Using SHA256 for collision resistance if row_hash not in seen_hashes: - seen_hashes[row_hash] = [idx] + seen_hashes.add(row_hash) unique_indices.append(idx) - else: - # Check for collision by looking up the original dataset indices - original_indices = seen_hashes[row_hash] - is_duplicate = False - for original_idx in original_indices: - if ( - not idx == original_idx - and original_idx < len(dataset) - and str(dataset[original_idx]) == str(row) - ): - is_duplicate = True - break - # Check in the other dataset if provided - if other_dataset is not None: - if original_idx < len(other_dataset) and str( - other_dataset[original_idx] - ) == str(row): - is_duplicate = True - break - if not is_duplicate: - seen_hashes[row_hash].append(idx) - unique_indices.append(idx) - continue - return dataset.select(unique_indices) + + return dataset.select(unique_indices), seen_hashes def deduplicate_and_log_datasets( - *, - train_dataset: Dataset = None, - eval_dataset: Dataset = None, - dataset: Dataset = None, -) -> tuple[Dataset, Dataset, Dataset]: - """ - Deduplicates train, eval, and an optional dataset if provided, logging original and new sizes. + dataset: Dataset, + other_dataset: Dataset | None = None, + dataset_name: str | None = "train", + other_name: str | None = "eval", +) -> tuple[Dataset, Dataset | None]: + """Deduplicate datasets, with optional cross-dataset deduplication. + + Args: + dataset: Primary dataset to deduplicate. + other_dataset: Optional second dataset to deduplicate against the first. + dataset_name: Name for the primary dataset (for logging). + other_name: Name for the second dataset (for logging). Returns: - tuple: Deduplicated train, eval, and additional datasets. + Tuple of (deduplicated_dataset, deduplicated_other_dataset). """ - seen_hashes: dict[str, list[int]] = {} + # Deduplicate primary dataset + LOG.info( + f"Starting deduplication for {dataset_name} dataset. Original size: {len(dataset)}" + ) + dataset, seen_rows = _deduplicate_dataset(dataset) + LOG.info( + f"Deduplication complete for {dataset_name} dataset. New size: {len(dataset)}" + ) - # Handle cases where datasets are None - if train_dataset is not None: + # Deduplicate second dataset if provided + if other_dataset is not None: LOG.info( - f"Starting deduplication for train dataset. Original size: {len(train_dataset)}" - ) - train_dataset = deduplicate_dataset( - dataset=train_dataset, seen_hashes=seen_hashes + f"Starting deduplication for {other_name} dataset. Original size: {len(other_dataset)}" ) + other_dataset, _ = _deduplicate_dataset(other_dataset, seen_rows) LOG.info( - f"Deduplication complete for train dataset. New size: {len(train_dataset)}" - ) - else: - LOG.info("Train dataset is None. Skipping deduplication.") - - if eval_dataset is not None: - LOG.info( - f"Starting deduplication for eval dataset. Original size: {len(eval_dataset)}" - ) - eval_dataset = deduplicate_dataset( - dataset=eval_dataset, seen_hashes=seen_hashes, other_dataset=train_dataset - ) - LOG.info( - f"Deduplication complete for eval dataset. New size: {len(eval_dataset)}" - ) - else: - LOG.info("Eval dataset is None. Skipping deduplication.") - - if dataset is not None and (eval_dataset is None and train_dataset is None): - LOG.info( - f"Starting deduplication for combined dataset. Original size: {len(dataset)}" - ) - dataset = deduplicate_dataset(dataset=dataset, seen_hashes=seen_hashes) - LOG.info( - f"Deduplication complete for combined dataset. New size: {len(dataset)}" + f"Deduplication complete for {other_name} dataset. New size: {len(other_dataset)}" ) - return train_dataset, eval_dataset, dataset + return dataset, other_dataset -def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault): +def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault) -> Dataset: + """Remove sequences longer than configured maximum from dataset. + + Args: + dataset: Dataset to filter. + cfg: Dictionary mapping `axolotl` config keys to values. + + Returns: + Filtered dataset with long sequences removed. + """ if "input_ids" not in dataset.column_names: LOG.warning( - "Dataset does not contain 'input_ids' column. Skip drop long seq. This is expected for RewardModeling." + "Dataset does not contain 'input_ids' column. Skip drop long seq. This is " + "expected for reward modeling." ) return dataset @@ -171,20 +170,14 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault): min_sequence_len=cfg.min_sample_len, ) - try: + with contextlib.suppress(AttributeError): ds_lengths = get_dataset_lengths(dataset, from_arrow=True) min_input_len = np.min(ds_lengths) LOG.info(f"min_input_len: {min_input_len}") max_input_len = np.max(ds_lengths) LOG.info(f"max_input_len: {max_input_len}") - except AttributeError: - pass - try: - prior_len = len(dataset) - except TypeError: - # handle iterable datasets case - prior_len = None + prior_len = len(dataset) if hasattr(dataset, "__len__") else None filter_map_kwargs = {} if not isinstance(dataset, IterableDataset): diff --git a/src/axolotl/utils/data/wrappers.py b/src/axolotl/utils/data/wrappers.py new file mode 100644 index 000000000..b6dc42c71 --- /dev/null +++ b/src/axolotl/utils/data/wrappers.py @@ -0,0 +1,425 @@ +"""Data handling specific to SFT.""" + +import logging +from typing import Any, NoReturn, cast + +from datasets import ( + Dataset, + IterableDataset, + Sequence, + Value, +) +from transformers import PreTrainedTokenizer +from transformers.processing_utils import ProcessorMixin + +from axolotl.datasets import TokenizedPromptDataset, wrap_dataset_for_tokenized_prompt +from axolotl.prompt_strategies import load +from axolotl.prompt_strategies.bradley_terry import load as bradley_terry_load +from axolotl.prompt_tokenizers import ( + AlpacaMultipleChoicePromptTokenizingStrategy, + AlpacaPromptTokenizingStrategy, + AlpacaReflectionPTStrategy, + DatasetWrappingStrategy, + GPTeacherPromptTokenizingStrategy, + JeopardyPromptTokenizingStrategy, + OpenAssistantPromptTokenizingStrategy, + PromptTokenizingStrategy, + SummarizeTLDRPromptTokenizingStrategy, +) +from axolotl.prompters import ( + AlpacaPrompter, + GPTeacherPrompter, + JeopardyPrompter, + MultipleChoiceConcisePrompter, + MultipleChoiceExplainPrompter, + Prompter, + ReflectAlpacaPrompter, + SummarizeTLDRPrompter, + UnsupportedPrompter, +) +from axolotl.utils.dict import DictDefault + +LOG = logging.getLogger(__name__) + + +def handle_unknown_dataset_strategy(dataset_config: DictDefault) -> NoReturn: + """Raise error for unknown dataset strategy.""" + ds_type = dataset_config.type + suffix = "" + if ":load_" in ds_type: + suffix = f"Did you mean {ds_type.replace(':load_', '.load_')}?" + + error_message = f"unhandled prompt tokenization strategy: {ds_type}. {suffix}" + LOG.error(error_message) + raise ValueError(error_message) + + +# pylint: disable=too-many-return-statements +def get_dataset_wrapper( + dataset_config: DictDefault, + tokenizer: PreTrainedTokenizer, + cfg: DictDefault, + dataset_base_type: str | None, + dataset: Dataset | IterableDataset, + dataset_prompt_style: str | None = None, + processor: ProcessorMixin | None = None, # pylint: disable=unused-argument +) -> tuple[Dataset | IterableDataset, Prompter | None]: + """Create an appropriate dataset wrapper and prompter based on dataset + configuration. + + Args: + dataset_config: Configuration for the dataset. + tokenizer: Tokenizer to use for processing text. + cfg: Global configuration object. + dataset_base_type: The base type of the dataset. + dataset: The actual dataset object. + dataset_prompt_style: Optional prompt style specification. + processor: Optional processor for multimodal datasets. + + Returns: + tuple of (dataset_wrapper, dataset_prompter). + """ + # Common parameters for dataset wrapping + dataset_kwargs: dict[str, Any] = { + "process_count": cfg.dataset_processes, + "keep_in_memory": cfg.dataset_keep_in_memory is True, + } + + LOG.info( + f"Loading dataset: {dataset_config['path']} with base_type: " + f"{dataset_base_type} and prompt_style: {dataset_prompt_style}" + ) + + # Dataset is already tokenized + if _is_dataset_already_tokenized(dataset): + return dataset, UnsupportedPrompter() + + # Custom dataset type definition + if isinstance(dataset_config.type, DictDefault): + return _handle_custom_dataset_type( + dataset_config, tokenizer, cfg, dataset, dataset_kwargs + ) + + # Skip preparation if configured + if cfg.skip_prepare_dataset: + return dataset, None + + # Bradley-Terry dataset + if dataset_config.type.startswith("bradley_terry"): + return _handle_bradley_terry_dataset( + dataset_config, tokenizer, cfg, dataset, dataset_kwargs + ) + + # Stepwise supervised dataset + if dataset_config.type.startswith("stepwise_supervised"): + return _handle_stepwise_supervised_dataset( + dataset_config, tokenizer, cfg, dataset, dataset_kwargs + ) + + # Try to load prompt tokenizer / dataset wrapper strategy from registry + dataset_strategy = load( + dataset_config.type, tokenizer, cfg, dataset_config, processor=processor + ) + if dataset_strategy: + return _handle_loaded_strategy(dataset_strategy, dataset, dataset_kwargs) + + # Known dataset types with specific handling + if dataset_base_type in DATASET_HANDLERS: + handler = DATASET_HANDLERS[dataset_base_type] + return handler(dataset_prompt_style, tokenizer, cfg, dataset, dataset_kwargs) + + # Unhandled dataset type + handle_unknown_dataset_strategy(dataset_config) + + +def _is_dataset_already_tokenized(dataset: Dataset | IterableDataset) -> bool: + """Check if the dataset is already tokenized.""" + return ( + isinstance(dataset, Dataset) + and "input_ids" in dataset.features + and "attention_mask" in dataset.features + and "labels" in dataset.features + ) + + +def _handle_custom_dataset_type( + dataset_config: DictDefault, + tokenizer: PreTrainedTokenizer, + cfg: DictDefault, + dataset: Dataset | IterableDataset, + dataset_kwargs: dict[str, Any], +) -> tuple[Dataset | IterableDataset, Prompter]: + """Handle a custom dataset type defined in the configuration.""" + dataset_strategy = cast( + PromptTokenizingStrategy, + load("user_defined", tokenizer, cfg, dataset_config.type.to_dict()), + ) + dataset_prompter = UnsupportedPrompter() + dataset_wrapper = wrap_dataset_for_tokenized_prompt( + dataset_strategy, + dataset, + **dataset_kwargs, + ) + return dataset_wrapper, dataset_prompter + + +def _handle_bradley_terry_dataset( + dataset_config: DictDefault, + tokenizer: PreTrainedTokenizer, + cfg: DictDefault, + dataset: Dataset | IterableDataset, + dataset_kwargs: dict[str, Any], +) -> tuple[Dataset | IterableDataset, Prompter | None]: + """Handle a Bradley-Terry dataset.""" + bt_type = dataset_config.type.split(".", 1)[1] + dataset_strategy = bradley_terry_load(bt_type, tokenizer, cfg, dataset_config) + + if not dataset_strategy: + handle_unknown_dataset_strategy(dataset_config) + + dataset_prompter = UnsupportedPrompter() + dataset_wrapper = wrap_dataset_for_tokenized_prompt( + dataset_strategy, + dataset, + **dataset_kwargs, + ) + + return dataset_wrapper, dataset_prompter + + +def _handle_stepwise_supervised_dataset( + dataset_config: DictDefault, + tokenizer: PreTrainedTokenizer, + cfg: DictDefault, + dataset: Dataset | IterableDataset, + dataset_kwargs: dict[str, Any], +) -> tuple[Dataset | IterableDataset, Prompter]: + """Handle a stepwise supervised dataset.""" + dataset_prompter = UnsupportedPrompter() + dataset_strategy = load(dataset_config.type, tokenizer, cfg, dataset_config) + + # We need to explicitly cast boolean labels to int + # for compatibility with how trl's PRMTrainer works + if isinstance(dataset, Dataset): + dataset = dataset.cast_column("labels", Sequence(Value("int64"))) + + dataset_wrapper = TokenizedPromptDataset( + dataset_strategy, + dataset, + **dataset_kwargs, + ) + return dataset_wrapper, dataset_prompter + + +def _handle_loaded_strategy( + dataset_strategy: PromptTokenizingStrategy | DatasetWrappingStrategy, + dataset: Dataset | IterableDataset, + dataset_kwargs: dict[str, Any], +) -> tuple[Dataset | IterableDataset, Prompter | None]: + """Handle a dataset with a strategy loaded from the registry.""" + if isinstance(dataset_strategy, DatasetWrappingStrategy): + return dataset_strategy.wrap_dataset(dataset, **dataset_kwargs), None + + dataset_prompter = UnsupportedPrompter() + dataset_wrapper = wrap_dataset_for_tokenized_prompt( + dataset_strategy, + dataset, + **dataset_kwargs, + ) + return dataset_wrapper, dataset_prompter + + +def _handle_alpaca_dataset( + dataset_prompt_style: str | None, + tokenizer: PreTrainedTokenizer, + cfg: DictDefault, + dataset: Dataset | IterableDataset, + dataset_kwargs: dict[str, Any], +) -> tuple[Dataset | IterableDataset, Prompter]: + """Handle an Alpaca dataset.""" + dataset_prompter = AlpacaPrompter(dataset_prompt_style) + dataset_strategy = AlpacaPromptTokenizingStrategy( + dataset_prompter, + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) + dataset_wrapper = wrap_dataset_for_tokenized_prompt( + dataset_strategy, + dataset, + **dataset_kwargs, + ) + return dataset_wrapper, dataset_prompter + + +def _handle_explainchoice_dataset( + dataset_prompt_style: str | None, + tokenizer: PreTrainedTokenizer, + cfg: DictDefault, + dataset: Dataset | IterableDataset, + dataset_kwargs: dict[str, Any], +) -> tuple[Dataset | IterableDataset, Prompter]: + """Handle an ExplainChoice dataset.""" + dataset_prompter = MultipleChoiceExplainPrompter(dataset_prompt_style) + dataset_strategy = AlpacaMultipleChoicePromptTokenizingStrategy( + dataset_prompter, + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) + dataset_wrapper = wrap_dataset_for_tokenized_prompt( + dataset_strategy, + dataset, + **dataset_kwargs, + ) + return dataset_wrapper, dataset_prompter + + +def _handle_concisechoice_dataset( + dataset_prompt_style: str | None, + tokenizer: PreTrainedTokenizer, + cfg: DictDefault, + dataset: Dataset | IterableDataset, + dataset_kwargs: dict[str, Any], +) -> tuple[Dataset | IterableDataset, Prompter]: + """Handle a ConciseChoice dataset.""" + dataset_prompter = MultipleChoiceConcisePrompter(dataset_prompt_style) + dataset_strategy = AlpacaMultipleChoicePromptTokenizingStrategy( + dataset_prompter, + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) + dataset_wrapper = wrap_dataset_for_tokenized_prompt( + dataset_strategy, + dataset, + **dataset_kwargs, + ) + return dataset_wrapper, dataset_prompter + + +def _handle_summarizetldr_dataset( + dataset_prompt_style: str | None, + tokenizer: PreTrainedTokenizer, + cfg: DictDefault, + dataset: Dataset | IterableDataset, + dataset_kwargs: dict[str, Any], +) -> tuple[Dataset | IterableDataset, Prompter]: + """Handle a SummarizeTLDR dataset.""" + dataset_prompter = SummarizeTLDRPrompter(dataset_prompt_style) + dataset_strategy = SummarizeTLDRPromptTokenizingStrategy( + dataset_prompter, + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) + dataset_wrapper = wrap_dataset_for_tokenized_prompt( + dataset_strategy, + dataset, + **dataset_kwargs, + ) + return dataset_wrapper, dataset_prompter + + +def _handle_jeopardy_dataset( + dataset_prompt_style: str | None, + tokenizer: PreTrainedTokenizer, + cfg: DictDefault, + dataset: Dataset | IterableDataset, + dataset_kwargs: dict[str, Any], +) -> tuple[Dataset | IterableDataset, Prompter]: + """Handle a Jeopardy dataset.""" + dataset_prompter = JeopardyPrompter(dataset_prompt_style) + dataset_strategy = JeopardyPromptTokenizingStrategy( + dataset_prompter, + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) + dataset_wrapper = wrap_dataset_for_tokenized_prompt( + dataset_strategy, + dataset, + **dataset_kwargs, + ) + return dataset_wrapper, dataset_prompter + + +def _handle_oasst_dataset( + dataset_prompt_style: str | None, + tokenizer: PreTrainedTokenizer, + cfg: DictDefault, + dataset: Dataset | IterableDataset, + dataset_kwargs: dict[str, Any], +) -> tuple[Dataset | IterableDataset, Prompter]: + """Handle an OpenAssistant dataset.""" + dataset_prompter = AlpacaPrompter(dataset_prompt_style) + dataset_strategy = OpenAssistantPromptTokenizingStrategy( + dataset_prompter, + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) + dataset_wrapper = wrap_dataset_for_tokenized_prompt( + dataset_strategy, + dataset, + **dataset_kwargs, + ) + return dataset_wrapper, dataset_prompter + + +def _handle_gpteacher_dataset( + dataset_prompt_style: str | None, + tokenizer: PreTrainedTokenizer, + cfg: DictDefault, + dataset: Dataset | IterableDataset, + dataset_kwargs: dict[str, Any], +) -> tuple[Dataset | IterableDataset, Prompter]: + """Handle a GPTeacher dataset.""" + dataset_prompter = GPTeacherPrompter(dataset_prompt_style) + dataset_strategy = GPTeacherPromptTokenizingStrategy( + dataset_prompter, + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) + dataset_wrapper = wrap_dataset_for_tokenized_prompt( + dataset_strategy, + dataset, + **dataset_kwargs, + ) + return dataset_wrapper, dataset_prompter + + +def _handle_reflection_dataset( + dataset_prompt_style: str | None, + tokenizer: PreTrainedTokenizer, + cfg: DictDefault, + dataset: Dataset | IterableDataset, + dataset_kwargs: dict[str, Any], +) -> tuple[Dataset | IterableDataset, Prompter]: + """Handle a Reflection dataset.""" + dataset_prompter = ReflectAlpacaPrompter(dataset_prompt_style) + dataset_strategy = AlpacaReflectionPTStrategy( + dataset_prompter, + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) + dataset_wrapper = wrap_dataset_for_tokenized_prompt( + dataset_strategy, + dataset, + **dataset_kwargs, + ) + return dataset_wrapper, dataset_prompter + + +DATASET_HANDLERS = { + "alpaca": _handle_alpaca_dataset, + "explainchoice": _handle_explainchoice_dataset, + "concisechoice": _handle_concisechoice_dataset, + "summarizetldr": _handle_summarizetldr_dataset, + "jeopardy": _handle_jeopardy_dataset, + "oasst": _handle_oasst_dataset, + "gpteacher": _handle_gpteacher_dataset, + "reflection": _handle_reflection_dataset, +} diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index e5f105053..dad6aac62 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -336,6 +336,14 @@ class AxolotlInputConfig( plugins: list[str] | None = Field(default=None) + @field_validator("seed", mode="after") + @classmethod + def set_default_seed(cls, seed): + if seed is None: + LOG.info("`seed` not set in config; setting to 42") + seed = 42 + return seed + @field_validator("datasets", mode="before") @classmethod def deprecate_sharegpt_datasets(cls, datasets): @@ -1199,7 +1207,7 @@ class AxolotlInputConfig( "flash_attention: true must be set with sequence_parallel_degree > 1" ) - if self.sample_packing and self.micro_batch_size > 1: + if self.sample_packing and getattr(self, "micro_batch_size", 1) > 1: raise ValueError( "micro_batch_size must be set to 1 when sample_packing is enabled " "due to a `ring-flash-attn` requirement" diff --git a/tests/core/test_builders.py b/tests/core/test_builders.py index cde7b74ce..e66b8e009 100644 --- a/tests/core/test_builders.py +++ b/tests/core/test_builders.py @@ -12,7 +12,7 @@ from axolotl.common.datasets import load_datasets from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder from axolotl.loaders import ModelLoader, load_tokenizer from axolotl.utils.config import normalize_config -from axolotl.utils.data.rl import load_prepare_preference_datasets +from axolotl.utils.data import prepare_preference_datasets from axolotl.utils.dict import DictDefault from axolotl.utils.schemas.enums import RLType @@ -451,15 +451,19 @@ def rand_reward_func(prompts, completions) -> list[float]: # Only use mock for the commented out configs if dataset_name is not None: with patch( - "axolotl.utils.data.rl.load_dataset_w_config" + "axolotl.utils.data.rl.load_dataset_with_config" ) as mock_load_dataset: mock_load_dataset.return_value = request.getfixturevalue( dataset_name ) - train_dataset, eval_dataset = load_prepare_preference_datasets(cfg) + train_dataset, eval_dataset = prepare_preference_datasets( + cfg, tokenizer + ) else: # Load actual datasets for orpo_cfg and kto_cfg - train_dataset, eval_dataset = load_prepare_preference_datasets(cfg) + train_dataset, eval_dataset = prepare_preference_datasets( + cfg, tokenizer + ) builder.train_dataset = train_dataset builder.eval_dataset = eval_dataset diff --git a/tests/e2e/integrations/test_cut_cross_entropy.py b/tests/e2e/integrations/test_cut_cross_entropy.py index 2ae59a15a..790b34f3e 100644 --- a/tests/e2e/integrations/test_cut_cross_entropy.py +++ b/tests/e2e/integrations/test_cut_cross_entropy.py @@ -4,7 +4,6 @@ Simple end-to-end test for Cut Cross Entropy integration import pytest -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils import get_pytorch_version @@ -59,8 +58,7 @@ class TestCutCrossEntropyIntegration: cfg = validate_config(cfg) prepare_plugins(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) major, minor, _ = get_pytorch_version() if (major, minor) < (2, 4): @@ -105,8 +103,7 @@ class TestCutCrossEntropyIntegration: cfg = validate_config(cfg) prepare_plugins(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) major, minor, _ = get_pytorch_version() if (major, minor) < (2, 4): @@ -134,8 +131,7 @@ class TestCutCrossEntropyIntegration: cfg = validate_config(cfg) prepare_plugins(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) major, minor, _ = get_pytorch_version() if (major, minor) < (2, 4): diff --git a/tests/e2e/integrations/test_hooks.py b/tests/e2e/integrations/test_hooks.py index 45d7200fb..4734449fe 100644 --- a/tests/e2e/integrations/test_hooks.py +++ b/tests/e2e/integrations/test_hooks.py @@ -5,7 +5,6 @@ e2e tests to make sure all the hooks are fired on the plugin import os from pathlib import Path -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.integrations.base import BasePlugin from axolotl.train import train @@ -160,8 +159,7 @@ class TestPluginHooks: cfg = validate_config(cfg) prepare_plugins(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/integrations/test_kd.py b/tests/e2e/integrations/test_kd.py index dad777947..2bd1fbf3d 100644 --- a/tests/e2e/integrations/test_kd.py +++ b/tests/e2e/integrations/test_kd.py @@ -6,7 +6,6 @@ from pathlib import Path import pytest -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, prepare_plugins, validate_config @@ -84,8 +83,7 @@ class TestKnowledgeDistillation: cfg = validate_config(cfg) prepare_plugins(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) assert (Path(temp_dir) / "model.safetensors").exists() @@ -115,8 +113,7 @@ class TestKnowledgeDistillation: cfg = validate_config(cfg) prepare_plugins(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) assert (Path(temp_dir) / "adapter_model.safetensors").exists() diff --git a/tests/e2e/integrations/test_liger.py b/tests/e2e/integrations/test_liger.py index 8ecfc4746..6ab3d7ab8 100644 --- a/tests/e2e/integrations/test_liger.py +++ b/tests/e2e/integrations/test_liger.py @@ -2,7 +2,6 @@ Simple end-to-end test for Liger integration """ -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, prepare_plugins, validate_config @@ -57,8 +56,7 @@ class LigerIntegrationTestCase: cfg = validate_config(cfg) prepare_plugins(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @@ -104,8 +102,7 @@ class LigerIntegrationTestCase: cfg = validate_config(cfg) prepare_plugins(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/integrations/test_llm_compressor.py b/tests/e2e/integrations/test_llm_compressor.py index 20bf821bf..247ae3bac 100644 --- a/tests/e2e/integrations/test_llm_compressor.py +++ b/tests/e2e/integrations/test_llm_compressor.py @@ -6,7 +6,6 @@ from pathlib import Path import pytest -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, prepare_plugins, validate_config @@ -88,8 +87,7 @@ class TestLLMCompressorIntegration: prepare_plugins(cfg) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) try: train(cfg=cfg, dataset_meta=dataset_meta) diff --git a/tests/e2e/multigpu/solo/test_grpo.py b/tests/e2e/multigpu/solo/test_grpo.py index 8ea2e3ce4..1daf58472 100644 --- a/tests/e2e/multigpu/solo/test_grpo.py +++ b/tests/e2e/multigpu/solo/test_grpo.py @@ -105,7 +105,7 @@ def start_vllm( print(f"{i}: VLLM server failed to start: {str(exc)}") # also check if the process.pid is still running - if not process.poll() is None: + if process.poll() is not None: break time.sleep(period_seconds) diff --git a/tests/e2e/multigpu/test_locking.py b/tests/e2e/multigpu/test_locking.py new file mode 100644 index 000000000..42502dfa3 --- /dev/null +++ b/tests/e2e/multigpu/test_locking.py @@ -0,0 +1,192 @@ +"""Tests for FileLockLoader class.""" + +import tempfile +import threading +import time +from pathlib import Path +from unittest.mock import MagicMock, Mock, patch + +import pytest + +from axolotl.utils.data.lock import FileLockLoader +from axolotl.utils.dict import DictDefault + + +class TestFileLockLoader: + """Class with tests for FileLockLoader.""" + + @pytest.fixture + def temp_dir(self): + """Create a temporary directory for testing.""" + with tempfile.TemporaryDirectory() as tmp_dir: + yield Path(tmp_dir) + + @pytest.fixture + def cfg(self, temp_dir): + """Create a test configuration.""" + return DictDefault({"dataset_prepared_path": str(temp_dir)}) + + @pytest.fixture + def loader(self, cfg): + """Create a FileLockLoader instance for testing.""" + return FileLockLoader(cfg) + + def test_load_first_process(self, loader): + """Test load() when no ready flag exists (first process).""" + mock_load_fn = Mock(return_value="test_data") + + result = loader.load(mock_load_fn) + + # Should call the load function + mock_load_fn.assert_called_once() + assert result == "test_data" + + # Should create the ready flag + assert loader.ready_flag_path.exists() + + def test_load_subsequent_process(self, loader): + """Test load() when ready flag already exists (subsequent process).""" + # Create ready flag first + loader.ready_flag_path.touch() + + mock_load_fn = Mock(return_value="loaded_data") + + result = loader.load(mock_load_fn) + + # Should still call load function (to load the prepared data) + mock_load_fn.assert_called_once() + assert result == "loaded_data" + + def test_load_concurrent_processes(self, cfg): + """Test that concurrent processes coordinate correctly.""" + results = [] + call_count = 0 + + def slow_load_fn(): + nonlocal call_count + call_count += 1 + time.sleep(0.1) # Simulate slow loading + return f"data_{call_count}" + + def worker(): + loader = FileLockLoader(cfg) + result = loader.load(slow_load_fn) + results.append(result) + + # Start multiple threads simultaneously + threads = [threading.Thread(target=worker) for _ in range(3)] + for t in threads: + t.start() + for t in threads: + t.join() + + # Only one thread should have done the initial loading + # All should return data, but the load function should be called + # once by the first process and once by each subsequent process + assert len(results) == 3 + assert all(result.startswith("data_") for result in results) + + @patch("time.sleep") + def test_load_waiting_for_ready_flag(self, mock_sleep, loader): + """Test that processes wait for the ready flag to appear.""" + mock_load_fn = Mock(return_value="waiting_data") + mock_ready_flag_path = Mock() + exists_call_count = 0 + + def mock_exists(): + nonlocal exists_call_count + exists_call_count += 1 + + if exists_call_count == 1: + # First check: ready flag exists (not first process) + return True + if exists_call_count <= 3: + # While loop checks: flag doesn't exist yet + return False + return True + + mock_ready_flag_path.exists.side_effect = mock_exists + + # Replace the ready_flag_path with our mock + original_path = loader.ready_flag_path + loader.ready_flag_path = mock_ready_flag_path + + try: + result = loader.load(mock_load_fn) + finally: + # Restore original path + loader.ready_flag_path = original_path + + # Should have slept twice while waiting + assert mock_sleep.call_count == 2 + mock_sleep.assert_called_with(1) + + # Should eventually call load function + mock_load_fn.assert_called_once() + assert result == "waiting_data" + + def test_complete_workflow_with_cleanup(self, loader): + """Test the complete load -> cleanup workflow.""" + mock_load_fn = Mock(return_value="test_data") + + # First process calls load (this should set up counter) + result = loader.load(mock_load_fn) + assert result == "test_data" + assert loader.ready_flag_path.exists() + assert loader.counter_path.exists() + + # Cleanup should remove everything since there's only one process + loader.cleanup() + assert not loader.ready_flag_path.exists() + assert not loader.counter_path.exists() + + def test_multiple_processes_workflow(self, loader): + """Test workflow with multiple processes.""" + # Simulate multiple processes by manually setting up counter + loader.ready_flag_path.touch() + loader.counter_path.write_text("3") # 3 processes + + # First process cleanup + loader.cleanup() + assert loader.ready_flag_path.exists() + assert loader.counter_path.read_text().strip() == "2" + + # Second process cleanup + loader.cleanup() + assert loader.ready_flag_path.exists() + assert loader.counter_path.read_text().strip() == "1" + + # Last process cleanup + loader.cleanup() + assert not loader.ready_flag_path.exists() + assert not loader.counter_path.exists() + + def test_load_exception_handling(self, loader): + """Test behavior when load_fn raises an exception.""" + + def failing_load_fn(): + raise ValueError("Load failed") + + with pytest.raises(ValueError, match="Load failed"): + loader.load(failing_load_fn) + + # Ready flag should not be created on failure + assert not loader.ready_flag_path.exists() + + def test_file_lock_called(self, loader): + """Test that FileLock is properly used.""" + mock_load_fn = Mock(return_value="locked_data") + + with patch("axolotl.utils.data.lock.FileLock") as mock_filelock: + mock_context = MagicMock() + mock_filelock.return_value.__enter__ = Mock(return_value=mock_context) + mock_filelock.return_value.__exit__ = Mock(return_value=None) + + loader.load(mock_load_fn) + + # Verify FileLock was called with correct path + mock_filelock.assert_called_once_with(str(loader.lock_file_path)) + + # Verify context manager was used + mock_filelock.return_value.__enter__.assert_called_once() + mock_filelock.return_value.__exit__.assert_called_once() diff --git a/tests/e2e/patched/test_4d_multipack_llama.py b/tests/e2e/patched/test_4d_multipack_llama.py index 490ce77fb..08b62accc 100644 --- a/tests/e2e/patched/test_4d_multipack_llama.py +++ b/tests/e2e/patched/test_4d_multipack_llama.py @@ -4,7 +4,6 @@ E2E tests for multipack fft llama using 4d attention masks import unittest -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -60,8 +59,7 @@ class Test4dMultipackLlama(unittest.TestCase): ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @@ -108,8 +106,7 @@ class Test4dMultipackLlama(unittest.TestCase): ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_activation_checkpointing.py b/tests/e2e/patched/test_activation_checkpointing.py index 45107b871..d494ed1eb 100644 --- a/tests/e2e/patched/test_activation_checkpointing.py +++ b/tests/e2e/patched/test_activation_checkpointing.py @@ -6,7 +6,6 @@ import pytest import transformers from torch.utils.checkpoint import checkpoint -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -75,8 +74,7 @@ class TestActivationCheckpointing: cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_fa_xentropy.py b/tests/e2e/patched/test_fa_xentropy.py index e66b67e6d..4e3cbc50d 100644 --- a/tests/e2e/patched/test_fa_xentropy.py +++ b/tests/e2e/patched/test_fa_xentropy.py @@ -5,7 +5,6 @@ E2E tests for lora llama import pytest from transformers.utils import is_torch_bf16_gpu_available -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -73,8 +72,7 @@ class TestFAXentropyLlama: cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_falcon_samplepack.py b/tests/e2e/patched/test_falcon_samplepack.py index bd80221ce..a593b0791 100644 --- a/tests/e2e/patched/test_falcon_samplepack.py +++ b/tests/e2e/patched/test_falcon_samplepack.py @@ -6,7 +6,6 @@ import unittest import pytest -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -63,8 +62,7 @@ class TestFalconPatched(unittest.TestCase): ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @@ -105,8 +103,7 @@ class TestFalconPatched(unittest.TestCase): ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_fused_llama.py b/tests/e2e/patched/test_fused_llama.py index 49478f10c..1bbc82a38 100644 --- a/tests/e2e/patched/test_fused_llama.py +++ b/tests/e2e/patched/test_fused_llama.py @@ -7,7 +7,6 @@ import unittest import pytest from transformers.utils import is_torch_bf16_gpu_available -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -62,8 +61,7 @@ class TestFusedLlama(unittest.TestCase): cfg.fp16 = True cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_llama_s2_attention.py b/tests/e2e/patched/test_llama_s2_attention.py index 327bb13f8..d2dcc5e4b 100644 --- a/tests/e2e/patched/test_llama_s2_attention.py +++ b/tests/e2e/patched/test_llama_s2_attention.py @@ -6,7 +6,6 @@ import unittest import pytest -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -64,8 +63,7 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase): cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @@ -107,8 +105,7 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase): cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_lora_llama_multipack.py b/tests/e2e/patched/test_lora_llama_multipack.py index 1bad677b9..5df6bfecc 100644 --- a/tests/e2e/patched/test_lora_llama_multipack.py +++ b/tests/e2e/patched/test_lora_llama_multipack.py @@ -7,7 +7,6 @@ import unittest import pytest from transformers.utils import is_auto_gptq_available, is_torch_bf16_gpu_available -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -65,8 +64,7 @@ class TestLoraLlama(unittest.TestCase): cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @@ -114,8 +112,7 @@ class TestLoraLlama(unittest.TestCase): ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_mistral_samplepack.py b/tests/e2e/patched/test_mistral_samplepack.py index 994b9dfca..2de9cc96f 100644 --- a/tests/e2e/patched/test_mistral_samplepack.py +++ b/tests/e2e/patched/test_mistral_samplepack.py @@ -4,7 +4,6 @@ E2E tests for lora llama import unittest -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -60,8 +59,7 @@ class TestMistral(unittest.TestCase): ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @@ -102,8 +100,7 @@ class TestMistral(unittest.TestCase): ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_mixtral_samplepack.py b/tests/e2e/patched/test_mixtral_samplepack.py index 6a84069ef..5f778660b 100644 --- a/tests/e2e/patched/test_mixtral_samplepack.py +++ b/tests/e2e/patched/test_mixtral_samplepack.py @@ -4,7 +4,6 @@ E2E tests for mixtral import unittest -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -57,8 +56,7 @@ class TestMixtral(unittest.TestCase): ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @@ -96,8 +94,7 @@ class TestMixtral(unittest.TestCase): ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_phi_multipack.py b/tests/e2e/patched/test_phi_multipack.py index ee2a3ffb4..d241ce185 100644 --- a/tests/e2e/patched/test_phi_multipack.py +++ b/tests/e2e/patched/test_phi_multipack.py @@ -4,7 +4,6 @@ E2E tests for lora llama import unittest -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -60,8 +59,7 @@ class TestPhiMultipack(unittest.TestCase): cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @@ -112,8 +110,7 @@ class TestPhiMultipack(unittest.TestCase): cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_resume.py b/tests/e2e/patched/test_resume.py index cc1f3ddee..363956733 100644 --- a/tests/e2e/patched/test_resume.py +++ b/tests/e2e/patched/test_resume.py @@ -7,7 +7,6 @@ import subprocess from transformers.utils import is_torch_bf16_gpu_available -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -67,8 +66,7 @@ class TestResumeLlama: cfg.fp16 = True cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) @@ -78,7 +76,6 @@ class TestResumeLlama: } ) normalize_config(resume_cfg) - cli_args = TrainerCliArgs() train(cfg=resume_cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_unsloth_qlora.py b/tests/e2e/patched/test_unsloth_qlora.py index 46f5b6614..9567c0b18 100644 --- a/tests/e2e/patched/test_unsloth_qlora.py +++ b/tests/e2e/patched/test_unsloth_qlora.py @@ -4,7 +4,6 @@ e2e tests for unsloth qlora import pytest -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -68,8 +67,7 @@ class TestUnslothQLoRA: cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @@ -119,8 +117,7 @@ class TestUnslothQLoRA: cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @@ -175,8 +172,7 @@ class TestUnslothQLoRA: cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/solo/test_flex.py b/tests/e2e/solo/test_flex.py index b33869b1c..8d1a0c7d1 100644 --- a/tests/e2e/solo/test_flex.py +++ b/tests/e2e/solo/test_flex.py @@ -6,7 +6,6 @@ import unittest from transformers.utils import is_torch_bf16_gpu_available -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -59,8 +58,7 @@ class TestPackedFlex(unittest.TestCase): cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) diff --git a/tests/e2e/solo/test_relora_llama.py b/tests/e2e/solo/test_relora_llama.py index cff8313f3..7af550496 100644 --- a/tests/e2e/solo/test_relora_llama.py +++ b/tests/e2e/solo/test_relora_llama.py @@ -5,7 +5,6 @@ E2E tests for relora llama import unittest from pathlib import Path -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -71,8 +70,7 @@ class TestReLoraLlama(unittest.TestCase): cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(Path(temp_dir) / "checkpoint-100/adapter", cfg) diff --git a/tests/e2e/test_deepseekv3.py b/tests/e2e/test_deepseekv3.py index d882286cc..7dfc4ae15 100644 --- a/tests/e2e/test_deepseekv3.py +++ b/tests/e2e/test_deepseekv3.py @@ -6,7 +6,6 @@ from pathlib import Path import pytest -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -72,8 +71,7 @@ class TestDeepseekV3: ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) assert (Path(temp_dir) / "adapter_model.safetensors").exists() @@ -122,8 +120,7 @@ class TestDeepseekV3: ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) assert (Path(temp_dir) / "model.safetensors").exists() diff --git a/tests/e2e/test_dpo.py b/tests/e2e/test_dpo.py index e9f70758b..2cdb57689 100644 --- a/tests/e2e/test_dpo.py +++ b/tests/e2e/test_dpo.py @@ -1,6 +1,4 @@ -""" -E2E tests for lora llama -""" +"""E2E tests for lora llama""" import unittest from pathlib import Path diff --git a/tests/e2e/test_embeddings_lr.py b/tests/e2e/test_embeddings_lr.py index f1297fcf3..9b65f8feb 100644 --- a/tests/e2e/test_embeddings_lr.py +++ b/tests/e2e/test_embeddings_lr.py @@ -4,7 +4,6 @@ E2E tests for llama pretrain import unittest -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -54,8 +53,7 @@ class TestEmbeddingsLrScale(unittest.TestCase): cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @@ -99,8 +97,7 @@ class TestEmbeddingsLrScale(unittest.TestCase): ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_falcon.py b/tests/e2e/test_falcon.py index 7ea7e30f4..4f88e740c 100644 --- a/tests/e2e/test_falcon.py +++ b/tests/e2e/test_falcon.py @@ -6,7 +6,6 @@ import unittest import pytest -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -66,8 +65,7 @@ class TestFalcon(unittest.TestCase): cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @@ -122,8 +120,7 @@ class TestFalcon(unittest.TestCase): cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @@ -164,8 +161,7 @@ class TestFalcon(unittest.TestCase): cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_gemma2.py b/tests/e2e/test_gemma2.py index 65732a737..c0eba72a7 100644 --- a/tests/e2e/test_gemma2.py +++ b/tests/e2e/test_gemma2.py @@ -6,7 +6,6 @@ from pathlib import Path import pytest -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -69,8 +68,7 @@ class TestGemma2: ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) assert (Path(temp_dir) / "adapter_model.safetensors").exists() @@ -121,8 +119,7 @@ class TestGemma2: ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) assert (Path(temp_dir) / "model.safetensors").exists() diff --git a/tests/e2e/test_gemma3_text.py b/tests/e2e/test_gemma3_text.py index d790fa156..3f00a1384 100644 --- a/tests/e2e/test_gemma3_text.py +++ b/tests/e2e/test_gemma3_text.py @@ -6,7 +6,6 @@ from pathlib import Path import pytest -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -68,8 +67,7 @@ class TestGemma3Text: ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) assert (Path(temp_dir) / "adapter_model.safetensors").exists() @@ -119,8 +117,7 @@ class TestGemma3Text: ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) assert (Path(temp_dir) / "model.safetensors").exists() diff --git a/tests/e2e/test_llama.py b/tests/e2e/test_llama.py index 455e17532..2b180029c 100644 --- a/tests/e2e/test_llama.py +++ b/tests/e2e/test_llama.py @@ -2,7 +2,6 @@ E2E tests for llama """ -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -51,8 +50,7 @@ class TestLlama: cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @@ -99,8 +97,7 @@ class TestLlama: cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @@ -144,8 +141,7 @@ class TestLlama: cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @@ -185,8 +181,7 @@ class TestLlama: cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_llama_pretrain.py b/tests/e2e/test_llama_pretrain.py index ec1e164a4..47d4b4839 100644 --- a/tests/e2e/test_llama_pretrain.py +++ b/tests/e2e/test_llama_pretrain.py @@ -1,10 +1,7 @@ -""" -E2E tests for llama pretrain -""" +"""E2E tests for llama pretrain""" import pytest -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -14,9 +11,7 @@ from .utils import check_model_output_exists, check_tensorboard class TestPretrainLlama: - """ - Test case for Llama models w pretraining - """ + """Test case for Llama models w pretraining""" @pytest.mark.parametrize( "sample_packing", @@ -66,8 +61,7 @@ class TestPretrainLlama: cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_llama_vision.py b/tests/e2e/test_llama_vision.py index 32657c156..ad4a83c6a 100644 --- a/tests/e2e/test_llama_vision.py +++ b/tests/e2e/test_llama_vision.py @@ -4,7 +4,6 @@ E2E tests for lora llama import unittest -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -60,8 +59,7 @@ class TestLlamaVision(unittest.TestCase): cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @@ -106,8 +104,7 @@ class TestLlamaVision(unittest.TestCase): ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_lora_llama.py b/tests/e2e/test_lora_llama.py index 999625070..301565302 100644 --- a/tests/e2e/test_lora_llama.py +++ b/tests/e2e/test_lora_llama.py @@ -4,7 +4,6 @@ E2E tests for lora llama import unittest -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -55,8 +54,7 @@ class TestLoraLlama(unittest.TestCase): cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_mamba.py b/tests/e2e/test_mamba.py index efffb4547..1824619a6 100644 --- a/tests/e2e/test_mamba.py +++ b/tests/e2e/test_mamba.py @@ -6,7 +6,6 @@ import unittest import pytest -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -57,8 +56,7 @@ class TestMamba(unittest.TestCase): cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_mistral.py b/tests/e2e/test_mistral.py index 98a82a5f0..5d9b8ba8c 100644 --- a/tests/e2e/test_mistral.py +++ b/tests/e2e/test_mistral.py @@ -6,7 +6,6 @@ import unittest from transformers.utils import is_torch_bf16_gpu_available -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -61,8 +60,7 @@ class TestMistral(unittest.TestCase): cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @@ -106,8 +104,7 @@ class TestMistral(unittest.TestCase): cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_mixtral.py b/tests/e2e/test_mixtral.py index b551e431a..761e59391 100644 --- a/tests/e2e/test_mixtral.py +++ b/tests/e2e/test_mixtral.py @@ -7,7 +7,6 @@ import unittest import torch from transformers.utils import is_torch_bf16_gpu_available -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -67,8 +66,7 @@ class TestMixtral(unittest.TestCase): cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta) assert ( @@ -123,8 +121,7 @@ class TestMixtral(unittest.TestCase): cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta) assert ( @@ -182,8 +179,7 @@ class TestMixtral(unittest.TestCase): cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta) assert ( @@ -241,8 +237,7 @@ class TestMixtral(unittest.TestCase): cfg.bf16 = True else: cfg.fp16 = True - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta) assert ( @@ -287,8 +282,7 @@ class TestMixtral(unittest.TestCase): cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_optimizers.py b/tests/e2e/test_optimizers.py index e812a5f7e..53ef86022 100644 --- a/tests/e2e/test_optimizers.py +++ b/tests/e2e/test_optimizers.py @@ -4,7 +4,6 @@ E2E tests for custom optimizers using Llama import unittest -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -61,8 +60,7 @@ class TestCustomOptimizers(unittest.TestCase): cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) _, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @@ -107,8 +105,7 @@ class TestCustomOptimizers(unittest.TestCase): cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) _, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @@ -154,8 +151,7 @@ class TestCustomOptimizers(unittest.TestCase): cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) _, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @@ -194,8 +190,7 @@ class TestCustomOptimizers(unittest.TestCase): cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @@ -242,8 +237,7 @@ class TestCustomOptimizers(unittest.TestCase): cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_packing_loss.py b/tests/e2e/test_packing_loss.py index 12e272888..463f7c838 100644 --- a/tests/e2e/test_packing_loss.py +++ b/tests/e2e/test_packing_loss.py @@ -6,7 +6,6 @@ import unittest from transformers.utils import is_torch_bf16_gpu_available -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -58,8 +57,7 @@ class TestPackedLlama(unittest.TestCase): cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) diff --git a/tests/e2e/test_phi.py b/tests/e2e/test_phi.py index f8b43ad32..88fda9191 100644 --- a/tests/e2e/test_phi.py +++ b/tests/e2e/test_phi.py @@ -4,7 +4,6 @@ E2E tests for lora llama import unittest -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -58,8 +57,7 @@ class TestPhi(unittest.TestCase): ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @@ -108,8 +106,7 @@ class TestPhi(unittest.TestCase): ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_process_reward_model_smollm2.py b/tests/e2e/test_process_reward_model_smollm2.py index eb81959a2..abfe1b0c5 100644 --- a/tests/e2e/test_process_reward_model_smollm2.py +++ b/tests/e2e/test_process_reward_model_smollm2.py @@ -4,7 +4,6 @@ E2E tests for process reward model w/ lora llama import unittest -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -54,8 +53,7 @@ class TestProcessRewardSmolLM2(unittest.TestCase): ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_tensorboard( diff --git a/tests/e2e/test_qat.py b/tests/e2e/test_qat.py index f9e7993be..2a7cd1459 100644 --- a/tests/e2e/test_qat.py +++ b/tests/e2e/test_qat.py @@ -5,7 +5,6 @@ E2E tests for QAT import unittest from pathlib import Path -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -64,8 +63,7 @@ class TestQATLlama(unittest.TestCase): ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(Path(temp_dir) / "checkpoint-5", cfg) diff --git a/tests/e2e/test_reward_model_smollm2.py b/tests/e2e/test_reward_model_smollm2.py index 55405d58c..304fda1cc 100644 --- a/tests/e2e/test_reward_model_smollm2.py +++ b/tests/e2e/test_reward_model_smollm2.py @@ -4,7 +4,6 @@ E2E tests for reward model lora llama import unittest -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -63,8 +62,7 @@ class TestRewardModelLoraSmolLM2(unittest.TestCase): ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_tensorboard( diff --git a/tests/e2e/test_schedulers.py b/tests/e2e/test_schedulers.py index e468081b1..e98378f08 100644 --- a/tests/e2e/test_schedulers.py +++ b/tests/e2e/test_schedulers.py @@ -4,7 +4,6 @@ E2E tests for custom schedulers using Llama import unittest -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -57,8 +56,7 @@ class TestCustomSchedulers(unittest.TestCase): cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/prompt_strategies/test_dpo_chatml.py b/tests/prompt_strategies/test_dpo_chatml.py index b313a4b64..2c089067f 100644 --- a/tests/prompt_strategies/test_dpo_chatml.py +++ b/tests/prompt_strategies/test_dpo_chatml.py @@ -6,8 +6,9 @@ import unittest import pytest +from axolotl.loaders.tokenizer import load_tokenizer from axolotl.prompt_strategies.dpo import load as load_dpo -from axolotl.utils.data.rl import load_prepare_preference_datasets +from axolotl.utils.data.rl import prepare_preference_datasets from axolotl.utils.dict import DictDefault from tests.hf_offline_utils import enable_hf_offline @@ -55,7 +56,8 @@ class TestDPOChatml: # test that dpo.load works load_dpo("chatml", cfg) # now actually load the datasets with the strategy - train_ds, _ = load_prepare_preference_datasets(cfg) + tokenizer = load_tokenizer(cfg) + train_ds, _ = prepare_preference_datasets(cfg, tokenizer) assert train_ds[0]["prompt"].startswith("<|im_start|>") assert train_ds[0]["prompt"].endswith("<|im_start|>assistant\n") assert "chosen" in train_ds[0] diff --git a/tests/test_datasets.py b/tests/test_datasets.py index bd77591cf..f4730f0f1 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -1,10 +1,9 @@ -""" -Test dataset loading under various conditions. -""" +"""Test dataset loading under various conditions.""" import shutil import tempfile from pathlib import Path +from typing import Any, Generator from unittest.mock import patch import pytest @@ -12,8 +11,9 @@ from datasets import Dataset from huggingface_hub import snapshot_download from transformers import PreTrainedTokenizer -from axolotl.utils.data import load_tokenized_prepared_datasets -from axolotl.utils.data.rl import load_prepare_preference_datasets +from axolotl.loaders.tokenizer import load_tokenizer +from axolotl.utils.data.rl import prepare_preference_datasets +from axolotl.utils.data.sft import _load_tokenized_prepared_datasets from axolotl.utils.dict import DictDefault from tests.constants import ( @@ -28,7 +28,9 @@ class TestDatasetPreparation: """Test a configured dataloader.""" @pytest.fixture - def tokenizer(self, tokenizer_huggyllama) -> PreTrainedTokenizer: + def tokenizer( + self, tokenizer_huggyllama + ) -> Generator[PreTrainedTokenizer, Any, Any]: tokenizer_huggyllama.add_special_tokens(SPECIAL_TOKENS) yield tokenizer_huggyllama @@ -63,7 +65,10 @@ class TestDatasetPreparation: } ) - dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path) + with patch( + "axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path) + ): + dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg) assert len(dataset) == 2000 assert "input_ids" in dataset.features @@ -107,7 +112,10 @@ class TestDatasetPreparation: } ) - dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path) + with patch( + "axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path) + ): + dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg) assert len(dataset) == 2000 assert "input_ids" in dataset.features @@ -136,7 +144,10 @@ class TestDatasetPreparation: } ) - dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path) + with patch( + "axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path) + ): + dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg) assert len(dataset) == 1 assert "input_ids" in dataset.features @@ -145,7 +156,7 @@ class TestDatasetPreparation: @enable_hf_offline def test_load_from_dir_of_parquet(self, tokenizer, dataset_fixture): - """Usual use case. Verify a directory of parquet files can be loaded.""" + """Usual use case. Verify a directory of parquet files can be loaded.""" with tempfile.TemporaryDirectory() as tmp_dir: tmp_ds_dir = Path(tmp_dir) / "tmp_dataset" tmp_ds_dir.mkdir() @@ -171,7 +182,10 @@ class TestDatasetPreparation: } ) - dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path) + with patch( + "axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path) + ): + dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg) assert len(dataset) == 1 assert "input_ids" in dataset.features @@ -206,7 +220,10 @@ class TestDatasetPreparation: } ) - dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path) + with patch( + "axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path) + ): + dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg) assert len(dataset) == 1 assert "input_ids" in dataset.features @@ -235,7 +252,10 @@ class TestDatasetPreparation: } ) - dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path) + with patch( + "axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path) + ): + dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg) assert len(dataset) == 1 assert "input_ids" in dataset.features @@ -264,7 +284,10 @@ class TestDatasetPreparation: } ) - dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path) + with patch( + "axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path) + ): + dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg) assert len(dataset) == 1 assert "input_ids" in dataset.features @@ -286,7 +309,8 @@ class TestDatasetPreparation: } ) - train_dataset, _ = load_prepare_preference_datasets(cfg) + tokenizer = load_tokenizer(cfg) + train_dataset, _ = prepare_preference_datasets(cfg, tokenizer) assert len(train_dataset) == 1800 assert "conversation" not in train_dataset.features @@ -318,7 +342,10 @@ class TestDatasetPreparation: } ) - dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path) + with patch( + "axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path) + ): + dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg) assert len(dataset) == 2000 assert "input_ids" in dataset.features @@ -342,13 +369,16 @@ class TestDatasetPreparation: ) # pylint: disable=duplicate-code - with patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset: + with patch( + "axolotl.utils.data.rl.load_dataset_with_config" + ) as mock_load_dataset: # Set up the mock to return different values on successive calls mock_load_dataset.return_value = ( dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff ) - train_dataset, _ = load_prepare_preference_datasets(cfg) + tokenizer = load_tokenizer(cfg) + train_dataset, _ = prepare_preference_datasets(cfg, tokenizer) assert len(train_dataset) == 1800 assert "conversation" not in train_dataset.features @@ -393,16 +423,18 @@ class TestDatasetPreparation: ) with patch( - "axolotl.utils.data.shared.load_dataset_w_config" + "axolotl.utils.data.shared.load_dataset_with_config" ) as mock_load_dataset: # Set up the mock to return different values on successive calls mock_load_dataset.return_value = ( dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff ) - dataset, _ = load_tokenized_prepared_datasets( - tokenizer, cfg, prepared_path - ) + with patch( + "axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", + str(prepared_path), + ): + dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg) assert len(dataset) == 2000 assert "input_ids" in dataset.features @@ -437,7 +469,10 @@ class TestDatasetPreparation: } ) - dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path) + with patch( + "axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path) + ): + dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg) assert len(dataset) == 2000 assert "input_ids" in dataset.features diff --git a/tests/test_exact_deduplication.py b/tests/test_exact_deduplication.py index 29672c9e5..45a327a40 100644 --- a/tests/test_exact_deduplication.py +++ b/tests/test_exact_deduplication.py @@ -5,7 +5,6 @@ Additionally, this test suite includes tests for functions that indirectly call `deduplicate_and_log_datasets` during the execution of the preprocess command. """ -import hashlib import unittest from unittest.mock import patch @@ -14,8 +13,7 @@ from datasets import Dataset from axolotl.loaders import load_processor, load_tokenizer from axolotl.utils.config import normalize_config, validate_config -from axolotl.utils.data import prepare_dataset -from axolotl.utils.data.rl import load_prepare_preference_datasets +from axolotl.utils.data import prepare_datasets, prepare_preference_datasets from axolotl.utils.data.utils import deduplicate_and_log_datasets from axolotl.utils.dict import DictDefault @@ -71,36 +69,14 @@ class TestDeduplicateIndividualFunctions(unittest.TestCase): self.expected_dataset = Dataset.from_dict(self.expected_data) def test_deduplication(self): - train_dataset, _, _ = deduplicate_and_log_datasets(train_dataset=self.dataset) - _, eval_dataset, _ = deduplicate_and_log_datasets(eval_dataset=self.dataset) + train_dataset, _ = deduplicate_and_log_datasets(dataset=self.dataset) + eval_dataset, _ = deduplicate_and_log_datasets( + dataset=self.dataset, dataset_name="eval" + ) verify_deduplication(train_dataset, self.expected_dataset, "train_dataset") verify_deduplication(eval_dataset, self.expected_dataset, "eval_dataset") - def test_datasets_are_none(self): - # Test when both datasets are None - train_dataset, eval_dataset, _ = deduplicate_and_log_datasets( - train_dataset=None, eval_dataset=None - ) - self.assertIsNone(train_dataset, "Expected train_dataset to be None") - self.assertIsNone(eval_dataset, "Expected eval_dataset to be None") - - def test_only_train_is_none(self): - # Test when only train_dataset is None - train_dataset, eval_dataset, _ = deduplicate_and_log_datasets( - train_dataset=None, eval_dataset=self.dataset - ) - self.assertIsNone(train_dataset, "Expected train_dataset to be None") - verify_deduplication(eval_dataset, self.expected_dataset, "eval_dataset") - - def test_only_eval_is_none(self): - # Test when only eval_dataset is None - train_dataset, eval_dataset, _ = deduplicate_and_log_datasets( - train_dataset=self.dataset, eval_dataset=None - ) - self.assertIsNone(eval_dataset, "Expected eval_dataset to be None") - verify_deduplication(train_dataset, self.expected_dataset, "train_dataset") - def test_exact_duplicates(self): # Test when datasets are exact duplicates duplicate_data = { @@ -115,8 +91,10 @@ class TestDeduplicateIndividualFunctions(unittest.TestCase): expected_dataset = Dataset.from_dict(expected_data) # Run deduplication - train_dataset, _, _ = deduplicate_and_log_datasets(train_dataset=dataset) - _, eval_dataset, _ = deduplicate_and_log_datasets(eval_dataset=dataset) + train_dataset, _ = deduplicate_and_log_datasets(dataset=dataset) + eval_dataset, _ = deduplicate_and_log_datasets( + dataset=dataset, dataset_name="eval" + ) verify_deduplication(train_dataset, expected_dataset, "train_dataset") verify_deduplication(eval_dataset, expected_dataset, "eval_dataset") @@ -139,8 +117,10 @@ class TestDeduplicateIndividualFunctions(unittest.TestCase): expected_dataset = Dataset.from_dict(expected_data) # Run deduplication - train_dataset, _, _ = deduplicate_and_log_datasets(train_dataset=dataset) - _, eval_dataset, _ = deduplicate_and_log_datasets(eval_dataset=dataset) + train_dataset, _ = deduplicate_and_log_datasets(dataset=dataset) + eval_dataset, _ = deduplicate_and_log_datasets( + dataset=dataset, dataset_name="eval" + ) verify_deduplication(train_dataset, expected_dataset, "train_dataset") verify_deduplication(eval_dataset, expected_dataset, "eval_dataset") @@ -169,8 +149,8 @@ class TestDeduplicateIndividualFunctions(unittest.TestCase): expected_dataset_eval = Dataset.from_dict(expected_data_eval) # Run deduplication - train_dataset, eval_dataset, _ = deduplicate_and_log_datasets( - train_dataset=dataset, eval_dataset=dataset + train_dataset, eval_dataset = deduplicate_and_log_datasets( + dataset=dataset, other_dataset=dataset ) verify_deduplication(train_dataset, expected_dataset_train, "train_dataset") @@ -206,8 +186,8 @@ class TestDeduplicateIndividualFunctions(unittest.TestCase): expected_dataset_eval = Dataset.from_dict(expected_data_eval) # Run deduplication - train_dataset, eval_dataset, _ = deduplicate_and_log_datasets( - train_dataset=dataset_train, eval_dataset=dataset_eval + train_dataset, eval_dataset = deduplicate_and_log_datasets( + dataset=dataset_train, other_dataset=dataset_eval ) verify_deduplication(train_dataset, expected_dataset_train, "train_dataset") @@ -245,7 +225,9 @@ class TestDeduplicateRLDataset: # pylint: disable=duplicate-code with ( - patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset, + patch( + "axolotl.utils.data.rl.load_dataset_with_config" + ) as mock_load_dataset, patch("axolotl.loaders.load_tokenizer") as mock_load_tokenizer, ): # Set up the mock to return different values on successive calls @@ -255,7 +237,8 @@ class TestDeduplicateRLDataset: ] mock_load_tokenizer.return_value = tokenizer_huggyllama - train_dataset, _ = load_prepare_preference_datasets(cfg) + tokenizer = load_tokenizer(cfg) + train_dataset, _ = prepare_preference_datasets(cfg, tokenizer) # Verify that the dataset has been deduplicated assert len(train_dataset) == 1800, "Dataset was not properly deduplicated" @@ -269,7 +252,9 @@ class TestDeduplicateRLDataset: ): # pylint: disable=duplicate-code with ( - patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset, + patch( + "axolotl.utils.data.rl.load_dataset_with_config" + ) as mock_load_dataset, patch("axolotl.loaders.load_tokenizer") as mock_load_tokenizer, ): # Set up the mock to return different values on successive calls @@ -279,9 +264,10 @@ class TestDeduplicateRLDataset: ] mock_load_tokenizer.return_value = tokenizer_huggyllama - cfg.dataset_exact_deduplication = False # Load the dataset without deduplication - train_dataset, _ = load_prepare_preference_datasets(cfg) + cfg.dataset_exact_deduplication = False + tokenizer = load_tokenizer(cfg) + train_dataset, _ = prepare_preference_datasets(cfg, tokenizer) # Verify that the dataset retains duplicates assert ( @@ -335,7 +321,7 @@ class TestDeduplicateNonRL(unittest.TestCase): ) # Prepare dataset using the prepare_dataset function - train_dataset, _, _, _ = prepare_dataset( + train_dataset, _, _, _ = prepare_datasets( self.cfg_1, tokenizer, processor=processor, @@ -362,7 +348,7 @@ class TestDeduplicateNonRL(unittest.TestCase): ) # Prepare dataset using the prepare_dataset function - _, eval_dataset, _, _ = prepare_dataset( + _, eval_dataset, _, _ = prepare_datasets( self.cfg_1, tokenizer, processor=processor, @@ -389,7 +375,7 @@ class TestDeduplicateNonRL(unittest.TestCase): ) # Prepare dataset using the prepare_dataset function - train_dataset, eval_dataset, _, _ = prepare_dataset( + train_dataset, eval_dataset, _, _ = prepare_datasets( self.cfg_1, tokenizer, processor=processor, @@ -428,41 +414,8 @@ class TestWrongCollisions(unittest.TestCase): self.eval_dataset = Dataset.from_dict(self.eval_data) self.dataset = Dataset.from_dict(self.dataset_data) - @patch( - "axolotl.utils.data.utils.sha256", - side_effect=lambda x: ( - hashlib.sha256("forced_collision_hash".encode("utf-8")).hexdigest() - if "sample 5" in x - else hashlib.sha256(x.encode("utf-8")).hexdigest() - ), - ) - def test_deduplication_wrong_collision_train_eval(self, _mock_sha256): - dedup_train, dedup_eval, _ = deduplicate_and_log_datasets( - train_dataset=self.train_dataset, eval_dataset=self.eval_dataset - ) - self.assertEqual( - len(dedup_train), - 2, - "train dataset should not deduplicate rows with forced hash collisions but different labels.", - ) - self.assertEqual( - len(dedup_eval), - 2, - "Eval dataset should not deduplicate rows with forced hash collisions but different labels.", - ) - self.assertEqual( - len(dedup_eval), - len(self.eval_dataset), - "The output eval dataset should have the same number of rows as the input eval dataset.", - ) - self.assertEqual( - str(dedup_eval), - str(self.eval_dataset), - "The string representation of the output eval dataset should be identical to the input eval dataset.", - ) - def test_deduplication_dataset_only(self): - _, _, dedup_dataset = deduplicate_and_log_datasets(dataset=self.dataset) + dedup_dataset, _ = deduplicate_and_log_datasets(dataset=self.dataset) self.assertEqual( len(dedup_dataset), 3, "Dataset should have all original values" )