Data loader refactor (#2707)

* data loading refactor (wip)

* updates

* progress

* pytest

* pytest fix

* lint

* zero_first -> filelock, more simplifications

* small simplification

* import change

* nit

* lint

* simplify dedup

* couldnt resist

* review comments WIP

* continued wip

* minor changes

* fix; remove contrived test

* further refactor

* set default seed in pydantic config

* lint

* continued simplication

* lint

* renaming and nits

* filelock tests

* fix

* fix

* lint

* remove nullable arg

* remove unnecessary code

* moving dataset save fn to shared module

* remove debug print

* matching var naming

* fn name change

* coderabbit comments

* naming nit

* fix test
This commit is contained in:
Dan Saunders
2025-06-10 19:53:07 -04:00
committed by GitHub
parent 52a0452acb
commit 00cda8cc70
62 changed files with 2125 additions and 1436 deletions

View File

@@ -1,5 +1,3 @@
"""
Various shared constants
"""
"""Various shared constants"""
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"

View File

@@ -3,15 +3,13 @@
import math
import random
from dataclasses import dataclass
from typing import Optional, Union
from datasets import Dataset
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs
from axolotl.loaders import load_processor, load_tokenizer
from axolotl.utils.data import prepare_dataset
from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.data import prepare_datasets, prepare_preference_datasets
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.enums import RLType
@@ -30,16 +28,7 @@ class TrainDatasetMeta:
def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset:
"""
Randomly sample `num_samples` samples from `dataset`.
Args:
dataset: Dataset.
num_samples: Number of samples to return.
Returns:
Random sample (with replacement) of examples in `dataset`.
"""
"""Randomly sample `num_samples` samples with replacement from `dataset`."""
return dataset.select(
[random.randrange(0, len(dataset) - 1) for _ in range(num_samples)] # nosec
)
@@ -51,44 +40,37 @@ def load_datasets(
cli_args: PreprocessCliArgs | TrainerCliArgs | None = None,
debug: bool = False,
) -> TrainDatasetMeta:
"""
Loads one or more training or evaluation datasets, calling
`axolotl.utils.data.prepare_dataset`. Optionally, logs out debug information.
"""Loads one or more training or evaluation datasets, calling
`axolotl.utils.data.prepare_datasets`. Optionally, logs out debug information.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
cli_args: Command-specific CLI arguments.
debug: Whether to print out tokenization of sample
debug: Whether to print out tokenization of sample. This is duplicated in
`cfg` and `cli_args`, but is kept due to use in our Colab notebooks.
Returns:
Dataclass with fields for training and evaluation datasets and the computed
`total_num_steps`.
`total_num_steps`.
"""
tokenizer = load_tokenizer(cfg)
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
preprocess_iterable = (
cli_args
and hasattr(cli_args, "iterable")
and cli_args.iterable is not None
and cli_args.iterable
)
preprocess_iterable = getattr(cli_args, "iterable", False)
train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset(
train_dataset, eval_dataset, total_num_steps, prompters = prepare_datasets(
cfg,
tokenizer,
processor=processor,
preprocess_iterable=preprocess_iterable,
)
if ( # pylint: disable=too-many-boolean-expressions
cli_args
and (
cli_args.debug
or cfg.debug
or cli_args.debug_text_only
or int(cli_args.debug_num_examples) > 0
)
) or debug:
if (
cfg.debug
or getattr(cli_args, "debug", False)
or getattr(cli_args, "debug_text_only", False)
or getattr(cli_args, "debug_num_examples", 0) > 0
or debug
):
LOG.info("check_dataset_labels...")
num_examples = cli_args.debug_num_examples if cli_args else 1
@@ -113,13 +95,10 @@ def load_datasets(
def load_preference_datasets(
*,
cfg: DictDefault,
cli_args: Union[PreprocessCliArgs, TrainerCliArgs],
*, cfg: DictDefault, cli_args: PreprocessCliArgs | TrainerCliArgs
) -> TrainDatasetMeta:
"""
Loads one or more training or evaluation datasets for RL training using paired
preference data, calling `axolotl.utils.data.rl.load_prepare_preference_datasets`.
"""Loads one or more training or evaluation datasets for RL training using paired
preference data, calling `axolotl.utils.data.rl.prepare_preference_datasets`.
Optionally, logs out debug information.
Args:
@@ -130,12 +109,14 @@ def load_preference_datasets(
Dataclass with fields for training and evaluation datasets and the computed
`total_num_steps`.
"""
train_dataset, eval_dataset = load_prepare_preference_datasets(cfg)
total_num_steps: Optional[int] = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
)
if cfg.rl is RLType.GRPO:
total_num_steps = None
tokenizer = load_tokenizer(cfg)
train_dataset, eval_dataset = prepare_preference_datasets(cfg, tokenizer)
total_num_steps: int | None = None
if cfg.rl is not RLType.GRPO:
total_num_steps = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
)
if cli_args.debug or cfg.debug:
LOG.info("check_dataset_labels...")
@@ -143,8 +124,8 @@ def load_preference_datasets(
tokenizer = load_tokenizer(cfg)
train_samples = sample_dataset(train_dataset, cli_args.debug_num_examples)
check_dataset_labels(
train_samples,
tokenizer,
dataset=train_samples,
tokenizer=tokenizer,
num_examples=cli_args.debug_num_examples,
text_only=cli_args.debug_text_only,
rl_mode=True,

View File

@@ -381,7 +381,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
elif "tokenizer" in sig.parameters:
trainer_kwargs["tokenizer"] = self.tokenizer
if (
not (trainer_cls in [AxolotlRewardTrainer, AxolotlPRMTrainer])
trainer_cls not in [AxolotlRewardTrainer, AxolotlPRMTrainer]
and self.cfg.datasets is not None
):
trainer_kwargs["dataset_tags"] = [

View File

@@ -1,7 +1,6 @@
"""Module containing Dataset functionality"""
import os
from typing import List, Optional, Union
import torch
from datasets import Dataset, IterableDataset
@@ -20,21 +19,21 @@ LOG = get_logger(__name__)
class TokenizedPromptDataset(Dataset):
"""
Dataset that returns tokenized prompts from a stream of text files.
Args:
prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for processing the data.
dataset (dataset.Dataset): Dataset with text files.
process_count (int): Number of processes to use for tokenizing.
keep_in_memory (bool): Whether to keep the tokenized dataset in memory.
"""Dataset that returns tokenized prompts from a stream of text files.
Args:
prompt_tokenizer: The prompt tokenizing method for processing the data.
dataset: Dataset with text files.
process_count: Number of processes to use for tokenizing.
keep_in_memory: Whether to keep the tokenized dataset in memory.
"""
def __init__( # pylint: disable=super-init-not-called
self,
prompt_tokenizer: PromptTokenizingStrategy,
dataset: Dataset,
process_count: Optional[int] = None,
keep_in_memory: Optional[bool] = False,
process_count: int | None = None,
keep_in_memory: bool | None = False,
**kwargs,
):
self.prompt_tokenizer = prompt_tokenizer
@@ -76,14 +75,14 @@ class TokenizedPromptDataset(Dataset):
def wrap_dataset_for_tokenized_prompt(
prompt_tokenizer: PromptTokenizingStrategy,
dataset: Union[Dataset, IterableDataset],
dataset: Dataset | IterableDataset,
**kwargs,
):
if isinstance(dataset, IterableDataset):
map_kwargs = {}
if prompt_tokenizer.supports_batched:
map_kwargs["batched"] = True
features = dataset.features.keys()
features = list(dataset.features.keys())
return dataset.map(
prompt_tokenizer.tokenize_prompt,
remove_columns=features,
@@ -94,12 +93,13 @@ def wrap_dataset_for_tokenized_prompt(
# TODO this isn't the best since it can't interleave datasets
class ConstantLengthDataset(IterableDataset):
"""
Iterable dataset that returns constant length chunks of tokens from stream of text files.
Args:
tokenizer (Tokenizer): The processor used for processing the data.
dataset (dataset.Dataset): Dataset with text files.
seq_length (int): Length of token sequences to return.
"""Iterable dataset that returns constant length chunks of tokens from stream of
text files.
Args:
tokenizer: The processor used for processing the data.
dataset: Dataset with text files.
seq_length: Length of token sequences to return.
"""
def __init__( # pylint: disable=super-init-not-called
@@ -110,7 +110,7 @@ class ConstantLengthDataset(IterableDataset):
):
self.tokenizer = tokenizer
self.concat_token_id = tokenizer.eos_token_id
self.datasets: List[IterableDataset] = datasets
self.datasets: list[IterableDataset] = datasets
self.seq_length = seq_length
vocab_size = len(tokenizer.get_vocab())
@@ -174,7 +174,10 @@ class ConstantLengthDataset(IterableDataset):
}
else:
LOG.warning(
f"dropping batch due to tensor size mismatch input_ids: {input_ids.size()}, labels: {labels.size()}, attention_mask: {attention_mask.size()}"
"Dropping batch due to tensor size mismatch "
f"input_ids: {input_ids.size()}, "
f"labels: {labels.size()}, "
f"attention_mask: {attention_mask.size()}"
)
buffer = {
"input_ids": [],

View File

@@ -7,12 +7,14 @@ import transformers
from transformers import (
AddedToken,
AutoTokenizer,
PreTrainedTokenizer,
)
from axolotl.integrations.base import PluginManager
from axolotl.loaders.utils import get_linear_embedding_layers, load_model_config
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import (
barrier,
is_local_main_process,
@@ -117,7 +119,7 @@ def modify_tokenizer_files(
return tokenizer_dir
def load_tokenizer(cfg):
def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
"""Load and configure the tokenizer based on the provided config."""
model_config = load_model_config(cfg)
tokenizer_kwargs = {}
@@ -207,11 +209,12 @@ def load_tokenizer(cfg):
)
and k != "pad_token"
):
lora_modules_to_save = ", ".join(
lora_modules_to_save_str = ", ".join(
[f"`{x}`" for x in lora_modules_to_save]
)
raise ValueError(
f"Please set lora_modules_to_save to [{lora_modules_to_save}] when using an adapter and changing the special tokens."
f"Please set lora_modules_to_save to [{lora_modules_to_save_str}] "
"when using an adapter and changing the special tokens."
)
tokenizer.add_special_tokens(

View File

@@ -32,4 +32,3 @@ def load(tokenizer, cfg, ds_cfg, processor=None):
except Exception as exc: # pylint: disable=broad-exception-caught
LOG.error(f"Failed to load prompt strategy `{strategy}`: {str(exc)}")
raise exc
return None

View File

@@ -3,6 +3,7 @@
import abc
from typing import Callable, Dict, List, Optional, Tuple, Union
from datasets import Dataset
from transformers import BatchEncoding, PreTrainedTokenizer
from axolotl.prompters import Prompter
@@ -28,6 +29,16 @@ class DatasetWrappingStrategy(abc.ABC):
Abstract class for wrapping datasets for Chat Messages
"""
@abc.abstractmethod
def wrap_dataset(
self,
dataset,
process_count: int | None = None,
keep_in_memory: bool | None = False,
**kwargs,
) -> Dataset:
pass
class PromptTokenizingStrategy(abc.ABC):
"""

View File

@@ -53,8 +53,8 @@ def setup_model_and_tokenizer(
) -> tuple[
PreTrainedModel, PreTrainedTokenizer, PeftConfig | None, ProcessorMixin | None
]:
"""
Load the tokenizer, processor (for multimodal models), and model based on configuration.
"""Load the tokenizer, processor (for multimodal models), and model based on
configuration.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.

View File

@@ -1,16 +1,21 @@
"""
Data processing modules
"""
"""Init for `axolotl.utils.data` module."""
from axolotl.utils.data.pretraining import ( # noqa: F401
from axolotl.utils.data.pretraining import (
encode_pretraining,
wrap_pretraining_dataset,
)
from axolotl.utils.data.rl import load_prepare_preference_datasets # noqa: F401
from axolotl.utils.data.sft import ( # noqa: F401
from axolotl.utils.data.rl import prepare_preference_datasets
from axolotl.utils.data.sft import (
get_dataset_wrapper,
load_prepare_datasets,
load_tokenized_prepared_datasets,
prepare_dataset,
prepare_datasets,
)
from axolotl.utils.data.utils import md5 # noqa: F401
from axolotl.utils.data.utils import md5
__all__ = [
"encode_pretraining",
"wrap_pretraining_dataset",
"prepare_preference_datasets",
"get_dataset_wrapper",
"prepare_datasets",
"md5",
]

View File

@@ -0,0 +1,66 @@
"""Logic for loading / preparing a dataset once over all processes."""
import time
from pathlib import Path
from typing import Any, Callable
from filelock import FileLock
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.utils.dict import DictDefault
LOCK_FILE_NAME = "datasets_prep.lock"
READY_FILE_NAME = "datasets_ready.flag"
PROCESS_COUNTER_FILE_NAME = "process_counter.txt"
class FileLockLoader:
"""
Simple class for abstracting single process data loading / processing. The first
process that creates a lock file does the work; the remaining procesees simply load
the preprocessed dataset once the first process is done.
"""
def __init__(self, cfg: DictDefault):
self.cfg = cfg
self.dataset_prepared_path = (
cfg.dataset_prepared_path or DEFAULT_DATASET_PREPARED_PATH
)
self.lock_file_path = Path(self.dataset_prepared_path) / LOCK_FILE_NAME
self.ready_flag_path = Path(self.dataset_prepared_path) / READY_FILE_NAME
self.counter_path = Path(self.dataset_prepared_path) / PROCESS_COUNTER_FILE_NAME
def load(self, load_fn: Callable[[], Any]) -> Any:
with FileLock(str(self.lock_file_path)):
self._increment_counter()
if not self.ready_flag_path.exists():
result = load_fn()
self.ready_flag_path.touch()
return result
while not self.ready_flag_path.exists():
time.sleep(1)
return load_fn()
def _increment_counter(self):
"""Safely increment the process counter."""
if self.counter_path.exists():
count = int(self.counter_path.read_text().strip())
else:
count = 0
self.counter_path.write_text(str(count + 1))
def cleanup(self):
"""Clean up ready flag when last process is done."""
with FileLock(str(self.lock_file_path)):
count = int(self.counter_path.read_text().strip())
count -= 1
if count == 0:
# Last process cleans everything up
self.ready_flag_path.unlink(missing_ok=True)
self.counter_path.unlink(missing_ok=True)
else:
# Still have active processes
self.counter_path.write_text(str(count))

View File

@@ -250,7 +250,7 @@ def encode_packed_pretraining(
# pylint: disable=duplicate-code
# tokenize all the examples
# rows get split with stride (overlap)
train_dataset = ds_wrapper(Dataset.from_dict(examples))[0]
train_dataset = ds_wrapper(dataset=Dataset.from_dict(examples))[0]
train_dataset = process_pretraining_datasets_for_packing(
train_dataset,

View File

@@ -1,75 +1,117 @@
"""data handling specific to DPO"""
"""Data handling specific to RL trainers."""
import inspect
from functools import partial
from pathlib import Path
from typing import Any, List, Union
from typing import Any, Callable, Literal
import yaml
from datasets import Dataset, DatasetDict, concatenate_datasets, load_from_disk
from datasets import Dataset, DatasetDict
from transformers import PreTrainedTokenizer
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.loaders import load_tokenizer
from axolotl.prompt_strategies.dpo import load as load_dpo
from axolotl.prompt_strategies.kto import load as load_kto
from axolotl.prompt_strategies.orpo import load as load_orpo
from axolotl.utils.data.shared import datasets_w_name_generator, load_dataset_w_config
from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5
from axolotl.utils.data.lock import FileLockLoader
from axolotl.utils.data.shared import (
create_train_validation_split,
datasets_with_name_generator,
generate_dataset_hash_from_config,
load_dataset_with_config,
load_preprocessed_dataset,
merge_datasets,
save_preprocessed_dataset,
try_load_from_hub,
)
from axolotl.utils.data.utils import (
deduplicate_and_log_datasets,
retry_on_request_exceptions,
)
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process, zero_first
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.enums import RLType
LOG = get_logger(__name__)
def _get_path(ds_hash, cfg):
prepared_ds_path = (
Path(cfg.dataset_prepared_path) / ds_hash
if cfg.dataset_prepared_path
else Path(DEFAULT_DATASET_PREPARED_PATH) / ds_hash
)
@retry_on_request_exceptions(max_retries=3, delay=5)
def prepare_preference_datasets(
cfg: DictDefault, tokenizer: PreTrainedTokenizer
) -> tuple[Dataset, Dataset | None]:
"""Load and prepare preference datasets for RL training.
return prepared_ds_path
Loads training and evaluation datasets, handling preprocessing, caching, and
deduplication as configured. Uses FileLock for distributed coordination.
Args:
cfg: Configuration object containing dataset and training settings.
tokenizer: Tokenizer to use for processing text.
Returns:
Tuple of (train_dataset, eval_dataset). eval_dataset may be None
if no evaluation dataset is configured.
"""
def _load_datasets():
# Load training dataset
train_dataset = _load_or_create_dataset_split(cfg, tokenizer, split="train")
# Load or create evaluation dataset
eval_dataset: Dataset | None = None
if cfg.test_datasets:
eval_dataset = _load_or_create_dataset_split(cfg, tokenizer, split="test")
elif cfg.val_set_size:
# Create validation split from training data
train_dataset, eval_dataset = create_train_validation_split(
train_dataset, cfg, cfg.val_set_size
)
return train_dataset, eval_dataset
# Prepare datasets (with file locking logic for multiple ranks)
loader = FileLockLoader(cfg)
try:
train_dataset, eval_dataset = loader.load(_load_datasets)
finally:
loader.cleanup()
# Apply deduplication if configured
if cfg.dataset_exact_deduplication:
train_dataset, eval_dataset = deduplicate_and_log_datasets(
dataset=train_dataset, other_dataset=eval_dataset
)
return train_dataset, eval_dataset
def _load_preprocessed_ds(cfg, sub_cfg):
ds_hash = md5(yaml.dump(sub_cfg, Dumper=yaml.Dumper))
prepared_ds_path = _get_path(ds_hash, cfg)
dataset = None
def _map_dataset(
cfg: DictDefault,
dataset: Dataset | DatasetDict,
ds_transform_fn: Callable[..., Any],
tokenizer: Any | None = None,
**map_kwargs: Any,
) -> Dataset:
"""Apply transformation function to dataset.
# pylint: disable=duplicate-code
if (
cfg.dataset_prepared_path
and any(prepared_ds_path.glob("*"))
and not cfg.is_preprocess
):
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
dataset = load_from_disk(str(prepared_ds_path))
Args:
cfg: Configuration object.
dataset: Dataset to transform.
ds_transform_fn: Transformation function to apply.
tokenizer: Optional tokenizer for transformation.
**map_kwargs: Additional arguments for dataset mapping.
return dataset
def _save_preprocessed_ds(cfg, sub_cfg, dataset):
ds_hash = md5(yaml.dump(sub_cfg, Dumper=yaml.Dumper))
prepared_ds_path = _get_path(ds_hash, cfg)
if cfg.is_preprocess and is_main_process():
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
dataset.save_to_disk(str(prepared_ds_path))
def map_dataset(cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs):
Returns:
Transformed dataset.
"""
sig = inspect.signature(ds_transform_fn)
if "tokenizer" in sig.parameters:
if not tokenizer:
tokenizer = load_tokenizer(cfg)
ds_transform_fn = partial(ds_transform_fn, tokenizer=tokenizer)
if isinstance(data_set, DatasetDict):
data_set = data_set["train"]
if isinstance(dataset, DatasetDict):
dataset = dataset["train"]
data_set = data_set.map(
dataset = dataset.map(
ds_transform_fn,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
@@ -77,13 +119,27 @@ def map_dataset(cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs):
**map_kwargs,
)
return data_set
return dataset
def drop_long_rl_seq(
sample, rl, tokenizer, sequence_len # pylint: disable=invalid-name
):
if rl in (RLType.DPO, RLType.IPO, RLType.ORPO, RLType.SIMPO):
def _drop_long_sequences(
sample: dict[str, Any], rl: RLType, tokenizer: Any, sequence_len: int
) -> bool:
"""Filter out samples that exceed maximum sequence length.
Args:
sample: Dataset sample to check.
rl: Reinforcement learning type.
tokenizer: Tokenizer for length calculation.
sequence_len: Maximum allowed sequence length.
Returns:
True if sample should be kept, False if it should be dropped.
Raises:
ValueError: If required keys are missing or RL type is unknown.
"""
if rl in {RLType.DPO, RLType.IPO, RLType.ORPO, RLType.SIMPO}:
if not (
sample.get("prompt") and sample.get("chosen") and sample.get("rejected")
):
@@ -123,132 +179,115 @@ def drop_long_rl_seq(
raise ValueError("Unknown RL type")
def load_prepare_preference_datasets(cfg):
def load_split(dataset_cfgs, _cfg):
split_datasets: List[Any] = []
use_auth_token = _cfg.hf_use_auth_token
for config_dataset in datasets_w_name_generator(dataset_cfgs):
ds: Union[Dataset, DatasetDict] = load_dataset_w_config(
config_dataset, use_auth_token, streaming=False
)
split_datasets.append(ds)
def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
"""Load and process dataset split for RL training.
tokenizer = load_tokenizer(cfg)
Args:
cfg: Configuration object containing dataset settings.
split: Dataset split to load ("train" or "test").
for i, data_set in enumerate(split_datasets):
_type = dataset_cfgs[i]["type"]
if _type:
if isinstance(_type, DictDefault):
_type = "user_defined.default"
if _cfg.rl is RLType.ORPO:
ds_transform_fn = load_orpo(_type, _cfg, dataset_idx=i)
elif _cfg.rl is RLType.KTO:
ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i)
else:
ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
Returns:
Combined and processed dataset for the specified split.
"""
datasets_configs = cfg.datasets if split == "train" else cfg.test_datasets
split_datasets: list[Dataset | DatasetDict] = []
map_kwargs = {}
if isinstance(ds_transform_fn, tuple):
ds_transform_fn, map_kwargs = ds_transform_fn
split_datasets[i] = map_dataset(
cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs
)
elif _cfg.rl is RLType.KTO:
ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i)
map_kwargs = {}
if isinstance(ds_transform_fn, tuple):
ds_transform_fn, map_kwargs = ds_transform_fn
split_datasets[i] = map_dataset(
cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs
)
else:
# If no `type` is provided, assume the dataset is already in the expected format with
# "prompt", "chosen" and "rejected" already preprocessed
split_datasets[i] = data_set
if not cfg.skip_prepare_dataset:
drop_long = partial(
drop_long_rl_seq,
rl=_cfg.rl,
tokenizer=tokenizer,
sequence_len=cfg.sequence_len,
)
prior_len = len(split_datasets[i])
split_datasets[i] = split_datasets[i].filter(
drop_long,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences",
)
dropped = prior_len - len(split_datasets[i])
if dropped:
LOG.warning(
f"Dropped {dropped} long samples from dataset index {i}"
)
combined_datasets = concatenate_datasets(split_datasets)
combined_datasets = combined_datasets.shuffle(seed=cfg.seed or 42)
return combined_datasets
with zero_first(is_main_process()):
train_is_preprocessed = False
eval_is_preprocessed = False
if train_dataset := _load_preprocessed_ds(cfg, cfg.datasets):
train_is_preprocessed = True
else:
train_dataset = load_split(cfg.datasets, cfg)
eval_dataset = None
if cfg.test_datasets:
if eval_dataset := _load_preprocessed_ds(cfg, cfg.test_datasets):
eval_is_preprocessed = True
else:
eval_dataset = load_split(cfg.test_datasets, cfg)
if not eval_dataset:
if cfg.val_set_size:
seed = cfg.seed if cfg.seed is not None else 42
# ensure we end up with the same fingerprint by doing rank0 first and being able to cache
to_hash_train = (
train_dataset._fingerprint # pylint: disable=protected-access
+ "|"
+ str(cfg.val_set_size)
+ "|"
+ "train"
+ "|"
+ str(cfg.seed or 42)
)
to_hash_test = (
train_dataset._fingerprint # pylint: disable=protected-access
+ "|"
+ str(cfg.val_set_size)
+ "|"
+ "test"
+ "|"
+ str(cfg.seed or 42)
)
train_fingerprint = md5(to_hash_train)
test_fingerprint = md5(to_hash_test)
ds_w_test_split = train_dataset.train_test_split(
test_size=cfg.val_set_size,
seed=seed,
shuffle=False,
train_new_fingerprint=train_fingerprint,
test_new_fingerprint=test_fingerprint,
)
eval_dataset = ds_w_test_split["test"]
train_dataset = ds_w_test_split["train"]
if not train_is_preprocessed:
_save_preprocessed_ds(cfg, cfg.datasets, train_dataset)
if eval_dataset and not eval_is_preprocessed:
_save_preprocessed_ds(cfg, cfg.test_datasets, eval_dataset)
if cfg.dataset_exact_deduplication:
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
train_dataset=train_dataset, eval_dataset=eval_dataset
for dataset_config in datasets_with_name_generator(datasets_configs):
dataset: Dataset | DatasetDict = load_dataset_with_config(
dataset_config, cfg.hf_use_auth_token, streaming=False
)
split_datasets.append(dataset)
return train_dataset, eval_dataset
tokenizer = load_tokenizer(cfg)
for i, dataset in enumerate(split_datasets):
_type = datasets_configs[i]["type"]
if _type:
if isinstance(_type, DictDefault):
_type = "user_defined.default"
if cfg.rl is RLType.ORPO:
ds_transform_fn = load_orpo(_type, cfg, dataset_idx=i)
elif cfg.rl is RLType.KTO:
ds_transform_fn = load_kto(_type, cfg, dataset_idx=i)
else:
ds_transform_fn = load_dpo(_type, cfg, dataset_idx=i)
map_kwargs: dict[str, Any] = {}
if isinstance(ds_transform_fn, tuple):
ds_transform_fn, map_kwargs = ds_transform_fn
split_datasets[i] = _map_dataset(
cfg, dataset, ds_transform_fn, tokenizer, **map_kwargs
)
else:
# If no `type` is provided, assume the dataset is already in the expected format with
# "prompt", "chosen", and "rejected" already preprocessed
split_datasets[i] = dataset
if not cfg.skip_prepare_dataset:
drop_long = partial(
_drop_long_sequences,
rl=cfg.rl,
tokenizer=tokenizer,
sequence_len=cfg.sequence_len,
)
prior_len = len(split_datasets[i])
split_datasets[i] = split_datasets[i].filter(
drop_long,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences",
)
dropped = prior_len - len(split_datasets[i])
if dropped:
LOG.warning(f"Dropped {dropped} long samples from dataset index {i}")
# Merge datasets
dataset = merge_datasets(split_datasets, cfg)
if not cfg.skip_prepare_dataset:
# Save preprocessed dataset
dataset_hash = generate_dataset_hash_from_config(
cfg, datasets_configs, tokenizer.name_or_path
)
save_preprocessed_dataset(cfg, dataset, dataset_hash, split)
return dataset
# pylint: disable=duplicate-code
def _load_or_create_dataset_split(
cfg: DictDefault, tokenizer: PreTrainedTokenizer, split: Literal["train", "test"]
) -> Dataset:
"""Load preprocessed dataset or create new one for given split.
Args:
cfg: Configuration object.
tokenizer: Tokenizer to use for processing text.
split: Dataset split to load.
Returns:
Tuple of (dataset, is_preprocessed).
"""
# Select correct dataset configuration based on split
datasets_config = cfg.datasets if split == "train" else cfg.test_datasets
# Generate dataset hash for caching
dataset_hash = generate_dataset_hash_from_config(
cfg, datasets_config, tokenizer.name_or_path
)
# Try loading from hub if push_dataset_to_hub is configured
dataset = None
if cfg.push_dataset_to_hub:
dataset = try_load_from_hub(cfg, dataset_hash, split)
# Attempt to load preprocessed dataset
if dataset is None:
dataset = load_preprocessed_dataset(cfg, dataset_hash)
# Otherwise, load it
if dataset is None:
dataset = _load_split(cfg, split=split)
return dataset

File diff suppressed because it is too large Load Diff

View File

@@ -1,11 +1,21 @@
"""
dataset loading shared utils
"""
"""Dataset loading shared utils."""
from __future__ import annotations
import functools
import os
from pathlib import Path
from typing import Optional, Union
from typing import TYPE_CHECKING, Any, Generator
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
from datasets import (
Dataset,
DatasetDict,
IterableDataset,
IterableDatasetDict,
concatenate_datasets,
load_dataset,
load_from_disk,
)
from huggingface_hub import hf_hub_download, snapshot_download
from huggingface_hub.errors import (
HFValidationError,
@@ -13,78 +23,141 @@ from huggingface_hub.errors import (
RevisionNotFoundError,
)
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
if TYPE_CHECKING:
from adlfs import AzureBlobFileSystem
from gcsfs import GCSFileSystem
from ocifs import OCIFileSystem
from s3fs import S3FileSystem
LOG = get_logger(__name__)
EXTENSIONS_TO_DATASET_TYPES = {
".parquet": "parquet",
".arrow": "arrow",
".csv": "csv",
".txt": "text",
}
def get_ds_type(config_dataset: DictDefault):
"""
Get the dataset type from the path if it's not specified
"""
ds_type = "json"
if config_dataset.ds_type:
ds_type = config_dataset.ds_type
elif ".parquet" in config_dataset.path:
ds_type = "parquet"
elif ".arrow" in config_dataset.path:
ds_type = "arrow"
elif ".csv" in config_dataset.path:
ds_type = "csv"
elif ".txt" in config_dataset.path:
ds_type = "text"
return ds_type
def get_dataset_type(dataset_config: DictDefault) -> str:
"""Get the dataset type from the path if it's not specified."""
if dataset_config.ds_type:
return dataset_config.ds_type
for extension, dataset_type in EXTENSIONS_TO_DATASET_TYPES.items():
if extension in dataset_config.path:
return dataset_type
return "json"
def datasets_w_name_generator(dataset_configs: list[DictDefault]):
"""
Yields dataset configs handling multiple names or preprocess_shards
def datasets_with_name_generator(
dataset_configs: list[DictDefault],
) -> Generator[DictDefault, None, None]:
"""Yields expanded dataset configurations based on multiple names or preprocessing
shards.
When a dataset config has a list of names, it yields separate configs for each
name. When a dataset config specifies preprocessing shards, it yields configs for
each shard.
Args:
dataset_configs: list of dataset configs (equivalent to cfg.datasets)
dataset_configs: List of dataset configuration objects.
Yields:
Individual dataset configurations, expanded as needed for names or shards.
"""
for dataset in dataset_configs:
if dataset.name and isinstance(dataset.name, list):
# load_dataset doesn't properly handle multiple named configurations
# at the same time for a given dataset
for name in dataset.name:
yield DictDefault({**dataset, "name": name})
elif dataset.preprocess_shards and not dataset.shards:
for shard in range(dataset.preprocess_shards):
for config in dataset_configs:
if config.name and isinstance(config.name, list):
for name in config.name:
yield DictDefault({**config, "name": name})
elif config.preprocess_shards and not config.shards:
for shard_idx in range(config.preprocess_shards):
yield DictDefault(
{
**dataset,
"shards": dataset.preprocess_shards,
"shards_idx": shard,
**config,
"shards": config.preprocess_shards,
"shards_idx": shard_idx,
}
)
else:
yield dataset
yield config
def load_dataset_w_config(
config_dataset: DictDefault, use_auth_token: bool, streaming=False
) -> Union[Dataset, DatasetDict]:
"""
Load a dataset from a config
def load_dataset_with_config(
dataset_config: DictDefault, use_auth_token: bool, streaming=False
) -> Dataset | IterableDataset:
"""Load a dataset from a config. Handles datasets that are stored locally, in the
HuggingFace Hub, in a remote filesystem (S3, GCS, Azure, OCI), a URL, or
`data_files`.
Args:
config_dataset: single dataset config
use_auth_token: whether to use HF auth token
streaming: whether to stream the dataset
dataset_config: Single dataset config.
use_auth_token: Whether to use HF auth token.
streaming: Whether to stream the dataset.
Returns:
Loaded dataset.
"""
# pylint: disable=invalid-name
ds: Optional[Union[Dataset, DatasetDict]] = None # pylint: disable=invalid-name
ds_from_hub = False
# Set up common kwargs for dataset loading
load_dataset_kwargs = {
"split": dataset_config.split if dataset_config.split else None,
"name": dataset_config.name,
"streaming": streaming,
"trust_remote_code": dataset_config.trust_remote_code,
}
# First check if it's a local path
if Path(dataset_config.path).exists():
return _load_from_local_path(dataset_config, load_dataset_kwargs)
# Check if it's a HuggingFace dataset
is_hub_dataset = _check_if_hub_dataset(dataset_config, use_auth_token)
# Check if it's a cloud storage path and get appropriate filesystem
remote_fs, storage_options = _get_remote_filesystem(dataset_config.path)
is_cloud_dataset = False
if remote_fs:
try:
is_cloud_dataset = remote_fs.exists(dataset_config.path)
except (FileNotFoundError, ConnectionError):
pass
# Load from appropriate source
if is_hub_dataset:
return _load_from_hub(dataset_config, use_auth_token, load_dataset_kwargs)
if is_cloud_dataset:
return _load_from_cloud(
dataset_config, remote_fs, storage_options, load_dataset_kwargs
)
if dataset_config.path.startswith("https://"):
return _load_from_url(dataset_config, load_dataset_kwargs)
if dataset_config.data_files:
return _load_from_data_files(dataset_config, load_dataset_kwargs)
raise ValueError(
f"The dataset could not be loaded. This could be due to a misconfigured dataset path "
f"({dataset_config.path}). Try double-check your path / name / data_files. "
f"This is not caused by the dataset type."
)
def _check_if_hub_dataset(dataset_config: DictDefault, use_auth_token: bool) -> bool:
"""Check if a dataset exists on the HuggingFace Hub."""
try:
# this is just a basic check to see if the path is a
# valid HF dataset that's loadable
snapshot_download(
repo_id=config_dataset.path,
repo_id=dataset_config.path,
repo_type="dataset",
token=use_auth_token,
revision=config_dataset.revision,
revision=dataset_config.revision,
ignore_patterns=["*"],
)
ds_from_hub = True
return True
except (
RepositoryNotFoundError,
RevisionNotFoundError,
@@ -93,198 +166,373 @@ def load_dataset_w_config(
HFValidationError,
ValueError,
):
pass
return False
ds_from_cloud = False
storage_options: dict = {}
remote_file_system = None
if config_dataset.path.startswith("s3://"):
def _get_remote_filesystem(
path: str,
) -> tuple[
S3FileSystem | GCSFileSystem | AzureBlobFileSystem | OCIFileSystem | None, dict
]:
"""Get the appropriate filesystem for a remote path."""
if path.startswith("s3://"):
try:
import s3fs # type: ignore
import s3fs
storage_options = {"anon": False}
return s3fs.S3FileSystem(**storage_options), storage_options
except ImportError as exc:
raise ImportError("s3:// paths require s3fs to be installed") from exc
# Reads env, credentials from ~/.aws/credentials, or IAM metadata provider
# https://s3fs.readthedocs.io/en/latest/index.html?highlight=storage_options#credentials
storage_options = {"anon": False}
remote_file_system = s3fs.S3FileSystem(**storage_options)
elif config_dataset.path.startswith("gs://") or config_dataset.path.startswith(
"gcs://"
):
elif path.startswith(("gs://", "gcs://")):
try:
import gcsfs # type: ignore
import gcsfs
storage_options = {"token": None} # type: ignore
return gcsfs.GCSFileSystem(**storage_options), storage_options
except ImportError as exc:
raise ImportError(
"gs:// or gcs:// paths require gcsfs to be installed"
) from exc
# gcsfs will use default credentials from the environment else anon
# https://gcsfs.readthedocs.io/en/latest/#credentials
storage_options = {"token": None}
remote_file_system = gcsfs.GCSFileSystem(**storage_options)
elif (
config_dataset.path.startswith("adl://")
or config_dataset.path.startswith("abfs://")
or config_dataset.path.startswith("az://")
):
elif path.startswith(("adl://", "abfs://", "az://")):
try:
import adlfs
storage_options = {"anon": False}
return adlfs.AzureBlobFileSystem(**storage_options), storage_options
except ImportError as exc:
raise ImportError(
"adl:// or abfs:// paths require adlfs to be installed"
) from exc
# # Ensure you have the following environment variables set:
# # Gen 1
# storage_options = {
# "tenant_id": AZURE_STORAGE_TENANT_ID,
# "client_id": AZURE_STORAGE_CLIENT_ID,
# "client_secret": AZURE_STORAGE_CLIENT_SECRET,
# }
# # Gen 2
# storage_options = {
# "account_name": AZURE_STORAGE_ACCOUNT_NAME,
# "account_key": AZURE_STORAGE_ACCOUNT_KEY,
# }
# Reads env
# https://github.com/fsspec/adlfs?tab=readme-ov-file#setting-credentials
storage_options = {"anon": False}
remote_file_system = adlfs.AzureBlobFileSystem(**storage_options)
elif config_dataset.path.startswith("oci://"):
elif path.startswith("oci://"):
try:
import ocifs
storage_options = {}
return ocifs.OCIFileSystem(**storage_options), storage_options
except ImportError as exc:
raise ImportError("oci:// paths require ocifs to be installed") from exc
# https://ocifs.readthedocs.io/en/latest/getting-connected.html#Using-Environment-Variables
remote_file_system = ocifs.OCIFileSystem(**storage_options)
return None, {}
try:
if remote_file_system and remote_file_system.exists(config_dataset.path):
ds_from_cloud = True
except (FileNotFoundError, ConnectionError):
pass
# gather extra args from the config
load_ds_kwargs = {}
if config_dataset.split:
load_ds_kwargs["split"] = config_dataset.split
def _load_from_local_path(
dataset_config: DictDefault, load_dataset_kwargs: dict
) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict:
"""Load a dataset from a local path."""
local_path = Path(dataset_config.path)
if local_path.is_dir():
if dataset_config.data_files:
dataset_type = get_dataset_type(dataset_config)
return load_dataset(
dataset_type,
data_files=dataset_config.data_files,
**load_dataset_kwargs,
)
try:
return load_from_disk(dataset_config.path)
except FileNotFoundError:
load_dataset_kwargs["streaming"] = False
return load_dataset(dataset_config.path, **load_dataset_kwargs)
elif local_path.is_file():
dataset_type = get_dataset_type(dataset_config)
load_dataset_kwargs["streaming"] = False
return load_dataset(
dataset_type,
data_files=dataset_config.path,
**load_dataset_kwargs,
)
else:
load_ds_kwargs["split"] = None
# prefer local dataset, even if hub exists
local_path = Path(config_dataset.path)
if local_path.exists():
if local_path.is_dir():
if config_dataset.data_files:
ds_type = get_ds_type(config_dataset)
ds = load_dataset( # pylint: disable=invalid-name
ds_type,
name=config_dataset.name,
data_files=config_dataset.data_files,
streaming=streaming,
**load_ds_kwargs,
)
else:
try:
ds = load_from_disk(
config_dataset.path
) # pylint: disable=invalid-name
except FileNotFoundError:
ds = load_dataset(
config_dataset.path,
name=config_dataset.name,
streaming=False,
**load_ds_kwargs,
)
elif local_path.is_file():
ds_type = get_ds_type(config_dataset)
ds = load_dataset( # pylint: disable=invalid-name
ds_type,
name=config_dataset.name,
data_files=config_dataset.path,
streaming=False,
**load_ds_kwargs,
)
else:
raise ValueError(
"unhandled dataset load: local path exists, but is neither a directory or a file"
)
elif ds_from_hub:
ds = load_dataset(
config_dataset.path,
name=config_dataset.name,
streaming=streaming,
data_files=config_dataset.data_files,
token=use_auth_token,
revision=config_dataset.revision,
trust_remote_code=config_dataset.trust_remote_code,
**load_ds_kwargs,
)
elif ds_from_cloud and remote_file_system:
if remote_file_system.isdir(config_dataset.path):
ds = load_from_disk(
config_dataset.path,
storage_options=storage_options,
)
elif remote_file_system.isfile(config_dataset.path):
ds_type = get_ds_type(config_dataset)
ds = load_dataset(
ds_type,
name=config_dataset.name,
data_files=config_dataset.path,
streaming=streaming,
storage_options=storage_options,
trust_remote_code=config_dataset.trust_remote_code,
**load_ds_kwargs,
)
elif config_dataset.path.startswith("https://"):
ds_type = get_ds_type(config_dataset)
ds = load_dataset(
ds_type,
name=config_dataset.name,
data_files=config_dataset.path,
streaming=streaming,
storage_options=storage_options,
trust_remote_code=config_dataset.trust_remote_code,
**load_ds_kwargs,
)
elif config_dataset.data_files:
fp: str | list[str] | None = None
if isinstance(config_dataset.data_files, str):
fp = hf_hub_download(
repo_id=config_dataset.path,
repo_type="dataset",
filename=config_dataset.data_files,
revision=config_dataset.revision,
)
elif isinstance(config_dataset.data_files, list):
fp = []
for file in config_dataset.data_files:
fp.append(
hf_hub_download(
repo_id=config_dataset.path,
repo_type="dataset",
filename=file,
revision=config_dataset.revision,
)
)
else:
raise ValueError("data_files must be either a string or list of strings")
ds = load_dataset(
"json",
name=config_dataset.name,
data_files=fp,
streaming=streaming,
**load_ds_kwargs,
)
if not ds:
raise ValueError(
"The dataset could not be loaded. This could be due to a misconfigured dataset path "
f"({config_dataset.path}). Try double-check your path / name / data_files. "
"This is not caused by the dataset type."
"Unhandled dataset load: local path exists, but is neither a directory or a file"
)
return ds
def _load_from_hub(
dataset_config: DictDefault, use_auth_token: bool, load_dataset_kwargs: dict
) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict:
"""Load a dataset from the HuggingFace Hub."""
return load_dataset(
dataset_config.path,
data_files=dataset_config.data_files,
token=use_auth_token,
revision=dataset_config.revision,
**load_dataset_kwargs,
)
def _load_from_cloud(
dataset_config: DictDefault,
remote_fs: S3FileSystem | GCSFileSystem | AzureBlobFileSystem | OCIFileSystem,
storage_options: dict,
load_dataset_kwargs: dict,
) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict:
"""Load a dataset from cloud storage."""
if remote_fs.isdir(dataset_config.path):
return load_from_disk(
dataset_config.path,
storage_options=storage_options,
)
if remote_fs.isfile(dataset_config.path):
dataset_type = get_dataset_type(dataset_config)
return load_dataset(
dataset_type,
data_files=dataset_config.path,
storage_options=storage_options,
**load_dataset_kwargs,
)
raise ValueError(
f"Cloud path {dataset_config.path} is neither a directory nor a file"
)
def _load_from_url(
dataset_config: DictDefault, load_dataset_kwargs: dict
) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict:
"""Load a dataset from a URL."""
dataset_type = get_dataset_type(dataset_config)
return load_dataset(
dataset_type,
data_files=dataset_config.path,
**load_dataset_kwargs,
)
def _load_from_data_files(
dataset_config: DictDefault, load_dataset_kwargs: dict
) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict:
"""Load a dataset from data files."""
file_path = None
if isinstance(dataset_config.data_files, str):
file_path = hf_hub_download(
repo_id=dataset_config.path,
repo_type="dataset",
filename=dataset_config.data_files,
revision=dataset_config.revision,
)
elif isinstance(dataset_config.data_files, list):
file_path = [
hf_hub_download(
repo_id=dataset_config.path,
repo_type="dataset",
filename=file,
revision=dataset_config.revision,
)
for file in dataset_config.data_files
]
else:
raise ValueError("data_files must be either a string or list of strings")
return load_dataset("json", data_files=file_path, **load_dataset_kwargs)
def generate_split_fingerprints(
dataset: Dataset, val_set_size: int | float, seed: int
) -> tuple[str, str]:
"""Generate consistent fingerprints for train/test splits."""
fingerprint = dataset._fingerprint # pylint: disable=protected-access
train_hash_input = f"{fingerprint}|{val_set_size}|train|{seed}"
test_hash_input = f"{fingerprint}|{val_set_size}|test|{seed}"
train_fingerprint = md5(train_hash_input)
test_fingerprint = md5(test_hash_input)
return train_fingerprint, test_fingerprint
def get_prepared_dataset_path(cfg: DictDefault, dataset_hash: str) -> Path:
"""Get standardized path for prepared datasets.
Args:
cfg: Configuration object.
dataset_hash: Hash identifying the specific dataset configuration.
Returns:
Path where the prepared dataset should be stored.
"""
base_path = cfg.dataset_prepared_path or DEFAULT_DATASET_PREPARED_PATH
return Path(base_path) / dataset_hash
def create_train_validation_split(
dataset: Dataset, cfg: DictDefault, val_set_size: int | float
) -> tuple[Dataset, Dataset]:
"""Create train/validation split with consistent fingerprinting.
Args:
dataset: Dataset to split.
cfg: Configuration object containing seed and other settings.
val_set_size: Size of validation set (absolute number or fraction).
Returns:
Tuple of (train_dataset, eval_dataset).
"""
train_fingerprint, test_fingerprint = generate_split_fingerprints(
dataset, val_set_size, cfg.seed
)
# Apply deduplication before splitting if configured
if cfg.dataset_exact_deduplication:
dataset, _ = deduplicate_and_log_datasets(dataset=dataset)
split_dataset = dataset.train_test_split(
test_size=val_set_size,
shuffle=False,
seed=cfg.seed,
train_new_fingerprint=train_fingerprint,
test_new_fingerprint=test_fingerprint,
)
return split_dataset["train"], split_dataset["test"]
def _generate_from_iterable_dataset(
dataset: IterableDataset, worker_id: list[int], num_workers: list[int]
) -> Generator[Any, None, None]:
"""Generator function to correctly split the dataset for each worker"""
for i, item in enumerate(dataset):
if i % num_workers[0] == worker_id[0]:
yield item
def save_preprocessed_dataset(
cfg: DictDefault,
dataset: Dataset,
dataset_hash: str,
split: str,
) -> None:
"""Save preprocessed dataset to disk and optionally push to the HF Hub."""
prepared_ds_path = get_prepared_dataset_path(cfg, dataset_hash)
if isinstance(dataset, IterableDataset):
num_workers = cfg.dataset_processes
ds_from_iter = Dataset.from_generator(
functools.partial(_generate_from_iterable_dataset, dataset),
features=dataset.features,
num_proc=num_workers,
split=split,
gen_kwargs={
"worker_id": list(range(num_workers)),
"num_workers": [num_workers] * num_workers,
},
)
ds_from_iter.save_to_disk(str(prepared_ds_path))
else:
os.makedirs(prepared_ds_path, exist_ok=True)
dataset.save_to_disk(str(prepared_ds_path))
if cfg.push_dataset_to_hub:
LOG.info(
"Pushing merged prepared dataset to Huggingface hub at "
f"{cfg.push_dataset_to_hub} (version {dataset_hash})...",
main_process_only=False,
)
dataset.push_to_hub(
cfg.push_dataset_to_hub,
dataset_hash,
private=True,
)
def load_preprocessed_dataset(cfg: DictDefault, dataset_hash: str) -> Dataset | None:
"""Load preprocessed dataset from disk if available.
Args:
cfg: Configuration object.
dataset_hash: Hash identifying the dataset configuration.
Returns:
Loaded dataset if found and conditions are met, None otherwise.
"""
prepared_ds_path = get_prepared_dataset_path(cfg, dataset_hash)
if (
cfg.dataset_prepared_path
and any(prepared_ds_path.glob("*"))
and not cfg.skip_prepare_dataset
and not cfg.is_preprocess
):
LOG.info(
f"Loading prepared dataset from disk at {prepared_ds_path}...",
main_process_only=False,
)
return load_from_disk(str(prepared_ds_path))
LOG.info(
f"Unable to find prepared dataset in {prepared_ds_path}",
main_process_only=False,
)
return None
def try_load_from_hub(
cfg: DictDefault, dataset_hash: str, split: str
) -> Dataset | None:
"""Try to load the prepared dataset from HuggingFace Hub."""
try:
LOG.info(
"Attempting to load prepared dataset from HuggingFace Hub at "
f"{cfg.push_dataset_to_hub} (version {dataset_hash})..."
)
dataset = load_dataset(
cfg.push_dataset_to_hub,
dataset_hash,
token=cfg.hf_use_auth_token,
)
return dataset[split]
except Exception: # pylint: disable=broad-except # nosec
LOG.info("Unable to find prepared dataset in HuggingFace Hub")
return None
def generate_dataset_hash_from_config(
cfg: DictDefault, cfg_datasets: list, tokenizer_name: str
) -> str:
"""Generate a hash to uniquely identify a dataset configuration for SFT.
Args:
cfg: Main configuration object.
cfg_datasets: List of dataset configurations.
tokenizer_name: Name of the tokenizer being used.
Returns:
MD5 hash string representing the configuration.
"""
config_str = (
f"{cfg.sequence_len}@{cfg.sample_packing}@{cfg.eval_sample_packing}@"
f"{cfg.group_by_length}@{cfg.kd_temperature or 1.0}|"
f"{'|'.join(sorted([f'{d.path}:{d.type}:{d.shards}:{d.conversation}:{d.split}:{d.temperature or 1.0}' for d in cfg_datasets]))}"
f"|{tokenizer_name}"
)
return str(md5(config_str))
def merge_datasets(datasets: list[Dataset], cfg: DictDefault) -> Dataset:
"""Merge multiple datasets into one with optional shuffling.
Args:
datasets: List of datasets to merge.
cfg: Configuration object containing shuffle settings.
Returns:
Merged dataset.
"""
if len(datasets) == 1:
return datasets[0]
LOG.info("Merging datasets...")
merged_dataset = concatenate_datasets(datasets)
if cfg.shuffle_merged_datasets:
LOG.debug("Shuffling merged datasets...")
merged_dataset = merged_dataset.shuffle(seed=cfg.seed)
else:
LOG.debug("Not shuffling merged datasets.")
return merged_dataset

View File

@@ -1,9 +1,11 @@
"""data handling helpers"""
"""Data handling helpers"""
import contextlib
import functools
import hashlib
import time
from enum import Enum
from typing import Callable
import huggingface_hub
import numpy as np
@@ -19,9 +21,7 @@ LOG = get_logger(__name__)
class RetryStrategy(Enum):
"""
Enum for retry strategies.
"""
"""Enum for retry strategies."""
CONSTANT = 1
LINEAR = 2
@@ -30,7 +30,18 @@ class RetryStrategy(Enum):
def retry_on_request_exceptions(
max_retries=3, delay=1, retry_strategy: RetryStrategy = RetryStrategy.LINEAR
):
) -> Callable:
"""Decorator that retries function calls on specific request exceptions.
Args:
max_retries: Maximum number of retry attempts.
delay: Base delay between retries in seconds.
retry_strategy: Strategy for calculating retry delays.
Returns:
Decorated function with retry logic.
"""
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements
@@ -59,6 +70,7 @@ def retry_on_request_exceptions(
def md5(to_hash: str, encoding: str = "utf-8") -> str:
"""Generate MD5 hash of a string."""
try:
return hashlib.md5(to_hash.encode(encoding), usedforsecurity=False).hexdigest()
except TypeError:
@@ -66,102 +78,89 @@ def md5(to_hash: str, encoding: str = "utf-8") -> str:
def sha256(to_hash: str, encoding: str = "utf-8") -> str:
"""Generate SHA256 hash of a string."""
return hashlib.sha256(to_hash.encode(encoding)).hexdigest()
def deduplicate_dataset(
dataset: Dataset, seen_hashes: dict[str, list[int]], other_dataset: Dataset = None
) -> Dataset:
unique_indices = []
def _deduplicate_dataset(
dataset: Dataset,
seen_hashes: set[str] | None = None,
) -> tuple[Dataset, set[str]]:
"""Remove duplicate rows from a dataset using SHA256 hashes.
Args:
dataset: Dataset to deduplicate.
seen_hashes: Set of previously seen row hashes (for cross-deduplication).
Returns:
Tuple of deduplicated dataset and the set of seen hashes.
"""
if seen_hashes is None:
seen_hashes = set()
unique_indices = []
for idx, row in enumerate(dataset):
row_hash = sha256(str(row)) # Using SHA256 for collision resistance.
row_hash = sha256(str(row)) # Using SHA256 for collision resistance
if row_hash not in seen_hashes:
seen_hashes[row_hash] = [idx]
seen_hashes.add(row_hash)
unique_indices.append(idx)
else:
# Check for collision by looking up the original dataset indices
original_indices = seen_hashes[row_hash]
is_duplicate = False
for original_idx in original_indices:
if (
not idx == original_idx
and original_idx < len(dataset)
and str(dataset[original_idx]) == str(row)
):
is_duplicate = True
break
# Check in the other dataset if provided
if other_dataset is not None:
if original_idx < len(other_dataset) and str(
other_dataset[original_idx]
) == str(row):
is_duplicate = True
break
if not is_duplicate:
seen_hashes[row_hash].append(idx)
unique_indices.append(idx)
continue
return dataset.select(unique_indices)
return dataset.select(unique_indices), seen_hashes
def deduplicate_and_log_datasets(
*,
train_dataset: Dataset = None,
eval_dataset: Dataset = None,
dataset: Dataset = None,
) -> tuple[Dataset, Dataset, Dataset]:
"""
Deduplicates train, eval, and an optional dataset if provided, logging original and new sizes.
dataset: Dataset,
other_dataset: Dataset | None = None,
dataset_name: str | None = "train",
other_name: str | None = "eval",
) -> tuple[Dataset, Dataset | None]:
"""Deduplicate datasets, with optional cross-dataset deduplication.
Args:
dataset: Primary dataset to deduplicate.
other_dataset: Optional second dataset to deduplicate against the first.
dataset_name: Name for the primary dataset (for logging).
other_name: Name for the second dataset (for logging).
Returns:
tuple: Deduplicated train, eval, and additional datasets.
Tuple of (deduplicated_dataset, deduplicated_other_dataset).
"""
seen_hashes: dict[str, list[int]] = {}
# Deduplicate primary dataset
LOG.info(
f"Starting deduplication for {dataset_name} dataset. Original size: {len(dataset)}"
)
dataset, seen_rows = _deduplicate_dataset(dataset)
LOG.info(
f"Deduplication complete for {dataset_name} dataset. New size: {len(dataset)}"
)
# Handle cases where datasets are None
if train_dataset is not None:
# Deduplicate second dataset if provided
if other_dataset is not None:
LOG.info(
f"Starting deduplication for train dataset. Original size: {len(train_dataset)}"
)
train_dataset = deduplicate_dataset(
dataset=train_dataset, seen_hashes=seen_hashes
f"Starting deduplication for {other_name} dataset. Original size: {len(other_dataset)}"
)
other_dataset, _ = _deduplicate_dataset(other_dataset, seen_rows)
LOG.info(
f"Deduplication complete for train dataset. New size: {len(train_dataset)}"
)
else:
LOG.info("Train dataset is None. Skipping deduplication.")
if eval_dataset is not None:
LOG.info(
f"Starting deduplication for eval dataset. Original size: {len(eval_dataset)}"
)
eval_dataset = deduplicate_dataset(
dataset=eval_dataset, seen_hashes=seen_hashes, other_dataset=train_dataset
)
LOG.info(
f"Deduplication complete for eval dataset. New size: {len(eval_dataset)}"
)
else:
LOG.info("Eval dataset is None. Skipping deduplication.")
if dataset is not None and (eval_dataset is None and train_dataset is None):
LOG.info(
f"Starting deduplication for combined dataset. Original size: {len(dataset)}"
)
dataset = deduplicate_dataset(dataset=dataset, seen_hashes=seen_hashes)
LOG.info(
f"Deduplication complete for combined dataset. New size: {len(dataset)}"
f"Deduplication complete for {other_name} dataset. New size: {len(other_dataset)}"
)
return train_dataset, eval_dataset, dataset
return dataset, other_dataset
def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault):
def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault) -> Dataset:
"""Remove sequences longer than configured maximum from dataset.
Args:
dataset: Dataset to filter.
cfg: Dictionary mapping `axolotl` config keys to values.
Returns:
Filtered dataset with long sequences removed.
"""
if "input_ids" not in dataset.column_names:
LOG.warning(
"Dataset does not contain 'input_ids' column. Skip drop long seq. This is expected for RewardModeling."
"Dataset does not contain 'input_ids' column. Skip drop long seq. This is "
"expected for reward modeling."
)
return dataset
@@ -171,20 +170,14 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault):
min_sequence_len=cfg.min_sample_len,
)
try:
with contextlib.suppress(AttributeError):
ds_lengths = get_dataset_lengths(dataset, from_arrow=True)
min_input_len = np.min(ds_lengths)
LOG.info(f"min_input_len: {min_input_len}")
max_input_len = np.max(ds_lengths)
LOG.info(f"max_input_len: {max_input_len}")
except AttributeError:
pass
try:
prior_len = len(dataset)
except TypeError:
# handle iterable datasets case
prior_len = None
prior_len = len(dataset) if hasattr(dataset, "__len__") else None
filter_map_kwargs = {}
if not isinstance(dataset, IterableDataset):

View File

@@ -0,0 +1,425 @@
"""Data handling specific to SFT."""
import logging
from typing import Any, NoReturn, cast
from datasets import (
Dataset,
IterableDataset,
Sequence,
Value,
)
from transformers import PreTrainedTokenizer
from transformers.processing_utils import ProcessorMixin
from axolotl.datasets import TokenizedPromptDataset, wrap_dataset_for_tokenized_prompt
from axolotl.prompt_strategies import load
from axolotl.prompt_strategies.bradley_terry import load as bradley_terry_load
from axolotl.prompt_tokenizers import (
AlpacaMultipleChoicePromptTokenizingStrategy,
AlpacaPromptTokenizingStrategy,
AlpacaReflectionPTStrategy,
DatasetWrappingStrategy,
GPTeacherPromptTokenizingStrategy,
JeopardyPromptTokenizingStrategy,
OpenAssistantPromptTokenizingStrategy,
PromptTokenizingStrategy,
SummarizeTLDRPromptTokenizingStrategy,
)
from axolotl.prompters import (
AlpacaPrompter,
GPTeacherPrompter,
JeopardyPrompter,
MultipleChoiceConcisePrompter,
MultipleChoiceExplainPrompter,
Prompter,
ReflectAlpacaPrompter,
SummarizeTLDRPrompter,
UnsupportedPrompter,
)
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger(__name__)
def handle_unknown_dataset_strategy(dataset_config: DictDefault) -> NoReturn:
"""Raise error for unknown dataset strategy."""
ds_type = dataset_config.type
suffix = ""
if ":load_" in ds_type:
suffix = f"Did you mean {ds_type.replace(':load_', '.load_')}?"
error_message = f"unhandled prompt tokenization strategy: {ds_type}. {suffix}"
LOG.error(error_message)
raise ValueError(error_message)
# pylint: disable=too-many-return-statements
def get_dataset_wrapper(
dataset_config: DictDefault,
tokenizer: PreTrainedTokenizer,
cfg: DictDefault,
dataset_base_type: str | None,
dataset: Dataset | IterableDataset,
dataset_prompt_style: str | None = None,
processor: ProcessorMixin | None = None, # pylint: disable=unused-argument
) -> tuple[Dataset | IterableDataset, Prompter | None]:
"""Create an appropriate dataset wrapper and prompter based on dataset
configuration.
Args:
dataset_config: Configuration for the dataset.
tokenizer: Tokenizer to use for processing text.
cfg: Global configuration object.
dataset_base_type: The base type of the dataset.
dataset: The actual dataset object.
dataset_prompt_style: Optional prompt style specification.
processor: Optional processor for multimodal datasets.
Returns:
tuple of (dataset_wrapper, dataset_prompter).
"""
# Common parameters for dataset wrapping
dataset_kwargs: dict[str, Any] = {
"process_count": cfg.dataset_processes,
"keep_in_memory": cfg.dataset_keep_in_memory is True,
}
LOG.info(
f"Loading dataset: {dataset_config['path']} with base_type: "
f"{dataset_base_type} and prompt_style: {dataset_prompt_style}"
)
# Dataset is already tokenized
if _is_dataset_already_tokenized(dataset):
return dataset, UnsupportedPrompter()
# Custom dataset type definition
if isinstance(dataset_config.type, DictDefault):
return _handle_custom_dataset_type(
dataset_config, tokenizer, cfg, dataset, dataset_kwargs
)
# Skip preparation if configured
if cfg.skip_prepare_dataset:
return dataset, None
# Bradley-Terry dataset
if dataset_config.type.startswith("bradley_terry"):
return _handle_bradley_terry_dataset(
dataset_config, tokenizer, cfg, dataset, dataset_kwargs
)
# Stepwise supervised dataset
if dataset_config.type.startswith("stepwise_supervised"):
return _handle_stepwise_supervised_dataset(
dataset_config, tokenizer, cfg, dataset, dataset_kwargs
)
# Try to load prompt tokenizer / dataset wrapper strategy from registry
dataset_strategy = load(
dataset_config.type, tokenizer, cfg, dataset_config, processor=processor
)
if dataset_strategy:
return _handle_loaded_strategy(dataset_strategy, dataset, dataset_kwargs)
# Known dataset types with specific handling
if dataset_base_type in DATASET_HANDLERS:
handler = DATASET_HANDLERS[dataset_base_type]
return handler(dataset_prompt_style, tokenizer, cfg, dataset, dataset_kwargs)
# Unhandled dataset type
handle_unknown_dataset_strategy(dataset_config)
def _is_dataset_already_tokenized(dataset: Dataset | IterableDataset) -> bool:
"""Check if the dataset is already tokenized."""
return (
isinstance(dataset, Dataset)
and "input_ids" in dataset.features
and "attention_mask" in dataset.features
and "labels" in dataset.features
)
def _handle_custom_dataset_type(
dataset_config: DictDefault,
tokenizer: PreTrainedTokenizer,
cfg: DictDefault,
dataset: Dataset | IterableDataset,
dataset_kwargs: dict[str, Any],
) -> tuple[Dataset | IterableDataset, Prompter]:
"""Handle a custom dataset type defined in the configuration."""
dataset_strategy = cast(
PromptTokenizingStrategy,
load("user_defined", tokenizer, cfg, dataset_config.type.to_dict()),
)
dataset_prompter = UnsupportedPrompter()
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
dataset_strategy,
dataset,
**dataset_kwargs,
)
return dataset_wrapper, dataset_prompter
def _handle_bradley_terry_dataset(
dataset_config: DictDefault,
tokenizer: PreTrainedTokenizer,
cfg: DictDefault,
dataset: Dataset | IterableDataset,
dataset_kwargs: dict[str, Any],
) -> tuple[Dataset | IterableDataset, Prompter | None]:
"""Handle a Bradley-Terry dataset."""
bt_type = dataset_config.type.split(".", 1)[1]
dataset_strategy = bradley_terry_load(bt_type, tokenizer, cfg, dataset_config)
if not dataset_strategy:
handle_unknown_dataset_strategy(dataset_config)
dataset_prompter = UnsupportedPrompter()
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
dataset_strategy,
dataset,
**dataset_kwargs,
)
return dataset_wrapper, dataset_prompter
def _handle_stepwise_supervised_dataset(
dataset_config: DictDefault,
tokenizer: PreTrainedTokenizer,
cfg: DictDefault,
dataset: Dataset | IterableDataset,
dataset_kwargs: dict[str, Any],
) -> tuple[Dataset | IterableDataset, Prompter]:
"""Handle a stepwise supervised dataset."""
dataset_prompter = UnsupportedPrompter()
dataset_strategy = load(dataset_config.type, tokenizer, cfg, dataset_config)
# We need to explicitly cast boolean labels to int
# for compatibility with how trl's PRMTrainer works
if isinstance(dataset, Dataset):
dataset = dataset.cast_column("labels", Sequence(Value("int64")))
dataset_wrapper = TokenizedPromptDataset(
dataset_strategy,
dataset,
**dataset_kwargs,
)
return dataset_wrapper, dataset_prompter
def _handle_loaded_strategy(
dataset_strategy: PromptTokenizingStrategy | DatasetWrappingStrategy,
dataset: Dataset | IterableDataset,
dataset_kwargs: dict[str, Any],
) -> tuple[Dataset | IterableDataset, Prompter | None]:
"""Handle a dataset with a strategy loaded from the registry."""
if isinstance(dataset_strategy, DatasetWrappingStrategy):
return dataset_strategy.wrap_dataset(dataset, **dataset_kwargs), None
dataset_prompter = UnsupportedPrompter()
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
dataset_strategy,
dataset,
**dataset_kwargs,
)
return dataset_wrapper, dataset_prompter
def _handle_alpaca_dataset(
dataset_prompt_style: str | None,
tokenizer: PreTrainedTokenizer,
cfg: DictDefault,
dataset: Dataset | IterableDataset,
dataset_kwargs: dict[str, Any],
) -> tuple[Dataset | IterableDataset, Prompter]:
"""Handle an Alpaca dataset."""
dataset_prompter = AlpacaPrompter(dataset_prompt_style)
dataset_strategy = AlpacaPromptTokenizingStrategy(
dataset_prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
dataset_strategy,
dataset,
**dataset_kwargs,
)
return dataset_wrapper, dataset_prompter
def _handle_explainchoice_dataset(
dataset_prompt_style: str | None,
tokenizer: PreTrainedTokenizer,
cfg: DictDefault,
dataset: Dataset | IterableDataset,
dataset_kwargs: dict[str, Any],
) -> tuple[Dataset | IterableDataset, Prompter]:
"""Handle an ExplainChoice dataset."""
dataset_prompter = MultipleChoiceExplainPrompter(dataset_prompt_style)
dataset_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
dataset_prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
dataset_strategy,
dataset,
**dataset_kwargs,
)
return dataset_wrapper, dataset_prompter
def _handle_concisechoice_dataset(
dataset_prompt_style: str | None,
tokenizer: PreTrainedTokenizer,
cfg: DictDefault,
dataset: Dataset | IterableDataset,
dataset_kwargs: dict[str, Any],
) -> tuple[Dataset | IterableDataset, Prompter]:
"""Handle a ConciseChoice dataset."""
dataset_prompter = MultipleChoiceConcisePrompter(dataset_prompt_style)
dataset_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
dataset_prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
dataset_strategy,
dataset,
**dataset_kwargs,
)
return dataset_wrapper, dataset_prompter
def _handle_summarizetldr_dataset(
dataset_prompt_style: str | None,
tokenizer: PreTrainedTokenizer,
cfg: DictDefault,
dataset: Dataset | IterableDataset,
dataset_kwargs: dict[str, Any],
) -> tuple[Dataset | IterableDataset, Prompter]:
"""Handle a SummarizeTLDR dataset."""
dataset_prompter = SummarizeTLDRPrompter(dataset_prompt_style)
dataset_strategy = SummarizeTLDRPromptTokenizingStrategy(
dataset_prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
dataset_strategy,
dataset,
**dataset_kwargs,
)
return dataset_wrapper, dataset_prompter
def _handle_jeopardy_dataset(
dataset_prompt_style: str | None,
tokenizer: PreTrainedTokenizer,
cfg: DictDefault,
dataset: Dataset | IterableDataset,
dataset_kwargs: dict[str, Any],
) -> tuple[Dataset | IterableDataset, Prompter]:
"""Handle a Jeopardy dataset."""
dataset_prompter = JeopardyPrompter(dataset_prompt_style)
dataset_strategy = JeopardyPromptTokenizingStrategy(
dataset_prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
dataset_strategy,
dataset,
**dataset_kwargs,
)
return dataset_wrapper, dataset_prompter
def _handle_oasst_dataset(
dataset_prompt_style: str | None,
tokenizer: PreTrainedTokenizer,
cfg: DictDefault,
dataset: Dataset | IterableDataset,
dataset_kwargs: dict[str, Any],
) -> tuple[Dataset | IterableDataset, Prompter]:
"""Handle an OpenAssistant dataset."""
dataset_prompter = AlpacaPrompter(dataset_prompt_style)
dataset_strategy = OpenAssistantPromptTokenizingStrategy(
dataset_prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
dataset_strategy,
dataset,
**dataset_kwargs,
)
return dataset_wrapper, dataset_prompter
def _handle_gpteacher_dataset(
dataset_prompt_style: str | None,
tokenizer: PreTrainedTokenizer,
cfg: DictDefault,
dataset: Dataset | IterableDataset,
dataset_kwargs: dict[str, Any],
) -> tuple[Dataset | IterableDataset, Prompter]:
"""Handle a GPTeacher dataset."""
dataset_prompter = GPTeacherPrompter(dataset_prompt_style)
dataset_strategy = GPTeacherPromptTokenizingStrategy(
dataset_prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
dataset_strategy,
dataset,
**dataset_kwargs,
)
return dataset_wrapper, dataset_prompter
def _handle_reflection_dataset(
dataset_prompt_style: str | None,
tokenizer: PreTrainedTokenizer,
cfg: DictDefault,
dataset: Dataset | IterableDataset,
dataset_kwargs: dict[str, Any],
) -> tuple[Dataset | IterableDataset, Prompter]:
"""Handle a Reflection dataset."""
dataset_prompter = ReflectAlpacaPrompter(dataset_prompt_style)
dataset_strategy = AlpacaReflectionPTStrategy(
dataset_prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
dataset_strategy,
dataset,
**dataset_kwargs,
)
return dataset_wrapper, dataset_prompter
DATASET_HANDLERS = {
"alpaca": _handle_alpaca_dataset,
"explainchoice": _handle_explainchoice_dataset,
"concisechoice": _handle_concisechoice_dataset,
"summarizetldr": _handle_summarizetldr_dataset,
"jeopardy": _handle_jeopardy_dataset,
"oasst": _handle_oasst_dataset,
"gpteacher": _handle_gpteacher_dataset,
"reflection": _handle_reflection_dataset,
}

View File

@@ -336,6 +336,14 @@ class AxolotlInputConfig(
plugins: list[str] | None = Field(default=None)
@field_validator("seed", mode="after")
@classmethod
def set_default_seed(cls, seed):
if seed is None:
LOG.info("`seed` not set in config; setting to 42")
seed = 42
return seed
@field_validator("datasets", mode="before")
@classmethod
def deprecate_sharegpt_datasets(cls, datasets):
@@ -1199,7 +1207,7 @@ class AxolotlInputConfig(
"flash_attention: true must be set with sequence_parallel_degree > 1"
)
if self.sample_packing and self.micro_batch_size > 1:
if self.sample_packing and getattr(self, "micro_batch_size", 1) > 1:
raise ValueError(
"micro_batch_size must be set to 1 when sample_packing is enabled "
"due to a `ring-flash-attn` requirement"

View File

@@ -12,7 +12,7 @@ from axolotl.common.datasets import load_datasets
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.loaders import ModelLoader, load_tokenizer
from axolotl.utils.config import normalize_config
from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.data import prepare_preference_datasets
from axolotl.utils.dict import DictDefault
from axolotl.utils.schemas.enums import RLType
@@ -451,15 +451,19 @@ def rand_reward_func(prompts, completions) -> list[float]:
# Only use mock for the commented out configs
if dataset_name is not None:
with patch(
"axolotl.utils.data.rl.load_dataset_w_config"
"axolotl.utils.data.rl.load_dataset_with_config"
) as mock_load_dataset:
mock_load_dataset.return_value = request.getfixturevalue(
dataset_name
)
train_dataset, eval_dataset = load_prepare_preference_datasets(cfg)
train_dataset, eval_dataset = prepare_preference_datasets(
cfg, tokenizer
)
else:
# Load actual datasets for orpo_cfg and kto_cfg
train_dataset, eval_dataset = load_prepare_preference_datasets(cfg)
train_dataset, eval_dataset = prepare_preference_datasets(
cfg, tokenizer
)
builder.train_dataset = train_dataset
builder.eval_dataset = eval_dataset

View File

@@ -4,7 +4,6 @@ Simple end-to-end test for Cut Cross Entropy integration
import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils import get_pytorch_version
@@ -59,8 +58,7 @@ class TestCutCrossEntropyIntegration:
cfg = validate_config(cfg)
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
major, minor, _ = get_pytorch_version()
if (major, minor) < (2, 4):
@@ -105,8 +103,7 @@ class TestCutCrossEntropyIntegration:
cfg = validate_config(cfg)
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
major, minor, _ = get_pytorch_version()
if (major, minor) < (2, 4):
@@ -134,8 +131,7 @@ class TestCutCrossEntropyIntegration:
cfg = validate_config(cfg)
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
major, minor, _ = get_pytorch_version()
if (major, minor) < (2, 4):

View File

@@ -5,7 +5,6 @@ e2e tests to make sure all the hooks are fired on the plugin
import os
from pathlib import Path
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.integrations.base import BasePlugin
from axolotl.train import train
@@ -160,8 +159,7 @@ class TestPluginHooks:
cfg = validate_config(cfg)
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -6,7 +6,6 @@ from pathlib import Path
import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
@@ -84,8 +83,7 @@ class TestKnowledgeDistillation:
cfg = validate_config(cfg)
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
@@ -115,8 +113,7 @@ class TestKnowledgeDistillation:
cfg = validate_config(cfg)
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.safetensors").exists()

View File

@@ -2,7 +2,6 @@
Simple end-to-end test for Liger integration
"""
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
@@ -57,8 +56,7 @@ class LigerIntegrationTestCase:
cfg = validate_config(cfg)
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@@ -104,8 +102,7 @@ class LigerIntegrationTestCase:
cfg = validate_config(cfg)
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -6,7 +6,6 @@ from pathlib import Path
import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
@@ -88,8 +87,7 @@ class TestLLMCompressorIntegration:
prepare_plugins(cfg)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
try:
train(cfg=cfg, dataset_meta=dataset_meta)

View File

@@ -105,7 +105,7 @@ def start_vllm(
print(f"{i}: VLLM server failed to start: {str(exc)}")
# also check if the process.pid is still running
if not process.poll() is None:
if process.poll() is not None:
break
time.sleep(period_seconds)

View File

@@ -0,0 +1,192 @@
"""Tests for FileLockLoader class."""
import tempfile
import threading
import time
from pathlib import Path
from unittest.mock import MagicMock, Mock, patch
import pytest
from axolotl.utils.data.lock import FileLockLoader
from axolotl.utils.dict import DictDefault
class TestFileLockLoader:
"""Class with tests for FileLockLoader."""
@pytest.fixture
def temp_dir(self):
"""Create a temporary directory for testing."""
with tempfile.TemporaryDirectory() as tmp_dir:
yield Path(tmp_dir)
@pytest.fixture
def cfg(self, temp_dir):
"""Create a test configuration."""
return DictDefault({"dataset_prepared_path": str(temp_dir)})
@pytest.fixture
def loader(self, cfg):
"""Create a FileLockLoader instance for testing."""
return FileLockLoader(cfg)
def test_load_first_process(self, loader):
"""Test load() when no ready flag exists (first process)."""
mock_load_fn = Mock(return_value="test_data")
result = loader.load(mock_load_fn)
# Should call the load function
mock_load_fn.assert_called_once()
assert result == "test_data"
# Should create the ready flag
assert loader.ready_flag_path.exists()
def test_load_subsequent_process(self, loader):
"""Test load() when ready flag already exists (subsequent process)."""
# Create ready flag first
loader.ready_flag_path.touch()
mock_load_fn = Mock(return_value="loaded_data")
result = loader.load(mock_load_fn)
# Should still call load function (to load the prepared data)
mock_load_fn.assert_called_once()
assert result == "loaded_data"
def test_load_concurrent_processes(self, cfg):
"""Test that concurrent processes coordinate correctly."""
results = []
call_count = 0
def slow_load_fn():
nonlocal call_count
call_count += 1
time.sleep(0.1) # Simulate slow loading
return f"data_{call_count}"
def worker():
loader = FileLockLoader(cfg)
result = loader.load(slow_load_fn)
results.append(result)
# Start multiple threads simultaneously
threads = [threading.Thread(target=worker) for _ in range(3)]
for t in threads:
t.start()
for t in threads:
t.join()
# Only one thread should have done the initial loading
# All should return data, but the load function should be called
# once by the first process and once by each subsequent process
assert len(results) == 3
assert all(result.startswith("data_") for result in results)
@patch("time.sleep")
def test_load_waiting_for_ready_flag(self, mock_sleep, loader):
"""Test that processes wait for the ready flag to appear."""
mock_load_fn = Mock(return_value="waiting_data")
mock_ready_flag_path = Mock()
exists_call_count = 0
def mock_exists():
nonlocal exists_call_count
exists_call_count += 1
if exists_call_count == 1:
# First check: ready flag exists (not first process)
return True
if exists_call_count <= 3:
# While loop checks: flag doesn't exist yet
return False
return True
mock_ready_flag_path.exists.side_effect = mock_exists
# Replace the ready_flag_path with our mock
original_path = loader.ready_flag_path
loader.ready_flag_path = mock_ready_flag_path
try:
result = loader.load(mock_load_fn)
finally:
# Restore original path
loader.ready_flag_path = original_path
# Should have slept twice while waiting
assert mock_sleep.call_count == 2
mock_sleep.assert_called_with(1)
# Should eventually call load function
mock_load_fn.assert_called_once()
assert result == "waiting_data"
def test_complete_workflow_with_cleanup(self, loader):
"""Test the complete load -> cleanup workflow."""
mock_load_fn = Mock(return_value="test_data")
# First process calls load (this should set up counter)
result = loader.load(mock_load_fn)
assert result == "test_data"
assert loader.ready_flag_path.exists()
assert loader.counter_path.exists()
# Cleanup should remove everything since there's only one process
loader.cleanup()
assert not loader.ready_flag_path.exists()
assert not loader.counter_path.exists()
def test_multiple_processes_workflow(self, loader):
"""Test workflow with multiple processes."""
# Simulate multiple processes by manually setting up counter
loader.ready_flag_path.touch()
loader.counter_path.write_text("3") # 3 processes
# First process cleanup
loader.cleanup()
assert loader.ready_flag_path.exists()
assert loader.counter_path.read_text().strip() == "2"
# Second process cleanup
loader.cleanup()
assert loader.ready_flag_path.exists()
assert loader.counter_path.read_text().strip() == "1"
# Last process cleanup
loader.cleanup()
assert not loader.ready_flag_path.exists()
assert not loader.counter_path.exists()
def test_load_exception_handling(self, loader):
"""Test behavior when load_fn raises an exception."""
def failing_load_fn():
raise ValueError("Load failed")
with pytest.raises(ValueError, match="Load failed"):
loader.load(failing_load_fn)
# Ready flag should not be created on failure
assert not loader.ready_flag_path.exists()
def test_file_lock_called(self, loader):
"""Test that FileLock is properly used."""
mock_load_fn = Mock(return_value="locked_data")
with patch("axolotl.utils.data.lock.FileLock") as mock_filelock:
mock_context = MagicMock()
mock_filelock.return_value.__enter__ = Mock(return_value=mock_context)
mock_filelock.return_value.__exit__ = Mock(return_value=None)
loader.load(mock_load_fn)
# Verify FileLock was called with correct path
mock_filelock.assert_called_once_with(str(loader.lock_file_path))
# Verify context manager was used
mock_filelock.return_value.__enter__.assert_called_once()
mock_filelock.return_value.__exit__.assert_called_once()

View File

@@ -4,7 +4,6 @@ E2E tests for multipack fft llama using 4d attention masks
import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
@@ -60,8 +59,7 @@ class Test4dMultipackLlama(unittest.TestCase):
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@@ -108,8 +106,7 @@ class Test4dMultipackLlama(unittest.TestCase):
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -6,7 +6,6 @@ import pytest
import transformers
from torch.utils.checkpoint import checkpoint
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
@@ -75,8 +74,7 @@ class TestActivationCheckpointing:
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -5,7 +5,6 @@ E2E tests for lora llama
import pytest
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
@@ -73,8 +72,7 @@ class TestFAXentropyLlama:
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -6,7 +6,6 @@ import unittest
import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
@@ -63,8 +62,7 @@ class TestFalconPatched(unittest.TestCase):
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@@ -105,8 +103,7 @@ class TestFalconPatched(unittest.TestCase):
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -7,7 +7,6 @@ import unittest
import pytest
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
@@ -62,8 +61,7 @@ class TestFusedLlama(unittest.TestCase):
cfg.fp16 = True
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -6,7 +6,6 @@ import unittest
import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
@@ -64,8 +63,7 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@@ -107,8 +105,7 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -7,7 +7,6 @@ import unittest
import pytest
from transformers.utils import is_auto_gptq_available, is_torch_bf16_gpu_available
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
@@ -65,8 +64,7 @@ class TestLoraLlama(unittest.TestCase):
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@@ -114,8 +112,7 @@ class TestLoraLlama(unittest.TestCase):
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -4,7 +4,6 @@ E2E tests for lora llama
import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
@@ -60,8 +59,7 @@ class TestMistral(unittest.TestCase):
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@@ -102,8 +100,7 @@ class TestMistral(unittest.TestCase):
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -4,7 +4,6 @@ E2E tests for mixtral
import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
@@ -57,8 +56,7 @@ class TestMixtral(unittest.TestCase):
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@@ -96,8 +94,7 @@ class TestMixtral(unittest.TestCase):
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -4,7 +4,6 @@ E2E tests for lora llama
import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
@@ -60,8 +59,7 @@ class TestPhiMultipack(unittest.TestCase):
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@@ -112,8 +110,7 @@ class TestPhiMultipack(unittest.TestCase):
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -7,7 +7,6 @@ import subprocess
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
@@ -67,8 +66,7 @@ class TestResumeLlama:
cfg.fp16 = True
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
@@ -78,7 +76,6 @@ class TestResumeLlama:
}
)
normalize_config(resume_cfg)
cli_args = TrainerCliArgs()
train(cfg=resume_cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -4,7 +4,6 @@ e2e tests for unsloth qlora
import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
@@ -68,8 +67,7 @@ class TestUnslothQLoRA:
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@@ -119,8 +117,7 @@ class TestUnslothQLoRA:
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@@ -175,8 +172,7 @@ class TestUnslothQLoRA:
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -6,7 +6,6 @@ import unittest
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
@@ -59,8 +58,7 @@ class TestPackedFlex(unittest.TestCase):
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)

View File

@@ -5,7 +5,6 @@ E2E tests for relora llama
import unittest
from pathlib import Path
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
@@ -71,8 +70,7 @@ class TestReLoraLlama(unittest.TestCase):
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-100/adapter", cfg)

View File

@@ -6,7 +6,6 @@ from pathlib import Path
import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
@@ -72,8 +71,7 @@ class TestDeepseekV3:
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
@@ -122,8 +120,7 @@ class TestDeepseekV3:
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()

View File

@@ -1,6 +1,4 @@
"""
E2E tests for lora llama
"""
"""E2E tests for lora llama"""
import unittest
from pathlib import Path

View File

@@ -4,7 +4,6 @@ E2E tests for llama pretrain
import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
@@ -54,8 +53,7 @@ class TestEmbeddingsLrScale(unittest.TestCase):
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@@ -99,8 +97,7 @@ class TestEmbeddingsLrScale(unittest.TestCase):
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -6,7 +6,6 @@ import unittest
import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
@@ -66,8 +65,7 @@ class TestFalcon(unittest.TestCase):
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@@ -122,8 +120,7 @@ class TestFalcon(unittest.TestCase):
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@@ -164,8 +161,7 @@ class TestFalcon(unittest.TestCase):
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -6,7 +6,6 @@ from pathlib import Path
import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
@@ -69,8 +68,7 @@ class TestGemma2:
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
@@ -121,8 +119,7 @@ class TestGemma2:
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()

View File

@@ -6,7 +6,6 @@ from pathlib import Path
import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
@@ -68,8 +67,7 @@ class TestGemma3Text:
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
@@ -119,8 +117,7 @@ class TestGemma3Text:
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()

View File

@@ -2,7 +2,6 @@
E2E tests for llama
"""
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
@@ -51,8 +50,7 @@ class TestLlama:
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@@ -99,8 +97,7 @@ class TestLlama:
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@@ -144,8 +141,7 @@ class TestLlama:
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@@ -185,8 +181,7 @@ class TestLlama:
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -1,10 +1,7 @@
"""
E2E tests for llama pretrain
"""
"""E2E tests for llama pretrain"""
import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
@@ -14,9 +11,7 @@ from .utils import check_model_output_exists, check_tensorboard
class TestPretrainLlama:
"""
Test case for Llama models w pretraining
"""
"""Test case for Llama models w pretraining"""
@pytest.mark.parametrize(
"sample_packing",
@@ -66,8 +61,7 @@ class TestPretrainLlama:
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -4,7 +4,6 @@ E2E tests for lora llama
import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
@@ -60,8 +59,7 @@ class TestLlamaVision(unittest.TestCase):
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@@ -106,8 +104,7 @@ class TestLlamaVision(unittest.TestCase):
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -4,7 +4,6 @@ E2E tests for lora llama
import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
@@ -55,8 +54,7 @@ class TestLoraLlama(unittest.TestCase):
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -6,7 +6,6 @@ import unittest
import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
@@ -57,8 +56,7 @@ class TestMamba(unittest.TestCase):
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -6,7 +6,6 @@ import unittest
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
@@ -61,8 +60,7 @@ class TestMistral(unittest.TestCase):
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@@ -106,8 +104,7 @@ class TestMistral(unittest.TestCase):
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -7,7 +7,6 @@ import unittest
import torch
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
@@ -67,8 +66,7 @@ class TestMixtral(unittest.TestCase):
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
assert (
@@ -123,8 +121,7 @@ class TestMixtral(unittest.TestCase):
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
assert (
@@ -182,8 +179,7 @@ class TestMixtral(unittest.TestCase):
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
assert (
@@ -241,8 +237,7 @@ class TestMixtral(unittest.TestCase):
cfg.bf16 = True
else:
cfg.fp16 = True
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
assert (
@@ -287,8 +282,7 @@ class TestMixtral(unittest.TestCase):
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -4,7 +4,6 @@ E2E tests for custom optimizers using Llama
import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
@@ -61,8 +60,7 @@ class TestCustomOptimizers(unittest.TestCase):
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@@ -107,8 +105,7 @@ class TestCustomOptimizers(unittest.TestCase):
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@@ -154,8 +151,7 @@ class TestCustomOptimizers(unittest.TestCase):
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@@ -194,8 +190,7 @@ class TestCustomOptimizers(unittest.TestCase):
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@@ -242,8 +237,7 @@ class TestCustomOptimizers(unittest.TestCase):
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -6,7 +6,6 @@ import unittest
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
@@ -58,8 +57,7 @@ class TestPackedLlama(unittest.TestCase):
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)

View File

@@ -4,7 +4,6 @@ E2E tests for lora llama
import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
@@ -58,8 +57,7 @@ class TestPhi(unittest.TestCase):
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@@ -108,8 +106,7 @@ class TestPhi(unittest.TestCase):
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -4,7 +4,6 @@ E2E tests for process reward model w/ lora llama
import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
@@ -54,8 +53,7 @@ class TestProcessRewardSmolLM2(unittest.TestCase):
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_tensorboard(

View File

@@ -5,7 +5,6 @@ E2E tests for QAT
import unittest
from pathlib import Path
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
@@ -64,8 +63,7 @@ class TestQATLlama(unittest.TestCase):
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-5", cfg)

View File

@@ -4,7 +4,6 @@ E2E tests for reward model lora llama
import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
@@ -63,8 +62,7 @@ class TestRewardModelLoraSmolLM2(unittest.TestCase):
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_tensorboard(

View File

@@ -4,7 +4,6 @@ E2E tests for custom schedulers using Llama
import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
@@ -57,8 +56,7 @@ class TestCustomSchedulers(unittest.TestCase):
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -6,8 +6,9 @@ import unittest
import pytest
from axolotl.loaders.tokenizer import load_tokenizer
from axolotl.prompt_strategies.dpo import load as load_dpo
from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.data.rl import prepare_preference_datasets
from axolotl.utils.dict import DictDefault
from tests.hf_offline_utils import enable_hf_offline
@@ -55,7 +56,8 @@ class TestDPOChatml:
# test that dpo.load works
load_dpo("chatml", cfg)
# now actually load the datasets with the strategy
train_ds, _ = load_prepare_preference_datasets(cfg)
tokenizer = load_tokenizer(cfg)
train_ds, _ = prepare_preference_datasets(cfg, tokenizer)
assert train_ds[0]["prompt"].startswith("<|im_start|>")
assert train_ds[0]["prompt"].endswith("<|im_start|>assistant\n")
assert "chosen" in train_ds[0]

View File

@@ -1,10 +1,9 @@
"""
Test dataset loading under various conditions.
"""
"""Test dataset loading under various conditions."""
import shutil
import tempfile
from pathlib import Path
from typing import Any, Generator
from unittest.mock import patch
import pytest
@@ -12,8 +11,9 @@ from datasets import Dataset
from huggingface_hub import snapshot_download
from transformers import PreTrainedTokenizer
from axolotl.utils.data import load_tokenized_prepared_datasets
from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.loaders.tokenizer import load_tokenizer
from axolotl.utils.data.rl import prepare_preference_datasets
from axolotl.utils.data.sft import _load_tokenized_prepared_datasets
from axolotl.utils.dict import DictDefault
from tests.constants import (
@@ -28,7 +28,9 @@ class TestDatasetPreparation:
"""Test a configured dataloader."""
@pytest.fixture
def tokenizer(self, tokenizer_huggyllama) -> PreTrainedTokenizer:
def tokenizer(
self, tokenizer_huggyllama
) -> Generator[PreTrainedTokenizer, Any, Any]:
tokenizer_huggyllama.add_special_tokens(SPECIAL_TOKENS)
yield tokenizer_huggyllama
@@ -63,7 +65,10 @@ class TestDatasetPreparation:
}
)
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
with patch(
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
):
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
assert len(dataset) == 2000
assert "input_ids" in dataset.features
@@ -107,7 +112,10 @@ class TestDatasetPreparation:
}
)
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
with patch(
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
):
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
assert len(dataset) == 2000
assert "input_ids" in dataset.features
@@ -136,7 +144,10 @@ class TestDatasetPreparation:
}
)
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
with patch(
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
):
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
assert len(dataset) == 1
assert "input_ids" in dataset.features
@@ -145,7 +156,7 @@ class TestDatasetPreparation:
@enable_hf_offline
def test_load_from_dir_of_parquet(self, tokenizer, dataset_fixture):
"""Usual use case. Verify a directory of parquet files can be loaded."""
"""Usual use case. Verify a directory of parquet files can be loaded."""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_dir = Path(tmp_dir) / "tmp_dataset"
tmp_ds_dir.mkdir()
@@ -171,7 +182,10 @@ class TestDatasetPreparation:
}
)
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
with patch(
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
):
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
assert len(dataset) == 1
assert "input_ids" in dataset.features
@@ -206,7 +220,10 @@ class TestDatasetPreparation:
}
)
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
with patch(
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
):
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
assert len(dataset) == 1
assert "input_ids" in dataset.features
@@ -235,7 +252,10 @@ class TestDatasetPreparation:
}
)
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
with patch(
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
):
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
assert len(dataset) == 1
assert "input_ids" in dataset.features
@@ -264,7 +284,10 @@ class TestDatasetPreparation:
}
)
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
with patch(
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
):
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
assert len(dataset) == 1
assert "input_ids" in dataset.features
@@ -286,7 +309,8 @@ class TestDatasetPreparation:
}
)
train_dataset, _ = load_prepare_preference_datasets(cfg)
tokenizer = load_tokenizer(cfg)
train_dataset, _ = prepare_preference_datasets(cfg, tokenizer)
assert len(train_dataset) == 1800
assert "conversation" not in train_dataset.features
@@ -318,7 +342,10 @@ class TestDatasetPreparation:
}
)
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
with patch(
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
):
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
assert len(dataset) == 2000
assert "input_ids" in dataset.features
@@ -342,13 +369,16 @@ class TestDatasetPreparation:
)
# pylint: disable=duplicate-code
with patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset:
with patch(
"axolotl.utils.data.rl.load_dataset_with_config"
) as mock_load_dataset:
# Set up the mock to return different values on successive calls
mock_load_dataset.return_value = (
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff
)
train_dataset, _ = load_prepare_preference_datasets(cfg)
tokenizer = load_tokenizer(cfg)
train_dataset, _ = prepare_preference_datasets(cfg, tokenizer)
assert len(train_dataset) == 1800
assert "conversation" not in train_dataset.features
@@ -393,16 +423,18 @@ class TestDatasetPreparation:
)
with patch(
"axolotl.utils.data.shared.load_dataset_w_config"
"axolotl.utils.data.shared.load_dataset_with_config"
) as mock_load_dataset:
# Set up the mock to return different values on successive calls
mock_load_dataset.return_value = (
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff
)
dataset, _ = load_tokenized_prepared_datasets(
tokenizer, cfg, prepared_path
)
with patch(
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH",
str(prepared_path),
):
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
assert len(dataset) == 2000
assert "input_ids" in dataset.features
@@ -437,7 +469,10 @@ class TestDatasetPreparation:
}
)
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
with patch(
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
):
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
assert len(dataset) == 2000
assert "input_ids" in dataset.features

View File

@@ -5,7 +5,6 @@ Additionally, this test suite includes tests for functions that indirectly call
`deduplicate_and_log_datasets` during the execution of the preprocess command.
"""
import hashlib
import unittest
from unittest.mock import patch
@@ -14,8 +13,7 @@ from datasets import Dataset
from axolotl.loaders import load_processor, load_tokenizer
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.data import prepare_dataset
from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.data import prepare_datasets, prepare_preference_datasets
from axolotl.utils.data.utils import deduplicate_and_log_datasets
from axolotl.utils.dict import DictDefault
@@ -71,36 +69,14 @@ class TestDeduplicateIndividualFunctions(unittest.TestCase):
self.expected_dataset = Dataset.from_dict(self.expected_data)
def test_deduplication(self):
train_dataset, _, _ = deduplicate_and_log_datasets(train_dataset=self.dataset)
_, eval_dataset, _ = deduplicate_and_log_datasets(eval_dataset=self.dataset)
train_dataset, _ = deduplicate_and_log_datasets(dataset=self.dataset)
eval_dataset, _ = deduplicate_and_log_datasets(
dataset=self.dataset, dataset_name="eval"
)
verify_deduplication(train_dataset, self.expected_dataset, "train_dataset")
verify_deduplication(eval_dataset, self.expected_dataset, "eval_dataset")
def test_datasets_are_none(self):
# Test when both datasets are None
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
train_dataset=None, eval_dataset=None
)
self.assertIsNone(train_dataset, "Expected train_dataset to be None")
self.assertIsNone(eval_dataset, "Expected eval_dataset to be None")
def test_only_train_is_none(self):
# Test when only train_dataset is None
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
train_dataset=None, eval_dataset=self.dataset
)
self.assertIsNone(train_dataset, "Expected train_dataset to be None")
verify_deduplication(eval_dataset, self.expected_dataset, "eval_dataset")
def test_only_eval_is_none(self):
# Test when only eval_dataset is None
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
train_dataset=self.dataset, eval_dataset=None
)
self.assertIsNone(eval_dataset, "Expected eval_dataset to be None")
verify_deduplication(train_dataset, self.expected_dataset, "train_dataset")
def test_exact_duplicates(self):
# Test when datasets are exact duplicates
duplicate_data = {
@@ -115,8 +91,10 @@ class TestDeduplicateIndividualFunctions(unittest.TestCase):
expected_dataset = Dataset.from_dict(expected_data)
# Run deduplication
train_dataset, _, _ = deduplicate_and_log_datasets(train_dataset=dataset)
_, eval_dataset, _ = deduplicate_and_log_datasets(eval_dataset=dataset)
train_dataset, _ = deduplicate_and_log_datasets(dataset=dataset)
eval_dataset, _ = deduplicate_and_log_datasets(
dataset=dataset, dataset_name="eval"
)
verify_deduplication(train_dataset, expected_dataset, "train_dataset")
verify_deduplication(eval_dataset, expected_dataset, "eval_dataset")
@@ -139,8 +117,10 @@ class TestDeduplicateIndividualFunctions(unittest.TestCase):
expected_dataset = Dataset.from_dict(expected_data)
# Run deduplication
train_dataset, _, _ = deduplicate_and_log_datasets(train_dataset=dataset)
_, eval_dataset, _ = deduplicate_and_log_datasets(eval_dataset=dataset)
train_dataset, _ = deduplicate_and_log_datasets(dataset=dataset)
eval_dataset, _ = deduplicate_and_log_datasets(
dataset=dataset, dataset_name="eval"
)
verify_deduplication(train_dataset, expected_dataset, "train_dataset")
verify_deduplication(eval_dataset, expected_dataset, "eval_dataset")
@@ -169,8 +149,8 @@ class TestDeduplicateIndividualFunctions(unittest.TestCase):
expected_dataset_eval = Dataset.from_dict(expected_data_eval)
# Run deduplication
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
train_dataset=dataset, eval_dataset=dataset
train_dataset, eval_dataset = deduplicate_and_log_datasets(
dataset=dataset, other_dataset=dataset
)
verify_deduplication(train_dataset, expected_dataset_train, "train_dataset")
@@ -206,8 +186,8 @@ class TestDeduplicateIndividualFunctions(unittest.TestCase):
expected_dataset_eval = Dataset.from_dict(expected_data_eval)
# Run deduplication
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
train_dataset=dataset_train, eval_dataset=dataset_eval
train_dataset, eval_dataset = deduplicate_and_log_datasets(
dataset=dataset_train, other_dataset=dataset_eval
)
verify_deduplication(train_dataset, expected_dataset_train, "train_dataset")
@@ -245,7 +225,9 @@ class TestDeduplicateRLDataset:
# pylint: disable=duplicate-code
with (
patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset,
patch(
"axolotl.utils.data.rl.load_dataset_with_config"
) as mock_load_dataset,
patch("axolotl.loaders.load_tokenizer") as mock_load_tokenizer,
):
# Set up the mock to return different values on successive calls
@@ -255,7 +237,8 @@ class TestDeduplicateRLDataset:
]
mock_load_tokenizer.return_value = tokenizer_huggyllama
train_dataset, _ = load_prepare_preference_datasets(cfg)
tokenizer = load_tokenizer(cfg)
train_dataset, _ = prepare_preference_datasets(cfg, tokenizer)
# Verify that the dataset has been deduplicated
assert len(train_dataset) == 1800, "Dataset was not properly deduplicated"
@@ -269,7 +252,9 @@ class TestDeduplicateRLDataset:
):
# pylint: disable=duplicate-code
with (
patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset,
patch(
"axolotl.utils.data.rl.load_dataset_with_config"
) as mock_load_dataset,
patch("axolotl.loaders.load_tokenizer") as mock_load_tokenizer,
):
# Set up the mock to return different values on successive calls
@@ -279,9 +264,10 @@ class TestDeduplicateRLDataset:
]
mock_load_tokenizer.return_value = tokenizer_huggyllama
cfg.dataset_exact_deduplication = False
# Load the dataset without deduplication
train_dataset, _ = load_prepare_preference_datasets(cfg)
cfg.dataset_exact_deduplication = False
tokenizer = load_tokenizer(cfg)
train_dataset, _ = prepare_preference_datasets(cfg, tokenizer)
# Verify that the dataset retains duplicates
assert (
@@ -335,7 +321,7 @@ class TestDeduplicateNonRL(unittest.TestCase):
)
# Prepare dataset using the prepare_dataset function
train_dataset, _, _, _ = prepare_dataset(
train_dataset, _, _, _ = prepare_datasets(
self.cfg_1,
tokenizer,
processor=processor,
@@ -362,7 +348,7 @@ class TestDeduplicateNonRL(unittest.TestCase):
)
# Prepare dataset using the prepare_dataset function
_, eval_dataset, _, _ = prepare_dataset(
_, eval_dataset, _, _ = prepare_datasets(
self.cfg_1,
tokenizer,
processor=processor,
@@ -389,7 +375,7 @@ class TestDeduplicateNonRL(unittest.TestCase):
)
# Prepare dataset using the prepare_dataset function
train_dataset, eval_dataset, _, _ = prepare_dataset(
train_dataset, eval_dataset, _, _ = prepare_datasets(
self.cfg_1,
tokenizer,
processor=processor,
@@ -428,41 +414,8 @@ class TestWrongCollisions(unittest.TestCase):
self.eval_dataset = Dataset.from_dict(self.eval_data)
self.dataset = Dataset.from_dict(self.dataset_data)
@patch(
"axolotl.utils.data.utils.sha256",
side_effect=lambda x: (
hashlib.sha256("forced_collision_hash".encode("utf-8")).hexdigest()
if "sample 5" in x
else hashlib.sha256(x.encode("utf-8")).hexdigest()
),
)
def test_deduplication_wrong_collision_train_eval(self, _mock_sha256):
dedup_train, dedup_eval, _ = deduplicate_and_log_datasets(
train_dataset=self.train_dataset, eval_dataset=self.eval_dataset
)
self.assertEqual(
len(dedup_train),
2,
"train dataset should not deduplicate rows with forced hash collisions but different labels.",
)
self.assertEqual(
len(dedup_eval),
2,
"Eval dataset should not deduplicate rows with forced hash collisions but different labels.",
)
self.assertEqual(
len(dedup_eval),
len(self.eval_dataset),
"The output eval dataset should have the same number of rows as the input eval dataset.",
)
self.assertEqual(
str(dedup_eval),
str(self.eval_dataset),
"The string representation of the output eval dataset should be identical to the input eval dataset.",
)
def test_deduplication_dataset_only(self):
_, _, dedup_dataset = deduplicate_and_log_datasets(dataset=self.dataset)
dedup_dataset, _ = deduplicate_and_log_datasets(dataset=self.dataset)
self.assertEqual(
len(dedup_dataset), 3, "Dataset should have all original values"
)