Data loader refactor (#2707)

* data loading refactor (wip)

* updates

* progress

* pytest

* pytest fix

* lint

* zero_first -> filelock, more simplifications

* small simplification

* import change

* nit

* lint

* simplify dedup

* couldnt resist

* review comments WIP

* continued wip

* minor changes

* fix; remove contrived test

* further refactor

* set default seed in pydantic config

* lint

* continued simplication

* lint

* renaming and nits

* filelock tests

* fix

* fix

* lint

* remove nullable arg

* remove unnecessary code

* moving dataset save fn to shared module

* remove debug print

* matching var naming

* fn name change

* coderabbit comments

* naming nit

* fix test
This commit is contained in:
Dan Saunders
2025-06-10 19:53:07 -04:00
committed by GitHub
parent 52a0452acb
commit 00cda8cc70
62 changed files with 2125 additions and 1436 deletions

View File

@@ -1,5 +1,3 @@
""" """Various shared constants"""
Various shared constants
"""
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared" DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"

View File

@@ -3,15 +3,13 @@
import math import math
import random import random
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Union
from datasets import Dataset from datasets import Dataset
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401 import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs
from axolotl.loaders import load_processor, load_tokenizer from axolotl.loaders import load_processor, load_tokenizer
from axolotl.utils.data import prepare_dataset from axolotl.utils.data import prepare_datasets, prepare_preference_datasets
from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.enums import RLType from axolotl.utils.schemas.enums import RLType
@@ -30,16 +28,7 @@ class TrainDatasetMeta:
def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset: def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset:
""" """Randomly sample `num_samples` samples with replacement from `dataset`."""
Randomly sample `num_samples` samples from `dataset`.
Args:
dataset: Dataset.
num_samples: Number of samples to return.
Returns:
Random sample (with replacement) of examples in `dataset`.
"""
return dataset.select( return dataset.select(
[random.randrange(0, len(dataset) - 1) for _ in range(num_samples)] # nosec [random.randrange(0, len(dataset) - 1) for _ in range(num_samples)] # nosec
) )
@@ -51,44 +40,37 @@ def load_datasets(
cli_args: PreprocessCliArgs | TrainerCliArgs | None = None, cli_args: PreprocessCliArgs | TrainerCliArgs | None = None,
debug: bool = False, debug: bool = False,
) -> TrainDatasetMeta: ) -> TrainDatasetMeta:
""" """Loads one or more training or evaluation datasets, calling
Loads one or more training or evaluation datasets, calling `axolotl.utils.data.prepare_datasets`. Optionally, logs out debug information.
`axolotl.utils.data.prepare_dataset`. Optionally, logs out debug information.
Args: Args:
cfg: Dictionary mapping `axolotl` config keys to values. cfg: Dictionary mapping `axolotl` config keys to values.
cli_args: Command-specific CLI arguments. cli_args: Command-specific CLI arguments.
debug: Whether to print out tokenization of sample debug: Whether to print out tokenization of sample. This is duplicated in
`cfg` and `cli_args`, but is kept due to use in our Colab notebooks.
Returns: Returns:
Dataclass with fields for training and evaluation datasets and the computed Dataclass with fields for training and evaluation datasets and the computed
`total_num_steps`. `total_num_steps`.
""" """
tokenizer = load_tokenizer(cfg) tokenizer = load_tokenizer(cfg)
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
preprocess_iterable = ( preprocess_iterable = getattr(cli_args, "iterable", False)
cli_args
and hasattr(cli_args, "iterable")
and cli_args.iterable is not None
and cli_args.iterable
)
train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset( train_dataset, eval_dataset, total_num_steps, prompters = prepare_datasets(
cfg, cfg,
tokenizer, tokenizer,
processor=processor, processor=processor,
preprocess_iterable=preprocess_iterable, preprocess_iterable=preprocess_iterable,
) )
if ( # pylint: disable=too-many-boolean-expressions if (
cli_args cfg.debug
and ( or getattr(cli_args, "debug", False)
cli_args.debug or getattr(cli_args, "debug_text_only", False)
or cfg.debug or getattr(cli_args, "debug_num_examples", 0) > 0
or cli_args.debug_text_only or debug
or int(cli_args.debug_num_examples) > 0 ):
)
) or debug:
LOG.info("check_dataset_labels...") LOG.info("check_dataset_labels...")
num_examples = cli_args.debug_num_examples if cli_args else 1 num_examples = cli_args.debug_num_examples if cli_args else 1
@@ -113,13 +95,10 @@ def load_datasets(
def load_preference_datasets( def load_preference_datasets(
*, *, cfg: DictDefault, cli_args: PreprocessCliArgs | TrainerCliArgs
cfg: DictDefault,
cli_args: Union[PreprocessCliArgs, TrainerCliArgs],
) -> TrainDatasetMeta: ) -> TrainDatasetMeta:
""" """Loads one or more training or evaluation datasets for RL training using paired
Loads one or more training or evaluation datasets for RL training using paired preference data, calling `axolotl.utils.data.rl.prepare_preference_datasets`.
preference data, calling `axolotl.utils.data.rl.load_prepare_preference_datasets`.
Optionally, logs out debug information. Optionally, logs out debug information.
Args: Args:
@@ -130,12 +109,14 @@ def load_preference_datasets(
Dataclass with fields for training and evaluation datasets and the computed Dataclass with fields for training and evaluation datasets and the computed
`total_num_steps`. `total_num_steps`.
""" """
train_dataset, eval_dataset = load_prepare_preference_datasets(cfg) tokenizer = load_tokenizer(cfg)
total_num_steps: Optional[int] = int( train_dataset, eval_dataset = prepare_preference_datasets(cfg, tokenizer)
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
) total_num_steps: int | None = None
if cfg.rl is RLType.GRPO: if cfg.rl is not RLType.GRPO:
total_num_steps = None total_num_steps = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
)
if cli_args.debug or cfg.debug: if cli_args.debug or cfg.debug:
LOG.info("check_dataset_labels...") LOG.info("check_dataset_labels...")
@@ -143,8 +124,8 @@ def load_preference_datasets(
tokenizer = load_tokenizer(cfg) tokenizer = load_tokenizer(cfg)
train_samples = sample_dataset(train_dataset, cli_args.debug_num_examples) train_samples = sample_dataset(train_dataset, cli_args.debug_num_examples)
check_dataset_labels( check_dataset_labels(
train_samples, dataset=train_samples,
tokenizer, tokenizer=tokenizer,
num_examples=cli_args.debug_num_examples, num_examples=cli_args.debug_num_examples,
text_only=cli_args.debug_text_only, text_only=cli_args.debug_text_only,
rl_mode=True, rl_mode=True,

View File

@@ -381,7 +381,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
elif "tokenizer" in sig.parameters: elif "tokenizer" in sig.parameters:
trainer_kwargs["tokenizer"] = self.tokenizer trainer_kwargs["tokenizer"] = self.tokenizer
if ( if (
not (trainer_cls in [AxolotlRewardTrainer, AxolotlPRMTrainer]) trainer_cls not in [AxolotlRewardTrainer, AxolotlPRMTrainer]
and self.cfg.datasets is not None and self.cfg.datasets is not None
): ):
trainer_kwargs["dataset_tags"] = [ trainer_kwargs["dataset_tags"] = [

View File

@@ -1,7 +1,6 @@
"""Module containing Dataset functionality""" """Module containing Dataset functionality"""
import os import os
from typing import List, Optional, Union
import torch import torch
from datasets import Dataset, IterableDataset from datasets import Dataset, IterableDataset
@@ -20,21 +19,21 @@ LOG = get_logger(__name__)
class TokenizedPromptDataset(Dataset): class TokenizedPromptDataset(Dataset):
""" """Dataset that returns tokenized prompts from a stream of text files.
Dataset that returns tokenized prompts from a stream of text files.
Args: Args:
prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for processing the data. prompt_tokenizer: The prompt tokenizing method for processing the data.
dataset (dataset.Dataset): Dataset with text files. dataset: Dataset with text files.
process_count (int): Number of processes to use for tokenizing. process_count: Number of processes to use for tokenizing.
keep_in_memory (bool): Whether to keep the tokenized dataset in memory. keep_in_memory: Whether to keep the tokenized dataset in memory.
""" """
def __init__( # pylint: disable=super-init-not-called def __init__( # pylint: disable=super-init-not-called
self, self,
prompt_tokenizer: PromptTokenizingStrategy, prompt_tokenizer: PromptTokenizingStrategy,
dataset: Dataset, dataset: Dataset,
process_count: Optional[int] = None, process_count: int | None = None,
keep_in_memory: Optional[bool] = False, keep_in_memory: bool | None = False,
**kwargs, **kwargs,
): ):
self.prompt_tokenizer = prompt_tokenizer self.prompt_tokenizer = prompt_tokenizer
@@ -76,14 +75,14 @@ class TokenizedPromptDataset(Dataset):
def wrap_dataset_for_tokenized_prompt( def wrap_dataset_for_tokenized_prompt(
prompt_tokenizer: PromptTokenizingStrategy, prompt_tokenizer: PromptTokenizingStrategy,
dataset: Union[Dataset, IterableDataset], dataset: Dataset | IterableDataset,
**kwargs, **kwargs,
): ):
if isinstance(dataset, IterableDataset): if isinstance(dataset, IterableDataset):
map_kwargs = {} map_kwargs = {}
if prompt_tokenizer.supports_batched: if prompt_tokenizer.supports_batched:
map_kwargs["batched"] = True map_kwargs["batched"] = True
features = dataset.features.keys() features = list(dataset.features.keys())
return dataset.map( return dataset.map(
prompt_tokenizer.tokenize_prompt, prompt_tokenizer.tokenize_prompt,
remove_columns=features, remove_columns=features,
@@ -94,12 +93,13 @@ def wrap_dataset_for_tokenized_prompt(
# TODO this isn't the best since it can't interleave datasets # TODO this isn't the best since it can't interleave datasets
class ConstantLengthDataset(IterableDataset): class ConstantLengthDataset(IterableDataset):
""" """Iterable dataset that returns constant length chunks of tokens from stream of
Iterable dataset that returns constant length chunks of tokens from stream of text files. text files.
Args:
tokenizer (Tokenizer): The processor used for processing the data. Args:
dataset (dataset.Dataset): Dataset with text files. tokenizer: The processor used for processing the data.
seq_length (int): Length of token sequences to return. dataset: Dataset with text files.
seq_length: Length of token sequences to return.
""" """
def __init__( # pylint: disable=super-init-not-called def __init__( # pylint: disable=super-init-not-called
@@ -110,7 +110,7 @@ class ConstantLengthDataset(IterableDataset):
): ):
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.concat_token_id = tokenizer.eos_token_id self.concat_token_id = tokenizer.eos_token_id
self.datasets: List[IterableDataset] = datasets self.datasets: list[IterableDataset] = datasets
self.seq_length = seq_length self.seq_length = seq_length
vocab_size = len(tokenizer.get_vocab()) vocab_size = len(tokenizer.get_vocab())
@@ -174,7 +174,10 @@ class ConstantLengthDataset(IterableDataset):
} }
else: else:
LOG.warning( LOG.warning(
f"dropping batch due to tensor size mismatch input_ids: {input_ids.size()}, labels: {labels.size()}, attention_mask: {attention_mask.size()}" "Dropping batch due to tensor size mismatch "
f"input_ids: {input_ids.size()}, "
f"labels: {labels.size()}, "
f"attention_mask: {attention_mask.size()}"
) )
buffer = { buffer = {
"input_ids": [], "input_ids": [],

View File

@@ -7,12 +7,14 @@ import transformers
from transformers import ( from transformers import (
AddedToken, AddedToken,
AutoTokenizer, AutoTokenizer,
PreTrainedTokenizer,
) )
from axolotl.integrations.base import PluginManager from axolotl.integrations.base import PluginManager
from axolotl.loaders.utils import get_linear_embedding_layers, load_model_config from axolotl.loaders.utils import get_linear_embedding_layers, load_model_config
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
from axolotl.utils.chat_templates import get_chat_template_from_config from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import ( from axolotl.utils.distributed import (
barrier, barrier,
is_local_main_process, is_local_main_process,
@@ -117,7 +119,7 @@ def modify_tokenizer_files(
return tokenizer_dir return tokenizer_dir
def load_tokenizer(cfg): def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
"""Load and configure the tokenizer based on the provided config.""" """Load and configure the tokenizer based on the provided config."""
model_config = load_model_config(cfg) model_config = load_model_config(cfg)
tokenizer_kwargs = {} tokenizer_kwargs = {}
@@ -207,11 +209,12 @@ def load_tokenizer(cfg):
) )
and k != "pad_token" and k != "pad_token"
): ):
lora_modules_to_save = ", ".join( lora_modules_to_save_str = ", ".join(
[f"`{x}`" for x in lora_modules_to_save] [f"`{x}`" for x in lora_modules_to_save]
) )
raise ValueError( raise ValueError(
f"Please set lora_modules_to_save to [{lora_modules_to_save}] when using an adapter and changing the special tokens." f"Please set lora_modules_to_save to [{lora_modules_to_save_str}] "
"when using an adapter and changing the special tokens."
) )
tokenizer.add_special_tokens( tokenizer.add_special_tokens(

View File

@@ -32,4 +32,3 @@ def load(tokenizer, cfg, ds_cfg, processor=None):
except Exception as exc: # pylint: disable=broad-exception-caught except Exception as exc: # pylint: disable=broad-exception-caught
LOG.error(f"Failed to load prompt strategy `{strategy}`: {str(exc)}") LOG.error(f"Failed to load prompt strategy `{strategy}`: {str(exc)}")
raise exc raise exc
return None

View File

@@ -3,6 +3,7 @@
import abc import abc
from typing import Callable, Dict, List, Optional, Tuple, Union from typing import Callable, Dict, List, Optional, Tuple, Union
from datasets import Dataset
from transformers import BatchEncoding, PreTrainedTokenizer from transformers import BatchEncoding, PreTrainedTokenizer
from axolotl.prompters import Prompter from axolotl.prompters import Prompter
@@ -28,6 +29,16 @@ class DatasetWrappingStrategy(abc.ABC):
Abstract class for wrapping datasets for Chat Messages Abstract class for wrapping datasets for Chat Messages
""" """
@abc.abstractmethod
def wrap_dataset(
self,
dataset,
process_count: int | None = None,
keep_in_memory: bool | None = False,
**kwargs,
) -> Dataset:
pass
class PromptTokenizingStrategy(abc.ABC): class PromptTokenizingStrategy(abc.ABC):
""" """

View File

@@ -53,8 +53,8 @@ def setup_model_and_tokenizer(
) -> tuple[ ) -> tuple[
PreTrainedModel, PreTrainedTokenizer, PeftConfig | None, ProcessorMixin | None PreTrainedModel, PreTrainedTokenizer, PeftConfig | None, ProcessorMixin | None
]: ]:
""" """Load the tokenizer, processor (for multimodal models), and model based on
Load the tokenizer, processor (for multimodal models), and model based on configuration. configuration.
Args: Args:
cfg: Dictionary mapping `axolotl` config keys to values. cfg: Dictionary mapping `axolotl` config keys to values.

View File

@@ -1,16 +1,21 @@
""" """Init for `axolotl.utils.data` module."""
Data processing modules
"""
from axolotl.utils.data.pretraining import ( # noqa: F401 from axolotl.utils.data.pretraining import (
encode_pretraining, encode_pretraining,
wrap_pretraining_dataset, wrap_pretraining_dataset,
) )
from axolotl.utils.data.rl import load_prepare_preference_datasets # noqa: F401 from axolotl.utils.data.rl import prepare_preference_datasets
from axolotl.utils.data.sft import ( # noqa: F401 from axolotl.utils.data.sft import (
get_dataset_wrapper, get_dataset_wrapper,
load_prepare_datasets, prepare_datasets,
load_tokenized_prepared_datasets,
prepare_dataset,
) )
from axolotl.utils.data.utils import md5 # noqa: F401 from axolotl.utils.data.utils import md5
__all__ = [
"encode_pretraining",
"wrap_pretraining_dataset",
"prepare_preference_datasets",
"get_dataset_wrapper",
"prepare_datasets",
"md5",
]

View File

@@ -0,0 +1,66 @@
"""Logic for loading / preparing a dataset once over all processes."""
import time
from pathlib import Path
from typing import Any, Callable
from filelock import FileLock
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.utils.dict import DictDefault
LOCK_FILE_NAME = "datasets_prep.lock"
READY_FILE_NAME = "datasets_ready.flag"
PROCESS_COUNTER_FILE_NAME = "process_counter.txt"
class FileLockLoader:
"""
Simple class for abstracting single process data loading / processing. The first
process that creates a lock file does the work; the remaining procesees simply load
the preprocessed dataset once the first process is done.
"""
def __init__(self, cfg: DictDefault):
self.cfg = cfg
self.dataset_prepared_path = (
cfg.dataset_prepared_path or DEFAULT_DATASET_PREPARED_PATH
)
self.lock_file_path = Path(self.dataset_prepared_path) / LOCK_FILE_NAME
self.ready_flag_path = Path(self.dataset_prepared_path) / READY_FILE_NAME
self.counter_path = Path(self.dataset_prepared_path) / PROCESS_COUNTER_FILE_NAME
def load(self, load_fn: Callable[[], Any]) -> Any:
with FileLock(str(self.lock_file_path)):
self._increment_counter()
if not self.ready_flag_path.exists():
result = load_fn()
self.ready_flag_path.touch()
return result
while not self.ready_flag_path.exists():
time.sleep(1)
return load_fn()
def _increment_counter(self):
"""Safely increment the process counter."""
if self.counter_path.exists():
count = int(self.counter_path.read_text().strip())
else:
count = 0
self.counter_path.write_text(str(count + 1))
def cleanup(self):
"""Clean up ready flag when last process is done."""
with FileLock(str(self.lock_file_path)):
count = int(self.counter_path.read_text().strip())
count -= 1
if count == 0:
# Last process cleans everything up
self.ready_flag_path.unlink(missing_ok=True)
self.counter_path.unlink(missing_ok=True)
else:
# Still have active processes
self.counter_path.write_text(str(count))

View File

@@ -250,7 +250,7 @@ def encode_packed_pretraining(
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
# tokenize all the examples # tokenize all the examples
# rows get split with stride (overlap) # rows get split with stride (overlap)
train_dataset = ds_wrapper(Dataset.from_dict(examples))[0] train_dataset = ds_wrapper(dataset=Dataset.from_dict(examples))[0]
train_dataset = process_pretraining_datasets_for_packing( train_dataset = process_pretraining_datasets_for_packing(
train_dataset, train_dataset,

View File

@@ -1,75 +1,117 @@
"""data handling specific to DPO""" """Data handling specific to RL trainers."""
import inspect import inspect
from functools import partial from functools import partial
from pathlib import Path from typing import Any, Callable, Literal
from typing import Any, List, Union
import yaml from datasets import Dataset, DatasetDict
from datasets import Dataset, DatasetDict, concatenate_datasets, load_from_disk from transformers import PreTrainedTokenizer
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.loaders import load_tokenizer from axolotl.loaders import load_tokenizer
from axolotl.prompt_strategies.dpo import load as load_dpo from axolotl.prompt_strategies.dpo import load as load_dpo
from axolotl.prompt_strategies.kto import load as load_kto from axolotl.prompt_strategies.kto import load as load_kto
from axolotl.prompt_strategies.orpo import load as load_orpo from axolotl.prompt_strategies.orpo import load as load_orpo
from axolotl.utils.data.shared import datasets_w_name_generator, load_dataset_w_config from axolotl.utils.data.lock import FileLockLoader
from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5 from axolotl.utils.data.shared import (
create_train_validation_split,
datasets_with_name_generator,
generate_dataset_hash_from_config,
load_dataset_with_config,
load_preprocessed_dataset,
merge_datasets,
save_preprocessed_dataset,
try_load_from_hub,
)
from axolotl.utils.data.utils import (
deduplicate_and_log_datasets,
retry_on_request_exceptions,
)
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process, zero_first
from axolotl.utils.logging import get_logger from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.enums import RLType from axolotl.utils.schemas.enums import RLType
LOG = get_logger(__name__) LOG = get_logger(__name__)
def _get_path(ds_hash, cfg): @retry_on_request_exceptions(max_retries=3, delay=5)
prepared_ds_path = ( def prepare_preference_datasets(
Path(cfg.dataset_prepared_path) / ds_hash cfg: DictDefault, tokenizer: PreTrainedTokenizer
if cfg.dataset_prepared_path ) -> tuple[Dataset, Dataset | None]:
else Path(DEFAULT_DATASET_PREPARED_PATH) / ds_hash """Load and prepare preference datasets for RL training.
)
return prepared_ds_path Loads training and evaluation datasets, handling preprocessing, caching, and
deduplication as configured. Uses FileLock for distributed coordination.
Args:
cfg: Configuration object containing dataset and training settings.
tokenizer: Tokenizer to use for processing text.
Returns:
Tuple of (train_dataset, eval_dataset). eval_dataset may be None
if no evaluation dataset is configured.
"""
def _load_datasets():
# Load training dataset
train_dataset = _load_or_create_dataset_split(cfg, tokenizer, split="train")
# Load or create evaluation dataset
eval_dataset: Dataset | None = None
if cfg.test_datasets:
eval_dataset = _load_or_create_dataset_split(cfg, tokenizer, split="test")
elif cfg.val_set_size:
# Create validation split from training data
train_dataset, eval_dataset = create_train_validation_split(
train_dataset, cfg, cfg.val_set_size
)
return train_dataset, eval_dataset
# Prepare datasets (with file locking logic for multiple ranks)
loader = FileLockLoader(cfg)
try:
train_dataset, eval_dataset = loader.load(_load_datasets)
finally:
loader.cleanup()
# Apply deduplication if configured
if cfg.dataset_exact_deduplication:
train_dataset, eval_dataset = deduplicate_and_log_datasets(
dataset=train_dataset, other_dataset=eval_dataset
)
return train_dataset, eval_dataset
def _load_preprocessed_ds(cfg, sub_cfg): def _map_dataset(
ds_hash = md5(yaml.dump(sub_cfg, Dumper=yaml.Dumper)) cfg: DictDefault,
prepared_ds_path = _get_path(ds_hash, cfg) dataset: Dataset | DatasetDict,
dataset = None ds_transform_fn: Callable[..., Any],
tokenizer: Any | None = None,
**map_kwargs: Any,
) -> Dataset:
"""Apply transformation function to dataset.
# pylint: disable=duplicate-code Args:
if ( cfg: Configuration object.
cfg.dataset_prepared_path dataset: Dataset to transform.
and any(prepared_ds_path.glob("*")) ds_transform_fn: Transformation function to apply.
and not cfg.is_preprocess tokenizer: Optional tokenizer for transformation.
): **map_kwargs: Additional arguments for dataset mapping.
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
dataset = load_from_disk(str(prepared_ds_path))
return dataset Returns:
Transformed dataset.
"""
def _save_preprocessed_ds(cfg, sub_cfg, dataset):
ds_hash = md5(yaml.dump(sub_cfg, Dumper=yaml.Dumper))
prepared_ds_path = _get_path(ds_hash, cfg)
if cfg.is_preprocess and is_main_process():
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
dataset.save_to_disk(str(prepared_ds_path))
def map_dataset(cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs):
sig = inspect.signature(ds_transform_fn) sig = inspect.signature(ds_transform_fn)
if "tokenizer" in sig.parameters: if "tokenizer" in sig.parameters:
if not tokenizer: if not tokenizer:
tokenizer = load_tokenizer(cfg) tokenizer = load_tokenizer(cfg)
ds_transform_fn = partial(ds_transform_fn, tokenizer=tokenizer) ds_transform_fn = partial(ds_transform_fn, tokenizer=tokenizer)
if isinstance(data_set, DatasetDict): if isinstance(dataset, DatasetDict):
data_set = data_set["train"] dataset = dataset["train"]
data_set = data_set.map( dataset = dataset.map(
ds_transform_fn, ds_transform_fn,
num_proc=cfg.dataset_processes, num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess, load_from_cache_file=not cfg.is_preprocess,
@@ -77,13 +119,27 @@ def map_dataset(cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs):
**map_kwargs, **map_kwargs,
) )
return data_set return dataset
def drop_long_rl_seq( def _drop_long_sequences(
sample, rl, tokenizer, sequence_len # pylint: disable=invalid-name sample: dict[str, Any], rl: RLType, tokenizer: Any, sequence_len: int
): ) -> bool:
if rl in (RLType.DPO, RLType.IPO, RLType.ORPO, RLType.SIMPO): """Filter out samples that exceed maximum sequence length.
Args:
sample: Dataset sample to check.
rl: Reinforcement learning type.
tokenizer: Tokenizer for length calculation.
sequence_len: Maximum allowed sequence length.
Returns:
True if sample should be kept, False if it should be dropped.
Raises:
ValueError: If required keys are missing or RL type is unknown.
"""
if rl in {RLType.DPO, RLType.IPO, RLType.ORPO, RLType.SIMPO}:
if not ( if not (
sample.get("prompt") and sample.get("chosen") and sample.get("rejected") sample.get("prompt") and sample.get("chosen") and sample.get("rejected")
): ):
@@ -123,132 +179,115 @@ def drop_long_rl_seq(
raise ValueError("Unknown RL type") raise ValueError("Unknown RL type")
def load_prepare_preference_datasets(cfg): def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
def load_split(dataset_cfgs, _cfg): """Load and process dataset split for RL training.
split_datasets: List[Any] = []
use_auth_token = _cfg.hf_use_auth_token
for config_dataset in datasets_w_name_generator(dataset_cfgs):
ds: Union[Dataset, DatasetDict] = load_dataset_w_config(
config_dataset, use_auth_token, streaming=False
)
split_datasets.append(ds)
tokenizer = load_tokenizer(cfg) Args:
cfg: Configuration object containing dataset settings.
split: Dataset split to load ("train" or "test").
for i, data_set in enumerate(split_datasets): Returns:
_type = dataset_cfgs[i]["type"] Combined and processed dataset for the specified split.
if _type: """
if isinstance(_type, DictDefault): datasets_configs = cfg.datasets if split == "train" else cfg.test_datasets
_type = "user_defined.default" split_datasets: list[Dataset | DatasetDict] = []
if _cfg.rl is RLType.ORPO:
ds_transform_fn = load_orpo(_type, _cfg, dataset_idx=i)
elif _cfg.rl is RLType.KTO:
ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i)
else:
ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
map_kwargs = {} for dataset_config in datasets_with_name_generator(datasets_configs):
if isinstance(ds_transform_fn, tuple): dataset: Dataset | DatasetDict = load_dataset_with_config(
ds_transform_fn, map_kwargs = ds_transform_fn dataset_config, cfg.hf_use_auth_token, streaming=False
split_datasets[i] = map_dataset(
cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs
)
elif _cfg.rl is RLType.KTO:
ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i)
map_kwargs = {}
if isinstance(ds_transform_fn, tuple):
ds_transform_fn, map_kwargs = ds_transform_fn
split_datasets[i] = map_dataset(
cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs
)
else:
# If no `type` is provided, assume the dataset is already in the expected format with
# "prompt", "chosen" and "rejected" already preprocessed
split_datasets[i] = data_set
if not cfg.skip_prepare_dataset:
drop_long = partial(
drop_long_rl_seq,
rl=_cfg.rl,
tokenizer=tokenizer,
sequence_len=cfg.sequence_len,
)
prior_len = len(split_datasets[i])
split_datasets[i] = split_datasets[i].filter(
drop_long,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences",
)
dropped = prior_len - len(split_datasets[i])
if dropped:
LOG.warning(
f"Dropped {dropped} long samples from dataset index {i}"
)
combined_datasets = concatenate_datasets(split_datasets)
combined_datasets = combined_datasets.shuffle(seed=cfg.seed or 42)
return combined_datasets
with zero_first(is_main_process()):
train_is_preprocessed = False
eval_is_preprocessed = False
if train_dataset := _load_preprocessed_ds(cfg, cfg.datasets):
train_is_preprocessed = True
else:
train_dataset = load_split(cfg.datasets, cfg)
eval_dataset = None
if cfg.test_datasets:
if eval_dataset := _load_preprocessed_ds(cfg, cfg.test_datasets):
eval_is_preprocessed = True
else:
eval_dataset = load_split(cfg.test_datasets, cfg)
if not eval_dataset:
if cfg.val_set_size:
seed = cfg.seed if cfg.seed is not None else 42
# ensure we end up with the same fingerprint by doing rank0 first and being able to cache
to_hash_train = (
train_dataset._fingerprint # pylint: disable=protected-access
+ "|"
+ str(cfg.val_set_size)
+ "|"
+ "train"
+ "|"
+ str(cfg.seed or 42)
)
to_hash_test = (
train_dataset._fingerprint # pylint: disable=protected-access
+ "|"
+ str(cfg.val_set_size)
+ "|"
+ "test"
+ "|"
+ str(cfg.seed or 42)
)
train_fingerprint = md5(to_hash_train)
test_fingerprint = md5(to_hash_test)
ds_w_test_split = train_dataset.train_test_split(
test_size=cfg.val_set_size,
seed=seed,
shuffle=False,
train_new_fingerprint=train_fingerprint,
test_new_fingerprint=test_fingerprint,
)
eval_dataset = ds_w_test_split["test"]
train_dataset = ds_w_test_split["train"]
if not train_is_preprocessed:
_save_preprocessed_ds(cfg, cfg.datasets, train_dataset)
if eval_dataset and not eval_is_preprocessed:
_save_preprocessed_ds(cfg, cfg.test_datasets, eval_dataset)
if cfg.dataset_exact_deduplication:
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
train_dataset=train_dataset, eval_dataset=eval_dataset
) )
split_datasets.append(dataset)
return train_dataset, eval_dataset tokenizer = load_tokenizer(cfg)
for i, dataset in enumerate(split_datasets):
_type = datasets_configs[i]["type"]
if _type:
if isinstance(_type, DictDefault):
_type = "user_defined.default"
if cfg.rl is RLType.ORPO:
ds_transform_fn = load_orpo(_type, cfg, dataset_idx=i)
elif cfg.rl is RLType.KTO:
ds_transform_fn = load_kto(_type, cfg, dataset_idx=i)
else:
ds_transform_fn = load_dpo(_type, cfg, dataset_idx=i)
map_kwargs: dict[str, Any] = {}
if isinstance(ds_transform_fn, tuple):
ds_transform_fn, map_kwargs = ds_transform_fn
split_datasets[i] = _map_dataset(
cfg, dataset, ds_transform_fn, tokenizer, **map_kwargs
)
else:
# If no `type` is provided, assume the dataset is already in the expected format with
# "prompt", "chosen", and "rejected" already preprocessed
split_datasets[i] = dataset
if not cfg.skip_prepare_dataset:
drop_long = partial(
_drop_long_sequences,
rl=cfg.rl,
tokenizer=tokenizer,
sequence_len=cfg.sequence_len,
)
prior_len = len(split_datasets[i])
split_datasets[i] = split_datasets[i].filter(
drop_long,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences",
)
dropped = prior_len - len(split_datasets[i])
if dropped:
LOG.warning(f"Dropped {dropped} long samples from dataset index {i}")
# Merge datasets
dataset = merge_datasets(split_datasets, cfg)
if not cfg.skip_prepare_dataset:
# Save preprocessed dataset
dataset_hash = generate_dataset_hash_from_config(
cfg, datasets_configs, tokenizer.name_or_path
)
save_preprocessed_dataset(cfg, dataset, dataset_hash, split)
return dataset
# pylint: disable=duplicate-code
def _load_or_create_dataset_split(
cfg: DictDefault, tokenizer: PreTrainedTokenizer, split: Literal["train", "test"]
) -> Dataset:
"""Load preprocessed dataset or create new one for given split.
Args:
cfg: Configuration object.
tokenizer: Tokenizer to use for processing text.
split: Dataset split to load.
Returns:
Tuple of (dataset, is_preprocessed).
"""
# Select correct dataset configuration based on split
datasets_config = cfg.datasets if split == "train" else cfg.test_datasets
# Generate dataset hash for caching
dataset_hash = generate_dataset_hash_from_config(
cfg, datasets_config, tokenizer.name_or_path
)
# Try loading from hub if push_dataset_to_hub is configured
dataset = None
if cfg.push_dataset_to_hub:
dataset = try_load_from_hub(cfg, dataset_hash, split)
# Attempt to load preprocessed dataset
if dataset is None:
dataset = load_preprocessed_dataset(cfg, dataset_hash)
# Otherwise, load it
if dataset is None:
dataset = _load_split(cfg, split=split)
return dataset

File diff suppressed because it is too large Load Diff

View File

@@ -1,11 +1,21 @@
""" """Dataset loading shared utils."""
dataset loading shared utils
"""
from __future__ import annotations
import functools
import os
from pathlib import Path from pathlib import Path
from typing import Optional, Union from typing import TYPE_CHECKING, Any, Generator
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk from datasets import (
Dataset,
DatasetDict,
IterableDataset,
IterableDatasetDict,
concatenate_datasets,
load_dataset,
load_from_disk,
)
from huggingface_hub import hf_hub_download, snapshot_download from huggingface_hub import hf_hub_download, snapshot_download
from huggingface_hub.errors import ( from huggingface_hub.errors import (
HFValidationError, HFValidationError,
@@ -13,78 +23,141 @@ from huggingface_hub.errors import (
RevisionNotFoundError, RevisionNotFoundError,
) )
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
if TYPE_CHECKING:
from adlfs import AzureBlobFileSystem
from gcsfs import GCSFileSystem
from ocifs import OCIFileSystem
from s3fs import S3FileSystem
LOG = get_logger(__name__)
EXTENSIONS_TO_DATASET_TYPES = {
".parquet": "parquet",
".arrow": "arrow",
".csv": "csv",
".txt": "text",
}
def get_ds_type(config_dataset: DictDefault): def get_dataset_type(dataset_config: DictDefault) -> str:
""" """Get the dataset type from the path if it's not specified."""
Get the dataset type from the path if it's not specified if dataset_config.ds_type:
""" return dataset_config.ds_type
ds_type = "json"
if config_dataset.ds_type: for extension, dataset_type in EXTENSIONS_TO_DATASET_TYPES.items():
ds_type = config_dataset.ds_type if extension in dataset_config.path:
elif ".parquet" in config_dataset.path: return dataset_type
ds_type = "parquet"
elif ".arrow" in config_dataset.path: return "json"
ds_type = "arrow"
elif ".csv" in config_dataset.path:
ds_type = "csv"
elif ".txt" in config_dataset.path:
ds_type = "text"
return ds_type
def datasets_w_name_generator(dataset_configs: list[DictDefault]): def datasets_with_name_generator(
""" dataset_configs: list[DictDefault],
Yields dataset configs handling multiple names or preprocess_shards ) -> Generator[DictDefault, None, None]:
"""Yields expanded dataset configurations based on multiple names or preprocessing
shards.
When a dataset config has a list of names, it yields separate configs for each
name. When a dataset config specifies preprocessing shards, it yields configs for
each shard.
Args: Args:
dataset_configs: list of dataset configs (equivalent to cfg.datasets) dataset_configs: List of dataset configuration objects.
Yields:
Individual dataset configurations, expanded as needed for names or shards.
""" """
for dataset in dataset_configs: for config in dataset_configs:
if dataset.name and isinstance(dataset.name, list): if config.name and isinstance(config.name, list):
# load_dataset doesn't properly handle multiple named configurations for name in config.name:
# at the same time for a given dataset yield DictDefault({**config, "name": name})
for name in dataset.name: elif config.preprocess_shards and not config.shards:
yield DictDefault({**dataset, "name": name}) for shard_idx in range(config.preprocess_shards):
elif dataset.preprocess_shards and not dataset.shards:
for shard in range(dataset.preprocess_shards):
yield DictDefault( yield DictDefault(
{ {
**dataset, **config,
"shards": dataset.preprocess_shards, "shards": config.preprocess_shards,
"shards_idx": shard, "shards_idx": shard_idx,
} }
) )
else: else:
yield dataset yield config
def load_dataset_w_config( def load_dataset_with_config(
config_dataset: DictDefault, use_auth_token: bool, streaming=False dataset_config: DictDefault, use_auth_token: bool, streaming=False
) -> Union[Dataset, DatasetDict]: ) -> Dataset | IterableDataset:
""" """Load a dataset from a config. Handles datasets that are stored locally, in the
Load a dataset from a config HuggingFace Hub, in a remote filesystem (S3, GCS, Azure, OCI), a URL, or
`data_files`.
Args: Args:
config_dataset: single dataset config dataset_config: Single dataset config.
use_auth_token: whether to use HF auth token use_auth_token: Whether to use HF auth token.
streaming: whether to stream the dataset streaming: Whether to stream the dataset.
Returns:
Loaded dataset.
""" """
# pylint: disable=invalid-name # Set up common kwargs for dataset loading
ds: Optional[Union[Dataset, DatasetDict]] = None # pylint: disable=invalid-name load_dataset_kwargs = {
ds_from_hub = False "split": dataset_config.split if dataset_config.split else None,
"name": dataset_config.name,
"streaming": streaming,
"trust_remote_code": dataset_config.trust_remote_code,
}
# First check if it's a local path
if Path(dataset_config.path).exists():
return _load_from_local_path(dataset_config, load_dataset_kwargs)
# Check if it's a HuggingFace dataset
is_hub_dataset = _check_if_hub_dataset(dataset_config, use_auth_token)
# Check if it's a cloud storage path and get appropriate filesystem
remote_fs, storage_options = _get_remote_filesystem(dataset_config.path)
is_cloud_dataset = False
if remote_fs:
try:
is_cloud_dataset = remote_fs.exists(dataset_config.path)
except (FileNotFoundError, ConnectionError):
pass
# Load from appropriate source
if is_hub_dataset:
return _load_from_hub(dataset_config, use_auth_token, load_dataset_kwargs)
if is_cloud_dataset:
return _load_from_cloud(
dataset_config, remote_fs, storage_options, load_dataset_kwargs
)
if dataset_config.path.startswith("https://"):
return _load_from_url(dataset_config, load_dataset_kwargs)
if dataset_config.data_files:
return _load_from_data_files(dataset_config, load_dataset_kwargs)
raise ValueError(
f"The dataset could not be loaded. This could be due to a misconfigured dataset path "
f"({dataset_config.path}). Try double-check your path / name / data_files. "
f"This is not caused by the dataset type."
)
def _check_if_hub_dataset(dataset_config: DictDefault, use_auth_token: bool) -> bool:
"""Check if a dataset exists on the HuggingFace Hub."""
try: try:
# this is just a basic check to see if the path is a
# valid HF dataset that's loadable
snapshot_download( snapshot_download(
repo_id=config_dataset.path, repo_id=dataset_config.path,
repo_type="dataset", repo_type="dataset",
token=use_auth_token, token=use_auth_token,
revision=config_dataset.revision, revision=dataset_config.revision,
ignore_patterns=["*"], ignore_patterns=["*"],
) )
ds_from_hub = True return True
except ( except (
RepositoryNotFoundError, RepositoryNotFoundError,
RevisionNotFoundError, RevisionNotFoundError,
@@ -93,198 +166,373 @@ def load_dataset_w_config(
HFValidationError, HFValidationError,
ValueError, ValueError,
): ):
pass return False
ds_from_cloud = False
storage_options: dict = {} def _get_remote_filesystem(
remote_file_system = None path: str,
if config_dataset.path.startswith("s3://"): ) -> tuple[
S3FileSystem | GCSFileSystem | AzureBlobFileSystem | OCIFileSystem | None, dict
]:
"""Get the appropriate filesystem for a remote path."""
if path.startswith("s3://"):
try: try:
import s3fs # type: ignore import s3fs
storage_options = {"anon": False}
return s3fs.S3FileSystem(**storage_options), storage_options
except ImportError as exc: except ImportError as exc:
raise ImportError("s3:// paths require s3fs to be installed") from exc raise ImportError("s3:// paths require s3fs to be installed") from exc
# Reads env, credentials from ~/.aws/credentials, or IAM metadata provider elif path.startswith(("gs://", "gcs://")):
# https://s3fs.readthedocs.io/en/latest/index.html?highlight=storage_options#credentials
storage_options = {"anon": False}
remote_file_system = s3fs.S3FileSystem(**storage_options)
elif config_dataset.path.startswith("gs://") or config_dataset.path.startswith(
"gcs://"
):
try: try:
import gcsfs # type: ignore import gcsfs
storage_options = {"token": None} # type: ignore
return gcsfs.GCSFileSystem(**storage_options), storage_options
except ImportError as exc: except ImportError as exc:
raise ImportError( raise ImportError(
"gs:// or gcs:// paths require gcsfs to be installed" "gs:// or gcs:// paths require gcsfs to be installed"
) from exc ) from exc
# gcsfs will use default credentials from the environment else anon elif path.startswith(("adl://", "abfs://", "az://")):
# https://gcsfs.readthedocs.io/en/latest/#credentials
storage_options = {"token": None}
remote_file_system = gcsfs.GCSFileSystem(**storage_options)
elif (
config_dataset.path.startswith("adl://")
or config_dataset.path.startswith("abfs://")
or config_dataset.path.startswith("az://")
):
try: try:
import adlfs import adlfs
storage_options = {"anon": False}
return adlfs.AzureBlobFileSystem(**storage_options), storage_options
except ImportError as exc: except ImportError as exc:
raise ImportError( raise ImportError(
"adl:// or abfs:// paths require adlfs to be installed" "adl:// or abfs:// paths require adlfs to be installed"
) from exc ) from exc
# # Ensure you have the following environment variables set: elif path.startswith("oci://"):
# # Gen 1
# storage_options = {
# "tenant_id": AZURE_STORAGE_TENANT_ID,
# "client_id": AZURE_STORAGE_CLIENT_ID,
# "client_secret": AZURE_STORAGE_CLIENT_SECRET,
# }
# # Gen 2
# storage_options = {
# "account_name": AZURE_STORAGE_ACCOUNT_NAME,
# "account_key": AZURE_STORAGE_ACCOUNT_KEY,
# }
# Reads env
# https://github.com/fsspec/adlfs?tab=readme-ov-file#setting-credentials
storage_options = {"anon": False}
remote_file_system = adlfs.AzureBlobFileSystem(**storage_options)
elif config_dataset.path.startswith("oci://"):
try: try:
import ocifs import ocifs
storage_options = {}
return ocifs.OCIFileSystem(**storage_options), storage_options
except ImportError as exc: except ImportError as exc:
raise ImportError("oci:// paths require ocifs to be installed") from exc raise ImportError("oci:// paths require ocifs to be installed") from exc
# https://ocifs.readthedocs.io/en/latest/getting-connected.html#Using-Environment-Variables return None, {}
remote_file_system = ocifs.OCIFileSystem(**storage_options)
try:
if remote_file_system and remote_file_system.exists(config_dataset.path):
ds_from_cloud = True
except (FileNotFoundError, ConnectionError):
pass
# gather extra args from the config def _load_from_local_path(
load_ds_kwargs = {} dataset_config: DictDefault, load_dataset_kwargs: dict
if config_dataset.split: ) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict:
load_ds_kwargs["split"] = config_dataset.split """Load a dataset from a local path."""
local_path = Path(dataset_config.path)
if local_path.is_dir():
if dataset_config.data_files:
dataset_type = get_dataset_type(dataset_config)
return load_dataset(
dataset_type,
data_files=dataset_config.data_files,
**load_dataset_kwargs,
)
try:
return load_from_disk(dataset_config.path)
except FileNotFoundError:
load_dataset_kwargs["streaming"] = False
return load_dataset(dataset_config.path, **load_dataset_kwargs)
elif local_path.is_file():
dataset_type = get_dataset_type(dataset_config)
load_dataset_kwargs["streaming"] = False
return load_dataset(
dataset_type,
data_files=dataset_config.path,
**load_dataset_kwargs,
)
else: else:
load_ds_kwargs["split"] = None
# prefer local dataset, even if hub exists
local_path = Path(config_dataset.path)
if local_path.exists():
if local_path.is_dir():
if config_dataset.data_files:
ds_type = get_ds_type(config_dataset)
ds = load_dataset( # pylint: disable=invalid-name
ds_type,
name=config_dataset.name,
data_files=config_dataset.data_files,
streaming=streaming,
**load_ds_kwargs,
)
else:
try:
ds = load_from_disk(
config_dataset.path
) # pylint: disable=invalid-name
except FileNotFoundError:
ds = load_dataset(
config_dataset.path,
name=config_dataset.name,
streaming=False,
**load_ds_kwargs,
)
elif local_path.is_file():
ds_type = get_ds_type(config_dataset)
ds = load_dataset( # pylint: disable=invalid-name
ds_type,
name=config_dataset.name,
data_files=config_dataset.path,
streaming=False,
**load_ds_kwargs,
)
else:
raise ValueError(
"unhandled dataset load: local path exists, but is neither a directory or a file"
)
elif ds_from_hub:
ds = load_dataset(
config_dataset.path,
name=config_dataset.name,
streaming=streaming,
data_files=config_dataset.data_files,
token=use_auth_token,
revision=config_dataset.revision,
trust_remote_code=config_dataset.trust_remote_code,
**load_ds_kwargs,
)
elif ds_from_cloud and remote_file_system:
if remote_file_system.isdir(config_dataset.path):
ds = load_from_disk(
config_dataset.path,
storage_options=storage_options,
)
elif remote_file_system.isfile(config_dataset.path):
ds_type = get_ds_type(config_dataset)
ds = load_dataset(
ds_type,
name=config_dataset.name,
data_files=config_dataset.path,
streaming=streaming,
storage_options=storage_options,
trust_remote_code=config_dataset.trust_remote_code,
**load_ds_kwargs,
)
elif config_dataset.path.startswith("https://"):
ds_type = get_ds_type(config_dataset)
ds = load_dataset(
ds_type,
name=config_dataset.name,
data_files=config_dataset.path,
streaming=streaming,
storage_options=storage_options,
trust_remote_code=config_dataset.trust_remote_code,
**load_ds_kwargs,
)
elif config_dataset.data_files:
fp: str | list[str] | None = None
if isinstance(config_dataset.data_files, str):
fp = hf_hub_download(
repo_id=config_dataset.path,
repo_type="dataset",
filename=config_dataset.data_files,
revision=config_dataset.revision,
)
elif isinstance(config_dataset.data_files, list):
fp = []
for file in config_dataset.data_files:
fp.append(
hf_hub_download(
repo_id=config_dataset.path,
repo_type="dataset",
filename=file,
revision=config_dataset.revision,
)
)
else:
raise ValueError("data_files must be either a string or list of strings")
ds = load_dataset(
"json",
name=config_dataset.name,
data_files=fp,
streaming=streaming,
**load_ds_kwargs,
)
if not ds:
raise ValueError( raise ValueError(
"The dataset could not be loaded. This could be due to a misconfigured dataset path " "Unhandled dataset load: local path exists, but is neither a directory or a file"
f"({config_dataset.path}). Try double-check your path / name / data_files. "
"This is not caused by the dataset type."
) )
return ds
def _load_from_hub(
dataset_config: DictDefault, use_auth_token: bool, load_dataset_kwargs: dict
) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict:
"""Load a dataset from the HuggingFace Hub."""
return load_dataset(
dataset_config.path,
data_files=dataset_config.data_files,
token=use_auth_token,
revision=dataset_config.revision,
**load_dataset_kwargs,
)
def _load_from_cloud(
dataset_config: DictDefault,
remote_fs: S3FileSystem | GCSFileSystem | AzureBlobFileSystem | OCIFileSystem,
storage_options: dict,
load_dataset_kwargs: dict,
) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict:
"""Load a dataset from cloud storage."""
if remote_fs.isdir(dataset_config.path):
return load_from_disk(
dataset_config.path,
storage_options=storage_options,
)
if remote_fs.isfile(dataset_config.path):
dataset_type = get_dataset_type(dataset_config)
return load_dataset(
dataset_type,
data_files=dataset_config.path,
storage_options=storage_options,
**load_dataset_kwargs,
)
raise ValueError(
f"Cloud path {dataset_config.path} is neither a directory nor a file"
)
def _load_from_url(
dataset_config: DictDefault, load_dataset_kwargs: dict
) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict:
"""Load a dataset from a URL."""
dataset_type = get_dataset_type(dataset_config)
return load_dataset(
dataset_type,
data_files=dataset_config.path,
**load_dataset_kwargs,
)
def _load_from_data_files(
dataset_config: DictDefault, load_dataset_kwargs: dict
) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict:
"""Load a dataset from data files."""
file_path = None
if isinstance(dataset_config.data_files, str):
file_path = hf_hub_download(
repo_id=dataset_config.path,
repo_type="dataset",
filename=dataset_config.data_files,
revision=dataset_config.revision,
)
elif isinstance(dataset_config.data_files, list):
file_path = [
hf_hub_download(
repo_id=dataset_config.path,
repo_type="dataset",
filename=file,
revision=dataset_config.revision,
)
for file in dataset_config.data_files
]
else:
raise ValueError("data_files must be either a string or list of strings")
return load_dataset("json", data_files=file_path, **load_dataset_kwargs)
def generate_split_fingerprints(
dataset: Dataset, val_set_size: int | float, seed: int
) -> tuple[str, str]:
"""Generate consistent fingerprints for train/test splits."""
fingerprint = dataset._fingerprint # pylint: disable=protected-access
train_hash_input = f"{fingerprint}|{val_set_size}|train|{seed}"
test_hash_input = f"{fingerprint}|{val_set_size}|test|{seed}"
train_fingerprint = md5(train_hash_input)
test_fingerprint = md5(test_hash_input)
return train_fingerprint, test_fingerprint
def get_prepared_dataset_path(cfg: DictDefault, dataset_hash: str) -> Path:
"""Get standardized path for prepared datasets.
Args:
cfg: Configuration object.
dataset_hash: Hash identifying the specific dataset configuration.
Returns:
Path where the prepared dataset should be stored.
"""
base_path = cfg.dataset_prepared_path or DEFAULT_DATASET_PREPARED_PATH
return Path(base_path) / dataset_hash
def create_train_validation_split(
dataset: Dataset, cfg: DictDefault, val_set_size: int | float
) -> tuple[Dataset, Dataset]:
"""Create train/validation split with consistent fingerprinting.
Args:
dataset: Dataset to split.
cfg: Configuration object containing seed and other settings.
val_set_size: Size of validation set (absolute number or fraction).
Returns:
Tuple of (train_dataset, eval_dataset).
"""
train_fingerprint, test_fingerprint = generate_split_fingerprints(
dataset, val_set_size, cfg.seed
)
# Apply deduplication before splitting if configured
if cfg.dataset_exact_deduplication:
dataset, _ = deduplicate_and_log_datasets(dataset=dataset)
split_dataset = dataset.train_test_split(
test_size=val_set_size,
shuffle=False,
seed=cfg.seed,
train_new_fingerprint=train_fingerprint,
test_new_fingerprint=test_fingerprint,
)
return split_dataset["train"], split_dataset["test"]
def _generate_from_iterable_dataset(
dataset: IterableDataset, worker_id: list[int], num_workers: list[int]
) -> Generator[Any, None, None]:
"""Generator function to correctly split the dataset for each worker"""
for i, item in enumerate(dataset):
if i % num_workers[0] == worker_id[0]:
yield item
def save_preprocessed_dataset(
cfg: DictDefault,
dataset: Dataset,
dataset_hash: str,
split: str,
) -> None:
"""Save preprocessed dataset to disk and optionally push to the HF Hub."""
prepared_ds_path = get_prepared_dataset_path(cfg, dataset_hash)
if isinstance(dataset, IterableDataset):
num_workers = cfg.dataset_processes
ds_from_iter = Dataset.from_generator(
functools.partial(_generate_from_iterable_dataset, dataset),
features=dataset.features,
num_proc=num_workers,
split=split,
gen_kwargs={
"worker_id": list(range(num_workers)),
"num_workers": [num_workers] * num_workers,
},
)
ds_from_iter.save_to_disk(str(prepared_ds_path))
else:
os.makedirs(prepared_ds_path, exist_ok=True)
dataset.save_to_disk(str(prepared_ds_path))
if cfg.push_dataset_to_hub:
LOG.info(
"Pushing merged prepared dataset to Huggingface hub at "
f"{cfg.push_dataset_to_hub} (version {dataset_hash})...",
main_process_only=False,
)
dataset.push_to_hub(
cfg.push_dataset_to_hub,
dataset_hash,
private=True,
)
def load_preprocessed_dataset(cfg: DictDefault, dataset_hash: str) -> Dataset | None:
"""Load preprocessed dataset from disk if available.
Args:
cfg: Configuration object.
dataset_hash: Hash identifying the dataset configuration.
Returns:
Loaded dataset if found and conditions are met, None otherwise.
"""
prepared_ds_path = get_prepared_dataset_path(cfg, dataset_hash)
if (
cfg.dataset_prepared_path
and any(prepared_ds_path.glob("*"))
and not cfg.skip_prepare_dataset
and not cfg.is_preprocess
):
LOG.info(
f"Loading prepared dataset from disk at {prepared_ds_path}...",
main_process_only=False,
)
return load_from_disk(str(prepared_ds_path))
LOG.info(
f"Unable to find prepared dataset in {prepared_ds_path}",
main_process_only=False,
)
return None
def try_load_from_hub(
cfg: DictDefault, dataset_hash: str, split: str
) -> Dataset | None:
"""Try to load the prepared dataset from HuggingFace Hub."""
try:
LOG.info(
"Attempting to load prepared dataset from HuggingFace Hub at "
f"{cfg.push_dataset_to_hub} (version {dataset_hash})..."
)
dataset = load_dataset(
cfg.push_dataset_to_hub,
dataset_hash,
token=cfg.hf_use_auth_token,
)
return dataset[split]
except Exception: # pylint: disable=broad-except # nosec
LOG.info("Unable to find prepared dataset in HuggingFace Hub")
return None
def generate_dataset_hash_from_config(
cfg: DictDefault, cfg_datasets: list, tokenizer_name: str
) -> str:
"""Generate a hash to uniquely identify a dataset configuration for SFT.
Args:
cfg: Main configuration object.
cfg_datasets: List of dataset configurations.
tokenizer_name: Name of the tokenizer being used.
Returns:
MD5 hash string representing the configuration.
"""
config_str = (
f"{cfg.sequence_len}@{cfg.sample_packing}@{cfg.eval_sample_packing}@"
f"{cfg.group_by_length}@{cfg.kd_temperature or 1.0}|"
f"{'|'.join(sorted([f'{d.path}:{d.type}:{d.shards}:{d.conversation}:{d.split}:{d.temperature or 1.0}' for d in cfg_datasets]))}"
f"|{tokenizer_name}"
)
return str(md5(config_str))
def merge_datasets(datasets: list[Dataset], cfg: DictDefault) -> Dataset:
"""Merge multiple datasets into one with optional shuffling.
Args:
datasets: List of datasets to merge.
cfg: Configuration object containing shuffle settings.
Returns:
Merged dataset.
"""
if len(datasets) == 1:
return datasets[0]
LOG.info("Merging datasets...")
merged_dataset = concatenate_datasets(datasets)
if cfg.shuffle_merged_datasets:
LOG.debug("Shuffling merged datasets...")
merged_dataset = merged_dataset.shuffle(seed=cfg.seed)
else:
LOG.debug("Not shuffling merged datasets.")
return merged_dataset

View File

@@ -1,9 +1,11 @@
"""data handling helpers""" """Data handling helpers"""
import contextlib
import functools import functools
import hashlib import hashlib
import time import time
from enum import Enum from enum import Enum
from typing import Callable
import huggingface_hub import huggingface_hub
import numpy as np import numpy as np
@@ -19,9 +21,7 @@ LOG = get_logger(__name__)
class RetryStrategy(Enum): class RetryStrategy(Enum):
""" """Enum for retry strategies."""
Enum for retry strategies.
"""
CONSTANT = 1 CONSTANT = 1
LINEAR = 2 LINEAR = 2
@@ -30,7 +30,18 @@ class RetryStrategy(Enum):
def retry_on_request_exceptions( def retry_on_request_exceptions(
max_retries=3, delay=1, retry_strategy: RetryStrategy = RetryStrategy.LINEAR max_retries=3, delay=1, retry_strategy: RetryStrategy = RetryStrategy.LINEAR
): ) -> Callable:
"""Decorator that retries function calls on specific request exceptions.
Args:
max_retries: Maximum number of retry attempts.
delay: Base delay between retries in seconds.
retry_strategy: Strategy for calculating retry delays.
Returns:
Decorated function with retry logic.
"""
def decorator(func): def decorator(func):
@functools.wraps(func) @functools.wraps(func)
def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements
@@ -59,6 +70,7 @@ def retry_on_request_exceptions(
def md5(to_hash: str, encoding: str = "utf-8") -> str: def md5(to_hash: str, encoding: str = "utf-8") -> str:
"""Generate MD5 hash of a string."""
try: try:
return hashlib.md5(to_hash.encode(encoding), usedforsecurity=False).hexdigest() return hashlib.md5(to_hash.encode(encoding), usedforsecurity=False).hexdigest()
except TypeError: except TypeError:
@@ -66,102 +78,89 @@ def md5(to_hash: str, encoding: str = "utf-8") -> str:
def sha256(to_hash: str, encoding: str = "utf-8") -> str: def sha256(to_hash: str, encoding: str = "utf-8") -> str:
"""Generate SHA256 hash of a string."""
return hashlib.sha256(to_hash.encode(encoding)).hexdigest() return hashlib.sha256(to_hash.encode(encoding)).hexdigest()
def deduplicate_dataset( def _deduplicate_dataset(
dataset: Dataset, seen_hashes: dict[str, list[int]], other_dataset: Dataset = None dataset: Dataset,
) -> Dataset: seen_hashes: set[str] | None = None,
unique_indices = [] ) -> tuple[Dataset, set[str]]:
"""Remove duplicate rows from a dataset using SHA256 hashes.
Args:
dataset: Dataset to deduplicate.
seen_hashes: Set of previously seen row hashes (for cross-deduplication).
Returns:
Tuple of deduplicated dataset and the set of seen hashes.
"""
if seen_hashes is None:
seen_hashes = set()
unique_indices = []
for idx, row in enumerate(dataset): for idx, row in enumerate(dataset):
row_hash = sha256(str(row)) # Using SHA256 for collision resistance. row_hash = sha256(str(row)) # Using SHA256 for collision resistance
if row_hash not in seen_hashes: if row_hash not in seen_hashes:
seen_hashes[row_hash] = [idx] seen_hashes.add(row_hash)
unique_indices.append(idx) unique_indices.append(idx)
else:
# Check for collision by looking up the original dataset indices return dataset.select(unique_indices), seen_hashes
original_indices = seen_hashes[row_hash]
is_duplicate = False
for original_idx in original_indices:
if (
not idx == original_idx
and original_idx < len(dataset)
and str(dataset[original_idx]) == str(row)
):
is_duplicate = True
break
# Check in the other dataset if provided
if other_dataset is not None:
if original_idx < len(other_dataset) and str(
other_dataset[original_idx]
) == str(row):
is_duplicate = True
break
if not is_duplicate:
seen_hashes[row_hash].append(idx)
unique_indices.append(idx)
continue
return dataset.select(unique_indices)
def deduplicate_and_log_datasets( def deduplicate_and_log_datasets(
*, dataset: Dataset,
train_dataset: Dataset = None, other_dataset: Dataset | None = None,
eval_dataset: Dataset = None, dataset_name: str | None = "train",
dataset: Dataset = None, other_name: str | None = "eval",
) -> tuple[Dataset, Dataset, Dataset]: ) -> tuple[Dataset, Dataset | None]:
""" """Deduplicate datasets, with optional cross-dataset deduplication.
Deduplicates train, eval, and an optional dataset if provided, logging original and new sizes.
Args:
dataset: Primary dataset to deduplicate.
other_dataset: Optional second dataset to deduplicate against the first.
dataset_name: Name for the primary dataset (for logging).
other_name: Name for the second dataset (for logging).
Returns: Returns:
tuple: Deduplicated train, eval, and additional datasets. Tuple of (deduplicated_dataset, deduplicated_other_dataset).
""" """
seen_hashes: dict[str, list[int]] = {} # Deduplicate primary dataset
LOG.info(
f"Starting deduplication for {dataset_name} dataset. Original size: {len(dataset)}"
)
dataset, seen_rows = _deduplicate_dataset(dataset)
LOG.info(
f"Deduplication complete for {dataset_name} dataset. New size: {len(dataset)}"
)
# Handle cases where datasets are None # Deduplicate second dataset if provided
if train_dataset is not None: if other_dataset is not None:
LOG.info( LOG.info(
f"Starting deduplication for train dataset. Original size: {len(train_dataset)}" f"Starting deduplication for {other_name} dataset. Original size: {len(other_dataset)}"
)
train_dataset = deduplicate_dataset(
dataset=train_dataset, seen_hashes=seen_hashes
) )
other_dataset, _ = _deduplicate_dataset(other_dataset, seen_rows)
LOG.info( LOG.info(
f"Deduplication complete for train dataset. New size: {len(train_dataset)}" f"Deduplication complete for {other_name} dataset. New size: {len(other_dataset)}"
)
else:
LOG.info("Train dataset is None. Skipping deduplication.")
if eval_dataset is not None:
LOG.info(
f"Starting deduplication for eval dataset. Original size: {len(eval_dataset)}"
)
eval_dataset = deduplicate_dataset(
dataset=eval_dataset, seen_hashes=seen_hashes, other_dataset=train_dataset
)
LOG.info(
f"Deduplication complete for eval dataset. New size: {len(eval_dataset)}"
)
else:
LOG.info("Eval dataset is None. Skipping deduplication.")
if dataset is not None and (eval_dataset is None and train_dataset is None):
LOG.info(
f"Starting deduplication for combined dataset. Original size: {len(dataset)}"
)
dataset = deduplicate_dataset(dataset=dataset, seen_hashes=seen_hashes)
LOG.info(
f"Deduplication complete for combined dataset. New size: {len(dataset)}"
) )
return train_dataset, eval_dataset, dataset return dataset, other_dataset
def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault): def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault) -> Dataset:
"""Remove sequences longer than configured maximum from dataset.
Args:
dataset: Dataset to filter.
cfg: Dictionary mapping `axolotl` config keys to values.
Returns:
Filtered dataset with long sequences removed.
"""
if "input_ids" not in dataset.column_names: if "input_ids" not in dataset.column_names:
LOG.warning( LOG.warning(
"Dataset does not contain 'input_ids' column. Skip drop long seq. This is expected for RewardModeling." "Dataset does not contain 'input_ids' column. Skip drop long seq. This is "
"expected for reward modeling."
) )
return dataset return dataset
@@ -171,20 +170,14 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault):
min_sequence_len=cfg.min_sample_len, min_sequence_len=cfg.min_sample_len,
) )
try: with contextlib.suppress(AttributeError):
ds_lengths = get_dataset_lengths(dataset, from_arrow=True) ds_lengths = get_dataset_lengths(dataset, from_arrow=True)
min_input_len = np.min(ds_lengths) min_input_len = np.min(ds_lengths)
LOG.info(f"min_input_len: {min_input_len}") LOG.info(f"min_input_len: {min_input_len}")
max_input_len = np.max(ds_lengths) max_input_len = np.max(ds_lengths)
LOG.info(f"max_input_len: {max_input_len}") LOG.info(f"max_input_len: {max_input_len}")
except AttributeError:
pass
try: prior_len = len(dataset) if hasattr(dataset, "__len__") else None
prior_len = len(dataset)
except TypeError:
# handle iterable datasets case
prior_len = None
filter_map_kwargs = {} filter_map_kwargs = {}
if not isinstance(dataset, IterableDataset): if not isinstance(dataset, IterableDataset):

View File

@@ -0,0 +1,425 @@
"""Data handling specific to SFT."""
import logging
from typing import Any, NoReturn, cast
from datasets import (
Dataset,
IterableDataset,
Sequence,
Value,
)
from transformers import PreTrainedTokenizer
from transformers.processing_utils import ProcessorMixin
from axolotl.datasets import TokenizedPromptDataset, wrap_dataset_for_tokenized_prompt
from axolotl.prompt_strategies import load
from axolotl.prompt_strategies.bradley_terry import load as bradley_terry_load
from axolotl.prompt_tokenizers import (
AlpacaMultipleChoicePromptTokenizingStrategy,
AlpacaPromptTokenizingStrategy,
AlpacaReflectionPTStrategy,
DatasetWrappingStrategy,
GPTeacherPromptTokenizingStrategy,
JeopardyPromptTokenizingStrategy,
OpenAssistantPromptTokenizingStrategy,
PromptTokenizingStrategy,
SummarizeTLDRPromptTokenizingStrategy,
)
from axolotl.prompters import (
AlpacaPrompter,
GPTeacherPrompter,
JeopardyPrompter,
MultipleChoiceConcisePrompter,
MultipleChoiceExplainPrompter,
Prompter,
ReflectAlpacaPrompter,
SummarizeTLDRPrompter,
UnsupportedPrompter,
)
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger(__name__)
def handle_unknown_dataset_strategy(dataset_config: DictDefault) -> NoReturn:
"""Raise error for unknown dataset strategy."""
ds_type = dataset_config.type
suffix = ""
if ":load_" in ds_type:
suffix = f"Did you mean {ds_type.replace(':load_', '.load_')}?"
error_message = f"unhandled prompt tokenization strategy: {ds_type}. {suffix}"
LOG.error(error_message)
raise ValueError(error_message)
# pylint: disable=too-many-return-statements
def get_dataset_wrapper(
dataset_config: DictDefault,
tokenizer: PreTrainedTokenizer,
cfg: DictDefault,
dataset_base_type: str | None,
dataset: Dataset | IterableDataset,
dataset_prompt_style: str | None = None,
processor: ProcessorMixin | None = None, # pylint: disable=unused-argument
) -> tuple[Dataset | IterableDataset, Prompter | None]:
"""Create an appropriate dataset wrapper and prompter based on dataset
configuration.
Args:
dataset_config: Configuration for the dataset.
tokenizer: Tokenizer to use for processing text.
cfg: Global configuration object.
dataset_base_type: The base type of the dataset.
dataset: The actual dataset object.
dataset_prompt_style: Optional prompt style specification.
processor: Optional processor for multimodal datasets.
Returns:
tuple of (dataset_wrapper, dataset_prompter).
"""
# Common parameters for dataset wrapping
dataset_kwargs: dict[str, Any] = {
"process_count": cfg.dataset_processes,
"keep_in_memory": cfg.dataset_keep_in_memory is True,
}
LOG.info(
f"Loading dataset: {dataset_config['path']} with base_type: "
f"{dataset_base_type} and prompt_style: {dataset_prompt_style}"
)
# Dataset is already tokenized
if _is_dataset_already_tokenized(dataset):
return dataset, UnsupportedPrompter()
# Custom dataset type definition
if isinstance(dataset_config.type, DictDefault):
return _handle_custom_dataset_type(
dataset_config, tokenizer, cfg, dataset, dataset_kwargs
)
# Skip preparation if configured
if cfg.skip_prepare_dataset:
return dataset, None
# Bradley-Terry dataset
if dataset_config.type.startswith("bradley_terry"):
return _handle_bradley_terry_dataset(
dataset_config, tokenizer, cfg, dataset, dataset_kwargs
)
# Stepwise supervised dataset
if dataset_config.type.startswith("stepwise_supervised"):
return _handle_stepwise_supervised_dataset(
dataset_config, tokenizer, cfg, dataset, dataset_kwargs
)
# Try to load prompt tokenizer / dataset wrapper strategy from registry
dataset_strategy = load(
dataset_config.type, tokenizer, cfg, dataset_config, processor=processor
)
if dataset_strategy:
return _handle_loaded_strategy(dataset_strategy, dataset, dataset_kwargs)
# Known dataset types with specific handling
if dataset_base_type in DATASET_HANDLERS:
handler = DATASET_HANDLERS[dataset_base_type]
return handler(dataset_prompt_style, tokenizer, cfg, dataset, dataset_kwargs)
# Unhandled dataset type
handle_unknown_dataset_strategy(dataset_config)
def _is_dataset_already_tokenized(dataset: Dataset | IterableDataset) -> bool:
"""Check if the dataset is already tokenized."""
return (
isinstance(dataset, Dataset)
and "input_ids" in dataset.features
and "attention_mask" in dataset.features
and "labels" in dataset.features
)
def _handle_custom_dataset_type(
dataset_config: DictDefault,
tokenizer: PreTrainedTokenizer,
cfg: DictDefault,
dataset: Dataset | IterableDataset,
dataset_kwargs: dict[str, Any],
) -> tuple[Dataset | IterableDataset, Prompter]:
"""Handle a custom dataset type defined in the configuration."""
dataset_strategy = cast(
PromptTokenizingStrategy,
load("user_defined", tokenizer, cfg, dataset_config.type.to_dict()),
)
dataset_prompter = UnsupportedPrompter()
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
dataset_strategy,
dataset,
**dataset_kwargs,
)
return dataset_wrapper, dataset_prompter
def _handle_bradley_terry_dataset(
dataset_config: DictDefault,
tokenizer: PreTrainedTokenizer,
cfg: DictDefault,
dataset: Dataset | IterableDataset,
dataset_kwargs: dict[str, Any],
) -> tuple[Dataset | IterableDataset, Prompter | None]:
"""Handle a Bradley-Terry dataset."""
bt_type = dataset_config.type.split(".", 1)[1]
dataset_strategy = bradley_terry_load(bt_type, tokenizer, cfg, dataset_config)
if not dataset_strategy:
handle_unknown_dataset_strategy(dataset_config)
dataset_prompter = UnsupportedPrompter()
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
dataset_strategy,
dataset,
**dataset_kwargs,
)
return dataset_wrapper, dataset_prompter
def _handle_stepwise_supervised_dataset(
dataset_config: DictDefault,
tokenizer: PreTrainedTokenizer,
cfg: DictDefault,
dataset: Dataset | IterableDataset,
dataset_kwargs: dict[str, Any],
) -> tuple[Dataset | IterableDataset, Prompter]:
"""Handle a stepwise supervised dataset."""
dataset_prompter = UnsupportedPrompter()
dataset_strategy = load(dataset_config.type, tokenizer, cfg, dataset_config)
# We need to explicitly cast boolean labels to int
# for compatibility with how trl's PRMTrainer works
if isinstance(dataset, Dataset):
dataset = dataset.cast_column("labels", Sequence(Value("int64")))
dataset_wrapper = TokenizedPromptDataset(
dataset_strategy,
dataset,
**dataset_kwargs,
)
return dataset_wrapper, dataset_prompter
def _handle_loaded_strategy(
dataset_strategy: PromptTokenizingStrategy | DatasetWrappingStrategy,
dataset: Dataset | IterableDataset,
dataset_kwargs: dict[str, Any],
) -> tuple[Dataset | IterableDataset, Prompter | None]:
"""Handle a dataset with a strategy loaded from the registry."""
if isinstance(dataset_strategy, DatasetWrappingStrategy):
return dataset_strategy.wrap_dataset(dataset, **dataset_kwargs), None
dataset_prompter = UnsupportedPrompter()
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
dataset_strategy,
dataset,
**dataset_kwargs,
)
return dataset_wrapper, dataset_prompter
def _handle_alpaca_dataset(
dataset_prompt_style: str | None,
tokenizer: PreTrainedTokenizer,
cfg: DictDefault,
dataset: Dataset | IterableDataset,
dataset_kwargs: dict[str, Any],
) -> tuple[Dataset | IterableDataset, Prompter]:
"""Handle an Alpaca dataset."""
dataset_prompter = AlpacaPrompter(dataset_prompt_style)
dataset_strategy = AlpacaPromptTokenizingStrategy(
dataset_prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
dataset_strategy,
dataset,
**dataset_kwargs,
)
return dataset_wrapper, dataset_prompter
def _handle_explainchoice_dataset(
dataset_prompt_style: str | None,
tokenizer: PreTrainedTokenizer,
cfg: DictDefault,
dataset: Dataset | IterableDataset,
dataset_kwargs: dict[str, Any],
) -> tuple[Dataset | IterableDataset, Prompter]:
"""Handle an ExplainChoice dataset."""
dataset_prompter = MultipleChoiceExplainPrompter(dataset_prompt_style)
dataset_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
dataset_prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
dataset_strategy,
dataset,
**dataset_kwargs,
)
return dataset_wrapper, dataset_prompter
def _handle_concisechoice_dataset(
dataset_prompt_style: str | None,
tokenizer: PreTrainedTokenizer,
cfg: DictDefault,
dataset: Dataset | IterableDataset,
dataset_kwargs: dict[str, Any],
) -> tuple[Dataset | IterableDataset, Prompter]:
"""Handle a ConciseChoice dataset."""
dataset_prompter = MultipleChoiceConcisePrompter(dataset_prompt_style)
dataset_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
dataset_prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
dataset_strategy,
dataset,
**dataset_kwargs,
)
return dataset_wrapper, dataset_prompter
def _handle_summarizetldr_dataset(
dataset_prompt_style: str | None,
tokenizer: PreTrainedTokenizer,
cfg: DictDefault,
dataset: Dataset | IterableDataset,
dataset_kwargs: dict[str, Any],
) -> tuple[Dataset | IterableDataset, Prompter]:
"""Handle a SummarizeTLDR dataset."""
dataset_prompter = SummarizeTLDRPrompter(dataset_prompt_style)
dataset_strategy = SummarizeTLDRPromptTokenizingStrategy(
dataset_prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
dataset_strategy,
dataset,
**dataset_kwargs,
)
return dataset_wrapper, dataset_prompter
def _handle_jeopardy_dataset(
dataset_prompt_style: str | None,
tokenizer: PreTrainedTokenizer,
cfg: DictDefault,
dataset: Dataset | IterableDataset,
dataset_kwargs: dict[str, Any],
) -> tuple[Dataset | IterableDataset, Prompter]:
"""Handle a Jeopardy dataset."""
dataset_prompter = JeopardyPrompter(dataset_prompt_style)
dataset_strategy = JeopardyPromptTokenizingStrategy(
dataset_prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
dataset_strategy,
dataset,
**dataset_kwargs,
)
return dataset_wrapper, dataset_prompter
def _handle_oasst_dataset(
dataset_prompt_style: str | None,
tokenizer: PreTrainedTokenizer,
cfg: DictDefault,
dataset: Dataset | IterableDataset,
dataset_kwargs: dict[str, Any],
) -> tuple[Dataset | IterableDataset, Prompter]:
"""Handle an OpenAssistant dataset."""
dataset_prompter = AlpacaPrompter(dataset_prompt_style)
dataset_strategy = OpenAssistantPromptTokenizingStrategy(
dataset_prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
dataset_strategy,
dataset,
**dataset_kwargs,
)
return dataset_wrapper, dataset_prompter
def _handle_gpteacher_dataset(
dataset_prompt_style: str | None,
tokenizer: PreTrainedTokenizer,
cfg: DictDefault,
dataset: Dataset | IterableDataset,
dataset_kwargs: dict[str, Any],
) -> tuple[Dataset | IterableDataset, Prompter]:
"""Handle a GPTeacher dataset."""
dataset_prompter = GPTeacherPrompter(dataset_prompt_style)
dataset_strategy = GPTeacherPromptTokenizingStrategy(
dataset_prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
dataset_strategy,
dataset,
**dataset_kwargs,
)
return dataset_wrapper, dataset_prompter
def _handle_reflection_dataset(
dataset_prompt_style: str | None,
tokenizer: PreTrainedTokenizer,
cfg: DictDefault,
dataset: Dataset | IterableDataset,
dataset_kwargs: dict[str, Any],
) -> tuple[Dataset | IterableDataset, Prompter]:
"""Handle a Reflection dataset."""
dataset_prompter = ReflectAlpacaPrompter(dataset_prompt_style)
dataset_strategy = AlpacaReflectionPTStrategy(
dataset_prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
dataset_strategy,
dataset,
**dataset_kwargs,
)
return dataset_wrapper, dataset_prompter
DATASET_HANDLERS = {
"alpaca": _handle_alpaca_dataset,
"explainchoice": _handle_explainchoice_dataset,
"concisechoice": _handle_concisechoice_dataset,
"summarizetldr": _handle_summarizetldr_dataset,
"jeopardy": _handle_jeopardy_dataset,
"oasst": _handle_oasst_dataset,
"gpteacher": _handle_gpteacher_dataset,
"reflection": _handle_reflection_dataset,
}

View File

@@ -336,6 +336,14 @@ class AxolotlInputConfig(
plugins: list[str] | None = Field(default=None) plugins: list[str] | None = Field(default=None)
@field_validator("seed", mode="after")
@classmethod
def set_default_seed(cls, seed):
if seed is None:
LOG.info("`seed` not set in config; setting to 42")
seed = 42
return seed
@field_validator("datasets", mode="before") @field_validator("datasets", mode="before")
@classmethod @classmethod
def deprecate_sharegpt_datasets(cls, datasets): def deprecate_sharegpt_datasets(cls, datasets):
@@ -1199,7 +1207,7 @@ class AxolotlInputConfig(
"flash_attention: true must be set with sequence_parallel_degree > 1" "flash_attention: true must be set with sequence_parallel_degree > 1"
) )
if self.sample_packing and self.micro_batch_size > 1: if self.sample_packing and getattr(self, "micro_batch_size", 1) > 1:
raise ValueError( raise ValueError(
"micro_batch_size must be set to 1 when sample_packing is enabled " "micro_batch_size must be set to 1 when sample_packing is enabled "
"due to a `ring-flash-attn` requirement" "due to a `ring-flash-attn` requirement"

View File

@@ -12,7 +12,7 @@ from axolotl.common.datasets import load_datasets
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.loaders import ModelLoader, load_tokenizer from axolotl.loaders import ModelLoader, load_tokenizer
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.data.rl import load_prepare_preference_datasets from axolotl.utils.data import prepare_preference_datasets
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.schemas.enums import RLType from axolotl.utils.schemas.enums import RLType
@@ -451,15 +451,19 @@ def rand_reward_func(prompts, completions) -> list[float]:
# Only use mock for the commented out configs # Only use mock for the commented out configs
if dataset_name is not None: if dataset_name is not None:
with patch( with patch(
"axolotl.utils.data.rl.load_dataset_w_config" "axolotl.utils.data.rl.load_dataset_with_config"
) as mock_load_dataset: ) as mock_load_dataset:
mock_load_dataset.return_value = request.getfixturevalue( mock_load_dataset.return_value = request.getfixturevalue(
dataset_name dataset_name
) )
train_dataset, eval_dataset = load_prepare_preference_datasets(cfg) train_dataset, eval_dataset = prepare_preference_datasets(
cfg, tokenizer
)
else: else:
# Load actual datasets for orpo_cfg and kto_cfg # Load actual datasets for orpo_cfg and kto_cfg
train_dataset, eval_dataset = load_prepare_preference_datasets(cfg) train_dataset, eval_dataset = prepare_preference_datasets(
cfg, tokenizer
)
builder.train_dataset = train_dataset builder.train_dataset = train_dataset
builder.eval_dataset = eval_dataset builder.eval_dataset = eval_dataset

View File

@@ -4,7 +4,6 @@ Simple end-to-end test for Cut Cross Entropy integration
import pytest import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils import get_pytorch_version from axolotl.utils import get_pytorch_version
@@ -59,8 +58,7 @@ class TestCutCrossEntropyIntegration:
cfg = validate_config(cfg) cfg = validate_config(cfg)
prepare_plugins(cfg) prepare_plugins(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
major, minor, _ = get_pytorch_version() major, minor, _ = get_pytorch_version()
if (major, minor) < (2, 4): if (major, minor) < (2, 4):
@@ -105,8 +103,7 @@ class TestCutCrossEntropyIntegration:
cfg = validate_config(cfg) cfg = validate_config(cfg)
prepare_plugins(cfg) prepare_plugins(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
major, minor, _ = get_pytorch_version() major, minor, _ = get_pytorch_version()
if (major, minor) < (2, 4): if (major, minor) < (2, 4):
@@ -134,8 +131,7 @@ class TestCutCrossEntropyIntegration:
cfg = validate_config(cfg) cfg = validate_config(cfg)
prepare_plugins(cfg) prepare_plugins(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
major, minor, _ = get_pytorch_version() major, minor, _ = get_pytorch_version()
if (major, minor) < (2, 4): if (major, minor) < (2, 4):

View File

@@ -5,7 +5,6 @@ e2e tests to make sure all the hooks are fired on the plugin
import os import os
from pathlib import Path from pathlib import Path
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.integrations.base import BasePlugin from axolotl.integrations.base import BasePlugin
from axolotl.train import train from axolotl.train import train
@@ -160,8 +159,7 @@ class TestPluginHooks:
cfg = validate_config(cfg) cfg = validate_config(cfg)
prepare_plugins(cfg) prepare_plugins(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)

View File

@@ -6,7 +6,6 @@ from pathlib import Path
import pytest import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
@@ -84,8 +83,7 @@ class TestKnowledgeDistillation:
cfg = validate_config(cfg) cfg = validate_config(cfg)
prepare_plugins(cfg) prepare_plugins(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists() assert (Path(temp_dir) / "model.safetensors").exists()
@@ -115,8 +113,7 @@ class TestKnowledgeDistillation:
cfg = validate_config(cfg) cfg = validate_config(cfg)
prepare_plugins(cfg) prepare_plugins(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.safetensors").exists() assert (Path(temp_dir) / "adapter_model.safetensors").exists()

View File

@@ -2,7 +2,6 @@
Simple end-to-end test for Liger integration Simple end-to-end test for Liger integration
""" """
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
@@ -57,8 +56,7 @@ class LigerIntegrationTestCase:
cfg = validate_config(cfg) cfg = validate_config(cfg)
prepare_plugins(cfg) prepare_plugins(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
@@ -104,8 +102,7 @@ class LigerIntegrationTestCase:
cfg = validate_config(cfg) cfg = validate_config(cfg)
prepare_plugins(cfg) prepare_plugins(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)

View File

@@ -6,7 +6,6 @@ from pathlib import Path
import pytest import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
@@ -88,8 +87,7 @@ class TestLLMCompressorIntegration:
prepare_plugins(cfg) prepare_plugins(cfg)
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
try: try:
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)

View File

@@ -105,7 +105,7 @@ def start_vllm(
print(f"{i}: VLLM server failed to start: {str(exc)}") print(f"{i}: VLLM server failed to start: {str(exc)}")
# also check if the process.pid is still running # also check if the process.pid is still running
if not process.poll() is None: if process.poll() is not None:
break break
time.sleep(period_seconds) time.sleep(period_seconds)

View File

@@ -0,0 +1,192 @@
"""Tests for FileLockLoader class."""
import tempfile
import threading
import time
from pathlib import Path
from unittest.mock import MagicMock, Mock, patch
import pytest
from axolotl.utils.data.lock import FileLockLoader
from axolotl.utils.dict import DictDefault
class TestFileLockLoader:
"""Class with tests for FileLockLoader."""
@pytest.fixture
def temp_dir(self):
"""Create a temporary directory for testing."""
with tempfile.TemporaryDirectory() as tmp_dir:
yield Path(tmp_dir)
@pytest.fixture
def cfg(self, temp_dir):
"""Create a test configuration."""
return DictDefault({"dataset_prepared_path": str(temp_dir)})
@pytest.fixture
def loader(self, cfg):
"""Create a FileLockLoader instance for testing."""
return FileLockLoader(cfg)
def test_load_first_process(self, loader):
"""Test load() when no ready flag exists (first process)."""
mock_load_fn = Mock(return_value="test_data")
result = loader.load(mock_load_fn)
# Should call the load function
mock_load_fn.assert_called_once()
assert result == "test_data"
# Should create the ready flag
assert loader.ready_flag_path.exists()
def test_load_subsequent_process(self, loader):
"""Test load() when ready flag already exists (subsequent process)."""
# Create ready flag first
loader.ready_flag_path.touch()
mock_load_fn = Mock(return_value="loaded_data")
result = loader.load(mock_load_fn)
# Should still call load function (to load the prepared data)
mock_load_fn.assert_called_once()
assert result == "loaded_data"
def test_load_concurrent_processes(self, cfg):
"""Test that concurrent processes coordinate correctly."""
results = []
call_count = 0
def slow_load_fn():
nonlocal call_count
call_count += 1
time.sleep(0.1) # Simulate slow loading
return f"data_{call_count}"
def worker():
loader = FileLockLoader(cfg)
result = loader.load(slow_load_fn)
results.append(result)
# Start multiple threads simultaneously
threads = [threading.Thread(target=worker) for _ in range(3)]
for t in threads:
t.start()
for t in threads:
t.join()
# Only one thread should have done the initial loading
# All should return data, but the load function should be called
# once by the first process and once by each subsequent process
assert len(results) == 3
assert all(result.startswith("data_") for result in results)
@patch("time.sleep")
def test_load_waiting_for_ready_flag(self, mock_sleep, loader):
"""Test that processes wait for the ready flag to appear."""
mock_load_fn = Mock(return_value="waiting_data")
mock_ready_flag_path = Mock()
exists_call_count = 0
def mock_exists():
nonlocal exists_call_count
exists_call_count += 1
if exists_call_count == 1:
# First check: ready flag exists (not first process)
return True
if exists_call_count <= 3:
# While loop checks: flag doesn't exist yet
return False
return True
mock_ready_flag_path.exists.side_effect = mock_exists
# Replace the ready_flag_path with our mock
original_path = loader.ready_flag_path
loader.ready_flag_path = mock_ready_flag_path
try:
result = loader.load(mock_load_fn)
finally:
# Restore original path
loader.ready_flag_path = original_path
# Should have slept twice while waiting
assert mock_sleep.call_count == 2
mock_sleep.assert_called_with(1)
# Should eventually call load function
mock_load_fn.assert_called_once()
assert result == "waiting_data"
def test_complete_workflow_with_cleanup(self, loader):
"""Test the complete load -> cleanup workflow."""
mock_load_fn = Mock(return_value="test_data")
# First process calls load (this should set up counter)
result = loader.load(mock_load_fn)
assert result == "test_data"
assert loader.ready_flag_path.exists()
assert loader.counter_path.exists()
# Cleanup should remove everything since there's only one process
loader.cleanup()
assert not loader.ready_flag_path.exists()
assert not loader.counter_path.exists()
def test_multiple_processes_workflow(self, loader):
"""Test workflow with multiple processes."""
# Simulate multiple processes by manually setting up counter
loader.ready_flag_path.touch()
loader.counter_path.write_text("3") # 3 processes
# First process cleanup
loader.cleanup()
assert loader.ready_flag_path.exists()
assert loader.counter_path.read_text().strip() == "2"
# Second process cleanup
loader.cleanup()
assert loader.ready_flag_path.exists()
assert loader.counter_path.read_text().strip() == "1"
# Last process cleanup
loader.cleanup()
assert not loader.ready_flag_path.exists()
assert not loader.counter_path.exists()
def test_load_exception_handling(self, loader):
"""Test behavior when load_fn raises an exception."""
def failing_load_fn():
raise ValueError("Load failed")
with pytest.raises(ValueError, match="Load failed"):
loader.load(failing_load_fn)
# Ready flag should not be created on failure
assert not loader.ready_flag_path.exists()
def test_file_lock_called(self, loader):
"""Test that FileLock is properly used."""
mock_load_fn = Mock(return_value="locked_data")
with patch("axolotl.utils.data.lock.FileLock") as mock_filelock:
mock_context = MagicMock()
mock_filelock.return_value.__enter__ = Mock(return_value=mock_context)
mock_filelock.return_value.__exit__ = Mock(return_value=None)
loader.load(mock_load_fn)
# Verify FileLock was called with correct path
mock_filelock.assert_called_once_with(str(loader.lock_file_path))
# Verify context manager was used
mock_filelock.return_value.__enter__.assert_called_once()
mock_filelock.return_value.__exit__.assert_called_once()

View File

@@ -4,7 +4,6 @@ E2E tests for multipack fft llama using 4d attention masks
import unittest import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config, validate_config
@@ -60,8 +59,7 @@ class Test4dMultipackLlama(unittest.TestCase):
) )
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
@@ -108,8 +106,7 @@ class Test4dMultipackLlama(unittest.TestCase):
) )
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)

View File

@@ -6,7 +6,6 @@ import pytest
import transformers import transformers
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config, validate_config
@@ -75,8 +74,7 @@ class TestActivationCheckpointing:
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)

View File

@@ -5,7 +5,6 @@ E2E tests for lora llama
import pytest import pytest
from transformers.utils import is_torch_bf16_gpu_available from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config, validate_config
@@ -73,8 +72,7 @@ class TestFAXentropyLlama:
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)

View File

@@ -6,7 +6,6 @@ import unittest
import pytest import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config, validate_config
@@ -63,8 +62,7 @@ class TestFalconPatched(unittest.TestCase):
) )
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
@@ -105,8 +103,7 @@ class TestFalconPatched(unittest.TestCase):
) )
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)

View File

@@ -7,7 +7,6 @@ import unittest
import pytest import pytest
from transformers.utils import is_torch_bf16_gpu_available from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config, validate_config
@@ -62,8 +61,7 @@ class TestFusedLlama(unittest.TestCase):
cfg.fp16 = True cfg.fp16 = True
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)

View File

@@ -6,7 +6,6 @@ import unittest
import pytest import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config, validate_config
@@ -64,8 +63,7 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
@@ -107,8 +105,7 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)

View File

@@ -7,7 +7,6 @@ import unittest
import pytest import pytest
from transformers.utils import is_auto_gptq_available, is_torch_bf16_gpu_available from transformers.utils import is_auto_gptq_available, is_torch_bf16_gpu_available
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config, validate_config
@@ -65,8 +64,7 @@ class TestLoraLlama(unittest.TestCase):
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
@@ -114,8 +112,7 @@ class TestLoraLlama(unittest.TestCase):
) )
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)

View File

@@ -4,7 +4,6 @@ E2E tests for lora llama
import unittest import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config, validate_config
@@ -60,8 +59,7 @@ class TestMistral(unittest.TestCase):
) )
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
@@ -102,8 +100,7 @@ class TestMistral(unittest.TestCase):
) )
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)

View File

@@ -4,7 +4,6 @@ E2E tests for mixtral
import unittest import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config, validate_config
@@ -57,8 +56,7 @@ class TestMixtral(unittest.TestCase):
) )
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
@@ -96,8 +94,7 @@ class TestMixtral(unittest.TestCase):
) )
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)

View File

@@ -4,7 +4,6 @@ E2E tests for lora llama
import unittest import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config, validate_config
@@ -60,8 +59,7 @@ class TestPhiMultipack(unittest.TestCase):
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
@@ -112,8 +110,7 @@ class TestPhiMultipack(unittest.TestCase):
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)

View File

@@ -7,7 +7,6 @@ import subprocess
from transformers.utils import is_torch_bf16_gpu_available from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config, validate_config
@@ -67,8 +66,7 @@ class TestResumeLlama:
cfg.fp16 = True cfg.fp16 = True
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
@@ -78,7 +76,6 @@ class TestResumeLlama:
} }
) )
normalize_config(resume_cfg) normalize_config(resume_cfg)
cli_args = TrainerCliArgs()
train(cfg=resume_cfg, dataset_meta=dataset_meta) train(cfg=resume_cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)

View File

@@ -4,7 +4,6 @@ e2e tests for unsloth qlora
import pytest import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config, validate_config
@@ -68,8 +67,7 @@ class TestUnslothQLoRA:
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
@@ -119,8 +117,7 @@ class TestUnslothQLoRA:
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
@@ -175,8 +172,7 @@ class TestUnslothQLoRA:
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)

View File

@@ -6,7 +6,6 @@ import unittest
from transformers.utils import is_torch_bf16_gpu_available from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config, validate_config
@@ -59,8 +58,7 @@ class TestPackedFlex(unittest.TestCase):
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)

View File

@@ -5,7 +5,6 @@ E2E tests for relora llama
import unittest import unittest
from pathlib import Path from pathlib import Path
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config, validate_config
@@ -71,8 +70,7 @@ class TestReLoraLlama(unittest.TestCase):
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-100/adapter", cfg) check_model_output_exists(Path(temp_dir) / "checkpoint-100/adapter", cfg)

View File

@@ -6,7 +6,6 @@ from pathlib import Path
import pytest import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config, validate_config
@@ -72,8 +71,7 @@ class TestDeepseekV3:
) )
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.safetensors").exists() assert (Path(temp_dir) / "adapter_model.safetensors").exists()
@@ -122,8 +120,7 @@ class TestDeepseekV3:
) )
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists() assert (Path(temp_dir) / "model.safetensors").exists()

View File

@@ -1,6 +1,4 @@
""" """E2E tests for lora llama"""
E2E tests for lora llama
"""
import unittest import unittest
from pathlib import Path from pathlib import Path

View File

@@ -4,7 +4,6 @@ E2E tests for llama pretrain
import unittest import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config, validate_config
@@ -54,8 +53,7 @@ class TestEmbeddingsLrScale(unittest.TestCase):
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
@@ -99,8 +97,7 @@ class TestEmbeddingsLrScale(unittest.TestCase):
) )
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)

View File

@@ -6,7 +6,6 @@ import unittest
import pytest import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config, validate_config
@@ -66,8 +65,7 @@ class TestFalcon(unittest.TestCase):
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
@@ -122,8 +120,7 @@ class TestFalcon(unittest.TestCase):
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
@@ -164,8 +161,7 @@ class TestFalcon(unittest.TestCase):
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)

View File

@@ -6,7 +6,6 @@ from pathlib import Path
import pytest import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config, validate_config
@@ -69,8 +68,7 @@ class TestGemma2:
) )
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.safetensors").exists() assert (Path(temp_dir) / "adapter_model.safetensors").exists()
@@ -121,8 +119,7 @@ class TestGemma2:
) )
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists() assert (Path(temp_dir) / "model.safetensors").exists()

View File

@@ -6,7 +6,6 @@ from pathlib import Path
import pytest import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config, validate_config
@@ -68,8 +67,7 @@ class TestGemma3Text:
) )
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.safetensors").exists() assert (Path(temp_dir) / "adapter_model.safetensors").exists()
@@ -119,8 +117,7 @@ class TestGemma3Text:
) )
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists() assert (Path(temp_dir) / "model.safetensors").exists()

View File

@@ -2,7 +2,6 @@
E2E tests for llama E2E tests for llama
""" """
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config, validate_config
@@ -51,8 +50,7 @@ class TestLlama:
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
@@ -99,8 +97,7 @@ class TestLlama:
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
@@ -144,8 +141,7 @@ class TestLlama:
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
@@ -185,8 +181,7 @@ class TestLlama:
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)

View File

@@ -1,10 +1,7 @@
""" """E2E tests for llama pretrain"""
E2E tests for llama pretrain
"""
import pytest import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config, validate_config
@@ -14,9 +11,7 @@ from .utils import check_model_output_exists, check_tensorboard
class TestPretrainLlama: class TestPretrainLlama:
""" """Test case for Llama models w pretraining"""
Test case for Llama models w pretraining
"""
@pytest.mark.parametrize( @pytest.mark.parametrize(
"sample_packing", "sample_packing",
@@ -66,8 +61,7 @@ class TestPretrainLlama:
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)

View File

@@ -4,7 +4,6 @@ E2E tests for lora llama
import unittest import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config, validate_config
@@ -60,8 +59,7 @@ class TestLlamaVision(unittest.TestCase):
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
@@ -106,8 +104,7 @@ class TestLlamaVision(unittest.TestCase):
) )
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)

View File

@@ -4,7 +4,6 @@ E2E tests for lora llama
import unittest import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config, validate_config
@@ -55,8 +54,7 @@ class TestLoraLlama(unittest.TestCase):
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)

View File

@@ -6,7 +6,6 @@ import unittest
import pytest import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config, validate_config
@@ -57,8 +56,7 @@ class TestMamba(unittest.TestCase):
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)

View File

@@ -6,7 +6,6 @@ import unittest
from transformers.utils import is_torch_bf16_gpu_available from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config, validate_config
@@ -61,8 +60,7 @@ class TestMistral(unittest.TestCase):
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
@@ -106,8 +104,7 @@ class TestMistral(unittest.TestCase):
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)

View File

@@ -7,7 +7,6 @@ import unittest
import torch import torch
from transformers.utils import is_torch_bf16_gpu_available from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config, validate_config
@@ -67,8 +66,7 @@ class TestMixtral(unittest.TestCase):
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta) model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
assert ( assert (
@@ -123,8 +121,7 @@ class TestMixtral(unittest.TestCase):
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta) model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
assert ( assert (
@@ -182,8 +179,7 @@ class TestMixtral(unittest.TestCase):
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta) model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
assert ( assert (
@@ -241,8 +237,7 @@ class TestMixtral(unittest.TestCase):
cfg.bf16 = True cfg.bf16 = True
else: else:
cfg.fp16 = True cfg.fp16 = True
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta) model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
assert ( assert (
@@ -287,8 +282,7 @@ class TestMixtral(unittest.TestCase):
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)

View File

@@ -4,7 +4,6 @@ E2E tests for custom optimizers using Llama
import unittest import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config, validate_config
@@ -61,8 +60,7 @@ class TestCustomOptimizers(unittest.TestCase):
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta) _, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
@@ -107,8 +105,7 @@ class TestCustomOptimizers(unittest.TestCase):
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta) _, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
@@ -154,8 +151,7 @@ class TestCustomOptimizers(unittest.TestCase):
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta) _, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
@@ -194,8 +190,7 @@ class TestCustomOptimizers(unittest.TestCase):
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
@@ -242,8 +237,7 @@ class TestCustomOptimizers(unittest.TestCase):
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)

View File

@@ -6,7 +6,6 @@ import unittest
from transformers.utils import is_torch_bf16_gpu_available from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config, validate_config
@@ -58,8 +57,7 @@ class TestPackedLlama(unittest.TestCase):
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)

View File

@@ -4,7 +4,6 @@ E2E tests for lora llama
import unittest import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config, validate_config
@@ -58,8 +57,7 @@ class TestPhi(unittest.TestCase):
) )
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
@@ -108,8 +106,7 @@ class TestPhi(unittest.TestCase):
) )
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)

View File

@@ -4,7 +4,6 @@ E2E tests for process reward model w/ lora llama
import unittest import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config, validate_config
@@ -54,8 +53,7 @@ class TestProcessRewardSmolLM2(unittest.TestCase):
) )
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_tensorboard( check_tensorboard(

View File

@@ -5,7 +5,6 @@ E2E tests for QAT
import unittest import unittest
from pathlib import Path from pathlib import Path
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config, validate_config
@@ -64,8 +63,7 @@ class TestQATLlama(unittest.TestCase):
) )
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-5", cfg) check_model_output_exists(Path(temp_dir) / "checkpoint-5", cfg)

View File

@@ -4,7 +4,6 @@ E2E tests for reward model lora llama
import unittest import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config, validate_config
@@ -63,8 +62,7 @@ class TestRewardModelLoraSmolLM2(unittest.TestCase):
) )
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_tensorboard( check_tensorboard(

View File

@@ -4,7 +4,6 @@ E2E tests for custom schedulers using Llama
import unittest import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config, validate_config
@@ -57,8 +56,7 @@ class TestCustomSchedulers(unittest.TestCase):
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)

View File

@@ -6,8 +6,9 @@ import unittest
import pytest import pytest
from axolotl.loaders.tokenizer import load_tokenizer
from axolotl.prompt_strategies.dpo import load as load_dpo from axolotl.prompt_strategies.dpo import load as load_dpo
from axolotl.utils.data.rl import load_prepare_preference_datasets from axolotl.utils.data.rl import prepare_preference_datasets
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from tests.hf_offline_utils import enable_hf_offline from tests.hf_offline_utils import enable_hf_offline
@@ -55,7 +56,8 @@ class TestDPOChatml:
# test that dpo.load works # test that dpo.load works
load_dpo("chatml", cfg) load_dpo("chatml", cfg)
# now actually load the datasets with the strategy # now actually load the datasets with the strategy
train_ds, _ = load_prepare_preference_datasets(cfg) tokenizer = load_tokenizer(cfg)
train_ds, _ = prepare_preference_datasets(cfg, tokenizer)
assert train_ds[0]["prompt"].startswith("<|im_start|>") assert train_ds[0]["prompt"].startswith("<|im_start|>")
assert train_ds[0]["prompt"].endswith("<|im_start|>assistant\n") assert train_ds[0]["prompt"].endswith("<|im_start|>assistant\n")
assert "chosen" in train_ds[0] assert "chosen" in train_ds[0]

View File

@@ -1,10 +1,9 @@
""" """Test dataset loading under various conditions."""
Test dataset loading under various conditions.
"""
import shutil import shutil
import tempfile import tempfile
from pathlib import Path from pathlib import Path
from typing import Any, Generator
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
@@ -12,8 +11,9 @@ from datasets import Dataset
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
from axolotl.utils.data import load_tokenized_prepared_datasets from axolotl.loaders.tokenizer import load_tokenizer
from axolotl.utils.data.rl import load_prepare_preference_datasets from axolotl.utils.data.rl import prepare_preference_datasets
from axolotl.utils.data.sft import _load_tokenized_prepared_datasets
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from tests.constants import ( from tests.constants import (
@@ -28,7 +28,9 @@ class TestDatasetPreparation:
"""Test a configured dataloader.""" """Test a configured dataloader."""
@pytest.fixture @pytest.fixture
def tokenizer(self, tokenizer_huggyllama) -> PreTrainedTokenizer: def tokenizer(
self, tokenizer_huggyllama
) -> Generator[PreTrainedTokenizer, Any, Any]:
tokenizer_huggyllama.add_special_tokens(SPECIAL_TOKENS) tokenizer_huggyllama.add_special_tokens(SPECIAL_TOKENS)
yield tokenizer_huggyllama yield tokenizer_huggyllama
@@ -63,7 +65,10 @@ class TestDatasetPreparation:
} }
) )
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path) with patch(
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
):
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
assert len(dataset) == 2000 assert len(dataset) == 2000
assert "input_ids" in dataset.features assert "input_ids" in dataset.features
@@ -107,7 +112,10 @@ class TestDatasetPreparation:
} }
) )
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path) with patch(
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
):
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
assert len(dataset) == 2000 assert len(dataset) == 2000
assert "input_ids" in dataset.features assert "input_ids" in dataset.features
@@ -136,7 +144,10 @@ class TestDatasetPreparation:
} }
) )
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path) with patch(
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
):
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
assert len(dataset) == 1 assert len(dataset) == 1
assert "input_ids" in dataset.features assert "input_ids" in dataset.features
@@ -145,7 +156,7 @@ class TestDatasetPreparation:
@enable_hf_offline @enable_hf_offline
def test_load_from_dir_of_parquet(self, tokenizer, dataset_fixture): def test_load_from_dir_of_parquet(self, tokenizer, dataset_fixture):
"""Usual use case. Verify a directory of parquet files can be loaded.""" """Usual use case. Verify a directory of parquet files can be loaded."""
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_dir = Path(tmp_dir) / "tmp_dataset" tmp_ds_dir = Path(tmp_dir) / "tmp_dataset"
tmp_ds_dir.mkdir() tmp_ds_dir.mkdir()
@@ -171,7 +182,10 @@ class TestDatasetPreparation:
} }
) )
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path) with patch(
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
):
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
assert len(dataset) == 1 assert len(dataset) == 1
assert "input_ids" in dataset.features assert "input_ids" in dataset.features
@@ -206,7 +220,10 @@ class TestDatasetPreparation:
} }
) )
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path) with patch(
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
):
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
assert len(dataset) == 1 assert len(dataset) == 1
assert "input_ids" in dataset.features assert "input_ids" in dataset.features
@@ -235,7 +252,10 @@ class TestDatasetPreparation:
} }
) )
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path) with patch(
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
):
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
assert len(dataset) == 1 assert len(dataset) == 1
assert "input_ids" in dataset.features assert "input_ids" in dataset.features
@@ -264,7 +284,10 @@ class TestDatasetPreparation:
} }
) )
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path) with patch(
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
):
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
assert len(dataset) == 1 assert len(dataset) == 1
assert "input_ids" in dataset.features assert "input_ids" in dataset.features
@@ -286,7 +309,8 @@ class TestDatasetPreparation:
} }
) )
train_dataset, _ = load_prepare_preference_datasets(cfg) tokenizer = load_tokenizer(cfg)
train_dataset, _ = prepare_preference_datasets(cfg, tokenizer)
assert len(train_dataset) == 1800 assert len(train_dataset) == 1800
assert "conversation" not in train_dataset.features assert "conversation" not in train_dataset.features
@@ -318,7 +342,10 @@ class TestDatasetPreparation:
} }
) )
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path) with patch(
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
):
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
assert len(dataset) == 2000 assert len(dataset) == 2000
assert "input_ids" in dataset.features assert "input_ids" in dataset.features
@@ -342,13 +369,16 @@ class TestDatasetPreparation:
) )
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
with patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset: with patch(
"axolotl.utils.data.rl.load_dataset_with_config"
) as mock_load_dataset:
# Set up the mock to return different values on successive calls # Set up the mock to return different values on successive calls
mock_load_dataset.return_value = ( mock_load_dataset.return_value = (
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff
) )
train_dataset, _ = load_prepare_preference_datasets(cfg) tokenizer = load_tokenizer(cfg)
train_dataset, _ = prepare_preference_datasets(cfg, tokenizer)
assert len(train_dataset) == 1800 assert len(train_dataset) == 1800
assert "conversation" not in train_dataset.features assert "conversation" not in train_dataset.features
@@ -393,16 +423,18 @@ class TestDatasetPreparation:
) )
with patch( with patch(
"axolotl.utils.data.shared.load_dataset_w_config" "axolotl.utils.data.shared.load_dataset_with_config"
) as mock_load_dataset: ) as mock_load_dataset:
# Set up the mock to return different values on successive calls # Set up the mock to return different values on successive calls
mock_load_dataset.return_value = ( mock_load_dataset.return_value = (
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff
) )
dataset, _ = load_tokenized_prepared_datasets( with patch(
tokenizer, cfg, prepared_path "axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH",
) str(prepared_path),
):
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
assert len(dataset) == 2000 assert len(dataset) == 2000
assert "input_ids" in dataset.features assert "input_ids" in dataset.features
@@ -437,7 +469,10 @@ class TestDatasetPreparation:
} }
) )
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path) with patch(
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
):
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
assert len(dataset) == 2000 assert len(dataset) == 2000
assert "input_ids" in dataset.features assert "input_ids" in dataset.features

View File

@@ -5,7 +5,6 @@ Additionally, this test suite includes tests for functions that indirectly call
`deduplicate_and_log_datasets` during the execution of the preprocess command. `deduplicate_and_log_datasets` during the execution of the preprocess command.
""" """
import hashlib
import unittest import unittest
from unittest.mock import patch from unittest.mock import patch
@@ -14,8 +13,7 @@ from datasets import Dataset
from axolotl.loaders import load_processor, load_tokenizer from axolotl.loaders import load_processor, load_tokenizer
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.data import prepare_dataset from axolotl.utils.data import prepare_datasets, prepare_preference_datasets
from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.data.utils import deduplicate_and_log_datasets from axolotl.utils.data.utils import deduplicate_and_log_datasets
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -71,36 +69,14 @@ class TestDeduplicateIndividualFunctions(unittest.TestCase):
self.expected_dataset = Dataset.from_dict(self.expected_data) self.expected_dataset = Dataset.from_dict(self.expected_data)
def test_deduplication(self): def test_deduplication(self):
train_dataset, _, _ = deduplicate_and_log_datasets(train_dataset=self.dataset) train_dataset, _ = deduplicate_and_log_datasets(dataset=self.dataset)
_, eval_dataset, _ = deduplicate_and_log_datasets(eval_dataset=self.dataset) eval_dataset, _ = deduplicate_and_log_datasets(
dataset=self.dataset, dataset_name="eval"
)
verify_deduplication(train_dataset, self.expected_dataset, "train_dataset") verify_deduplication(train_dataset, self.expected_dataset, "train_dataset")
verify_deduplication(eval_dataset, self.expected_dataset, "eval_dataset") verify_deduplication(eval_dataset, self.expected_dataset, "eval_dataset")
def test_datasets_are_none(self):
# Test when both datasets are None
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
train_dataset=None, eval_dataset=None
)
self.assertIsNone(train_dataset, "Expected train_dataset to be None")
self.assertIsNone(eval_dataset, "Expected eval_dataset to be None")
def test_only_train_is_none(self):
# Test when only train_dataset is None
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
train_dataset=None, eval_dataset=self.dataset
)
self.assertIsNone(train_dataset, "Expected train_dataset to be None")
verify_deduplication(eval_dataset, self.expected_dataset, "eval_dataset")
def test_only_eval_is_none(self):
# Test when only eval_dataset is None
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
train_dataset=self.dataset, eval_dataset=None
)
self.assertIsNone(eval_dataset, "Expected eval_dataset to be None")
verify_deduplication(train_dataset, self.expected_dataset, "train_dataset")
def test_exact_duplicates(self): def test_exact_duplicates(self):
# Test when datasets are exact duplicates # Test when datasets are exact duplicates
duplicate_data = { duplicate_data = {
@@ -115,8 +91,10 @@ class TestDeduplicateIndividualFunctions(unittest.TestCase):
expected_dataset = Dataset.from_dict(expected_data) expected_dataset = Dataset.from_dict(expected_data)
# Run deduplication # Run deduplication
train_dataset, _, _ = deduplicate_and_log_datasets(train_dataset=dataset) train_dataset, _ = deduplicate_and_log_datasets(dataset=dataset)
_, eval_dataset, _ = deduplicate_and_log_datasets(eval_dataset=dataset) eval_dataset, _ = deduplicate_and_log_datasets(
dataset=dataset, dataset_name="eval"
)
verify_deduplication(train_dataset, expected_dataset, "train_dataset") verify_deduplication(train_dataset, expected_dataset, "train_dataset")
verify_deduplication(eval_dataset, expected_dataset, "eval_dataset") verify_deduplication(eval_dataset, expected_dataset, "eval_dataset")
@@ -139,8 +117,10 @@ class TestDeduplicateIndividualFunctions(unittest.TestCase):
expected_dataset = Dataset.from_dict(expected_data) expected_dataset = Dataset.from_dict(expected_data)
# Run deduplication # Run deduplication
train_dataset, _, _ = deduplicate_and_log_datasets(train_dataset=dataset) train_dataset, _ = deduplicate_and_log_datasets(dataset=dataset)
_, eval_dataset, _ = deduplicate_and_log_datasets(eval_dataset=dataset) eval_dataset, _ = deduplicate_and_log_datasets(
dataset=dataset, dataset_name="eval"
)
verify_deduplication(train_dataset, expected_dataset, "train_dataset") verify_deduplication(train_dataset, expected_dataset, "train_dataset")
verify_deduplication(eval_dataset, expected_dataset, "eval_dataset") verify_deduplication(eval_dataset, expected_dataset, "eval_dataset")
@@ -169,8 +149,8 @@ class TestDeduplicateIndividualFunctions(unittest.TestCase):
expected_dataset_eval = Dataset.from_dict(expected_data_eval) expected_dataset_eval = Dataset.from_dict(expected_data_eval)
# Run deduplication # Run deduplication
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets( train_dataset, eval_dataset = deduplicate_and_log_datasets(
train_dataset=dataset, eval_dataset=dataset dataset=dataset, other_dataset=dataset
) )
verify_deduplication(train_dataset, expected_dataset_train, "train_dataset") verify_deduplication(train_dataset, expected_dataset_train, "train_dataset")
@@ -206,8 +186,8 @@ class TestDeduplicateIndividualFunctions(unittest.TestCase):
expected_dataset_eval = Dataset.from_dict(expected_data_eval) expected_dataset_eval = Dataset.from_dict(expected_data_eval)
# Run deduplication # Run deduplication
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets( train_dataset, eval_dataset = deduplicate_and_log_datasets(
train_dataset=dataset_train, eval_dataset=dataset_eval dataset=dataset_train, other_dataset=dataset_eval
) )
verify_deduplication(train_dataset, expected_dataset_train, "train_dataset") verify_deduplication(train_dataset, expected_dataset_train, "train_dataset")
@@ -245,7 +225,9 @@ class TestDeduplicateRLDataset:
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
with ( with (
patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset, patch(
"axolotl.utils.data.rl.load_dataset_with_config"
) as mock_load_dataset,
patch("axolotl.loaders.load_tokenizer") as mock_load_tokenizer, patch("axolotl.loaders.load_tokenizer") as mock_load_tokenizer,
): ):
# Set up the mock to return different values on successive calls # Set up the mock to return different values on successive calls
@@ -255,7 +237,8 @@ class TestDeduplicateRLDataset:
] ]
mock_load_tokenizer.return_value = tokenizer_huggyllama mock_load_tokenizer.return_value = tokenizer_huggyllama
train_dataset, _ = load_prepare_preference_datasets(cfg) tokenizer = load_tokenizer(cfg)
train_dataset, _ = prepare_preference_datasets(cfg, tokenizer)
# Verify that the dataset has been deduplicated # Verify that the dataset has been deduplicated
assert len(train_dataset) == 1800, "Dataset was not properly deduplicated" assert len(train_dataset) == 1800, "Dataset was not properly deduplicated"
@@ -269,7 +252,9 @@ class TestDeduplicateRLDataset:
): ):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
with ( with (
patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset, patch(
"axolotl.utils.data.rl.load_dataset_with_config"
) as mock_load_dataset,
patch("axolotl.loaders.load_tokenizer") as mock_load_tokenizer, patch("axolotl.loaders.load_tokenizer") as mock_load_tokenizer,
): ):
# Set up the mock to return different values on successive calls # Set up the mock to return different values on successive calls
@@ -279,9 +264,10 @@ class TestDeduplicateRLDataset:
] ]
mock_load_tokenizer.return_value = tokenizer_huggyllama mock_load_tokenizer.return_value = tokenizer_huggyllama
cfg.dataset_exact_deduplication = False
# Load the dataset without deduplication # Load the dataset without deduplication
train_dataset, _ = load_prepare_preference_datasets(cfg) cfg.dataset_exact_deduplication = False
tokenizer = load_tokenizer(cfg)
train_dataset, _ = prepare_preference_datasets(cfg, tokenizer)
# Verify that the dataset retains duplicates # Verify that the dataset retains duplicates
assert ( assert (
@@ -335,7 +321,7 @@ class TestDeduplicateNonRL(unittest.TestCase):
) )
# Prepare dataset using the prepare_dataset function # Prepare dataset using the prepare_dataset function
train_dataset, _, _, _ = prepare_dataset( train_dataset, _, _, _ = prepare_datasets(
self.cfg_1, self.cfg_1,
tokenizer, tokenizer,
processor=processor, processor=processor,
@@ -362,7 +348,7 @@ class TestDeduplicateNonRL(unittest.TestCase):
) )
# Prepare dataset using the prepare_dataset function # Prepare dataset using the prepare_dataset function
_, eval_dataset, _, _ = prepare_dataset( _, eval_dataset, _, _ = prepare_datasets(
self.cfg_1, self.cfg_1,
tokenizer, tokenizer,
processor=processor, processor=processor,
@@ -389,7 +375,7 @@ class TestDeduplicateNonRL(unittest.TestCase):
) )
# Prepare dataset using the prepare_dataset function # Prepare dataset using the prepare_dataset function
train_dataset, eval_dataset, _, _ = prepare_dataset( train_dataset, eval_dataset, _, _ = prepare_datasets(
self.cfg_1, self.cfg_1,
tokenizer, tokenizer,
processor=processor, processor=processor,
@@ -428,41 +414,8 @@ class TestWrongCollisions(unittest.TestCase):
self.eval_dataset = Dataset.from_dict(self.eval_data) self.eval_dataset = Dataset.from_dict(self.eval_data)
self.dataset = Dataset.from_dict(self.dataset_data) self.dataset = Dataset.from_dict(self.dataset_data)
@patch(
"axolotl.utils.data.utils.sha256",
side_effect=lambda x: (
hashlib.sha256("forced_collision_hash".encode("utf-8")).hexdigest()
if "sample 5" in x
else hashlib.sha256(x.encode("utf-8")).hexdigest()
),
)
def test_deduplication_wrong_collision_train_eval(self, _mock_sha256):
dedup_train, dedup_eval, _ = deduplicate_and_log_datasets(
train_dataset=self.train_dataset, eval_dataset=self.eval_dataset
)
self.assertEqual(
len(dedup_train),
2,
"train dataset should not deduplicate rows with forced hash collisions but different labels.",
)
self.assertEqual(
len(dedup_eval),
2,
"Eval dataset should not deduplicate rows with forced hash collisions but different labels.",
)
self.assertEqual(
len(dedup_eval),
len(self.eval_dataset),
"The output eval dataset should have the same number of rows as the input eval dataset.",
)
self.assertEqual(
str(dedup_eval),
str(self.eval_dataset),
"The string representation of the output eval dataset should be identical to the input eval dataset.",
)
def test_deduplication_dataset_only(self): def test_deduplication_dataset_only(self):
_, _, dedup_dataset = deduplicate_and_log_datasets(dataset=self.dataset) dedup_dataset, _ = deduplicate_and_log_datasets(dataset=self.dataset)
self.assertEqual( self.assertEqual(
len(dedup_dataset), 3, "Dataset should have all original values" len(dedup_dataset), 3, "Dataset should have all original values"
) )