Data loader refactor (#2707)
* data loading refactor (wip) * updates * progress * pytest * pytest fix * lint * zero_first -> filelock, more simplifications * small simplification * import change * nit * lint * simplify dedup * couldnt resist * review comments WIP * continued wip * minor changes * fix; remove contrived test * further refactor * set default seed in pydantic config * lint * continued simplication * lint * renaming and nits * filelock tests * fix * fix * lint * remove nullable arg * remove unnecessary code * moving dataset save fn to shared module * remove debug print * matching var naming * fn name change * coderabbit comments * naming nit * fix test
This commit is contained in:
@@ -1,5 +1,3 @@
|
||||
"""
|
||||
Various shared constants
|
||||
"""
|
||||
"""Various shared constants"""
|
||||
|
||||
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
|
||||
|
||||
@@ -3,15 +3,13 @@
|
||||
import math
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
|
||||
from datasets import Dataset
|
||||
|
||||
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
|
||||
from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs
|
||||
from axolotl.loaders import load_processor, load_tokenizer
|
||||
from axolotl.utils.data import prepare_dataset
|
||||
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
||||
from axolotl.utils.data import prepare_datasets, prepare_preference_datasets
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.schemas.enums import RLType
|
||||
@@ -30,16 +28,7 @@ class TrainDatasetMeta:
|
||||
|
||||
|
||||
def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset:
|
||||
"""
|
||||
Randomly sample `num_samples` samples from `dataset`.
|
||||
|
||||
Args:
|
||||
dataset: Dataset.
|
||||
num_samples: Number of samples to return.
|
||||
|
||||
Returns:
|
||||
Random sample (with replacement) of examples in `dataset`.
|
||||
"""
|
||||
"""Randomly sample `num_samples` samples with replacement from `dataset`."""
|
||||
return dataset.select(
|
||||
[random.randrange(0, len(dataset) - 1) for _ in range(num_samples)] # nosec
|
||||
)
|
||||
@@ -51,44 +40,37 @@ def load_datasets(
|
||||
cli_args: PreprocessCliArgs | TrainerCliArgs | None = None,
|
||||
debug: bool = False,
|
||||
) -> TrainDatasetMeta:
|
||||
"""
|
||||
Loads one or more training or evaluation datasets, calling
|
||||
`axolotl.utils.data.prepare_dataset`. Optionally, logs out debug information.
|
||||
"""Loads one or more training or evaluation datasets, calling
|
||||
`axolotl.utils.data.prepare_datasets`. Optionally, logs out debug information.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
cli_args: Command-specific CLI arguments.
|
||||
debug: Whether to print out tokenization of sample
|
||||
debug: Whether to print out tokenization of sample. This is duplicated in
|
||||
`cfg` and `cli_args`, but is kept due to use in our Colab notebooks.
|
||||
|
||||
Returns:
|
||||
Dataclass with fields for training and evaluation datasets and the computed
|
||||
`total_num_steps`.
|
||||
`total_num_steps`.
|
||||
"""
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
|
||||
preprocess_iterable = (
|
||||
cli_args
|
||||
and hasattr(cli_args, "iterable")
|
||||
and cli_args.iterable is not None
|
||||
and cli_args.iterable
|
||||
)
|
||||
preprocess_iterable = getattr(cli_args, "iterable", False)
|
||||
|
||||
train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset(
|
||||
train_dataset, eval_dataset, total_num_steps, prompters = prepare_datasets(
|
||||
cfg,
|
||||
tokenizer,
|
||||
processor=processor,
|
||||
preprocess_iterable=preprocess_iterable,
|
||||
)
|
||||
|
||||
if ( # pylint: disable=too-many-boolean-expressions
|
||||
cli_args
|
||||
and (
|
||||
cli_args.debug
|
||||
or cfg.debug
|
||||
or cli_args.debug_text_only
|
||||
or int(cli_args.debug_num_examples) > 0
|
||||
)
|
||||
) or debug:
|
||||
if (
|
||||
cfg.debug
|
||||
or getattr(cli_args, "debug", False)
|
||||
or getattr(cli_args, "debug_text_only", False)
|
||||
or getattr(cli_args, "debug_num_examples", 0) > 0
|
||||
or debug
|
||||
):
|
||||
LOG.info("check_dataset_labels...")
|
||||
|
||||
num_examples = cli_args.debug_num_examples if cli_args else 1
|
||||
@@ -113,13 +95,10 @@ def load_datasets(
|
||||
|
||||
|
||||
def load_preference_datasets(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
cli_args: Union[PreprocessCliArgs, TrainerCliArgs],
|
||||
*, cfg: DictDefault, cli_args: PreprocessCliArgs | TrainerCliArgs
|
||||
) -> TrainDatasetMeta:
|
||||
"""
|
||||
Loads one or more training or evaluation datasets for RL training using paired
|
||||
preference data, calling `axolotl.utils.data.rl.load_prepare_preference_datasets`.
|
||||
"""Loads one or more training or evaluation datasets for RL training using paired
|
||||
preference data, calling `axolotl.utils.data.rl.prepare_preference_datasets`.
|
||||
Optionally, logs out debug information.
|
||||
|
||||
Args:
|
||||
@@ -130,12 +109,14 @@ def load_preference_datasets(
|
||||
Dataclass with fields for training and evaluation datasets and the computed
|
||||
`total_num_steps`.
|
||||
"""
|
||||
train_dataset, eval_dataset = load_prepare_preference_datasets(cfg)
|
||||
total_num_steps: Optional[int] = int(
|
||||
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
||||
)
|
||||
if cfg.rl is RLType.GRPO:
|
||||
total_num_steps = None
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
train_dataset, eval_dataset = prepare_preference_datasets(cfg, tokenizer)
|
||||
|
||||
total_num_steps: int | None = None
|
||||
if cfg.rl is not RLType.GRPO:
|
||||
total_num_steps = int(
|
||||
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
||||
)
|
||||
|
||||
if cli_args.debug or cfg.debug:
|
||||
LOG.info("check_dataset_labels...")
|
||||
@@ -143,8 +124,8 @@ def load_preference_datasets(
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
train_samples = sample_dataset(train_dataset, cli_args.debug_num_examples)
|
||||
check_dataset_labels(
|
||||
train_samples,
|
||||
tokenizer,
|
||||
dataset=train_samples,
|
||||
tokenizer=tokenizer,
|
||||
num_examples=cli_args.debug_num_examples,
|
||||
text_only=cli_args.debug_text_only,
|
||||
rl_mode=True,
|
||||
|
||||
@@ -381,7 +381,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
elif "tokenizer" in sig.parameters:
|
||||
trainer_kwargs["tokenizer"] = self.tokenizer
|
||||
if (
|
||||
not (trainer_cls in [AxolotlRewardTrainer, AxolotlPRMTrainer])
|
||||
trainer_cls not in [AxolotlRewardTrainer, AxolotlPRMTrainer]
|
||||
and self.cfg.datasets is not None
|
||||
):
|
||||
trainer_kwargs["dataset_tags"] = [
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""Module containing Dataset functionality"""
|
||||
|
||||
import os
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from datasets import Dataset, IterableDataset
|
||||
@@ -20,21 +19,21 @@ LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class TokenizedPromptDataset(Dataset):
|
||||
"""
|
||||
Dataset that returns tokenized prompts from a stream of text files.
|
||||
Args:
|
||||
prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for processing the data.
|
||||
dataset (dataset.Dataset): Dataset with text files.
|
||||
process_count (int): Number of processes to use for tokenizing.
|
||||
keep_in_memory (bool): Whether to keep the tokenized dataset in memory.
|
||||
"""Dataset that returns tokenized prompts from a stream of text files.
|
||||
|
||||
Args:
|
||||
prompt_tokenizer: The prompt tokenizing method for processing the data.
|
||||
dataset: Dataset with text files.
|
||||
process_count: Number of processes to use for tokenizing.
|
||||
keep_in_memory: Whether to keep the tokenized dataset in memory.
|
||||
"""
|
||||
|
||||
def __init__( # pylint: disable=super-init-not-called
|
||||
self,
|
||||
prompt_tokenizer: PromptTokenizingStrategy,
|
||||
dataset: Dataset,
|
||||
process_count: Optional[int] = None,
|
||||
keep_in_memory: Optional[bool] = False,
|
||||
process_count: int | None = None,
|
||||
keep_in_memory: bool | None = False,
|
||||
**kwargs,
|
||||
):
|
||||
self.prompt_tokenizer = prompt_tokenizer
|
||||
@@ -76,14 +75,14 @@ class TokenizedPromptDataset(Dataset):
|
||||
|
||||
def wrap_dataset_for_tokenized_prompt(
|
||||
prompt_tokenizer: PromptTokenizingStrategy,
|
||||
dataset: Union[Dataset, IterableDataset],
|
||||
dataset: Dataset | IterableDataset,
|
||||
**kwargs,
|
||||
):
|
||||
if isinstance(dataset, IterableDataset):
|
||||
map_kwargs = {}
|
||||
if prompt_tokenizer.supports_batched:
|
||||
map_kwargs["batched"] = True
|
||||
features = dataset.features.keys()
|
||||
features = list(dataset.features.keys())
|
||||
return dataset.map(
|
||||
prompt_tokenizer.tokenize_prompt,
|
||||
remove_columns=features,
|
||||
@@ -94,12 +93,13 @@ def wrap_dataset_for_tokenized_prompt(
|
||||
|
||||
# TODO this isn't the best since it can't interleave datasets
|
||||
class ConstantLengthDataset(IterableDataset):
|
||||
"""
|
||||
Iterable dataset that returns constant length chunks of tokens from stream of text files.
|
||||
Args:
|
||||
tokenizer (Tokenizer): The processor used for processing the data.
|
||||
dataset (dataset.Dataset): Dataset with text files.
|
||||
seq_length (int): Length of token sequences to return.
|
||||
"""Iterable dataset that returns constant length chunks of tokens from stream of
|
||||
text files.
|
||||
|
||||
Args:
|
||||
tokenizer: The processor used for processing the data.
|
||||
dataset: Dataset with text files.
|
||||
seq_length: Length of token sequences to return.
|
||||
"""
|
||||
|
||||
def __init__( # pylint: disable=super-init-not-called
|
||||
@@ -110,7 +110,7 @@ class ConstantLengthDataset(IterableDataset):
|
||||
):
|
||||
self.tokenizer = tokenizer
|
||||
self.concat_token_id = tokenizer.eos_token_id
|
||||
self.datasets: List[IterableDataset] = datasets
|
||||
self.datasets: list[IterableDataset] = datasets
|
||||
self.seq_length = seq_length
|
||||
|
||||
vocab_size = len(tokenizer.get_vocab())
|
||||
@@ -174,7 +174,10 @@ class ConstantLengthDataset(IterableDataset):
|
||||
}
|
||||
else:
|
||||
LOG.warning(
|
||||
f"dropping batch due to tensor size mismatch input_ids: {input_ids.size()}, labels: {labels.size()}, attention_mask: {attention_mask.size()}"
|
||||
"Dropping batch due to tensor size mismatch "
|
||||
f"input_ids: {input_ids.size()}, "
|
||||
f"labels: {labels.size()}, "
|
||||
f"attention_mask: {attention_mask.size()}"
|
||||
)
|
||||
buffer = {
|
||||
"input_ids": [],
|
||||
|
||||
@@ -7,12 +7,14 @@ import transformers
|
||||
from transformers import (
|
||||
AddedToken,
|
||||
AutoTokenizer,
|
||||
PreTrainedTokenizer,
|
||||
)
|
||||
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.loaders.utils import get_linear_embedding_layers, load_model_config
|
||||
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
||||
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import (
|
||||
barrier,
|
||||
is_local_main_process,
|
||||
@@ -117,7 +119,7 @@ def modify_tokenizer_files(
|
||||
return tokenizer_dir
|
||||
|
||||
|
||||
def load_tokenizer(cfg):
|
||||
def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
|
||||
"""Load and configure the tokenizer based on the provided config."""
|
||||
model_config = load_model_config(cfg)
|
||||
tokenizer_kwargs = {}
|
||||
@@ -207,11 +209,12 @@ def load_tokenizer(cfg):
|
||||
)
|
||||
and k != "pad_token"
|
||||
):
|
||||
lora_modules_to_save = ", ".join(
|
||||
lora_modules_to_save_str = ", ".join(
|
||||
[f"`{x}`" for x in lora_modules_to_save]
|
||||
)
|
||||
raise ValueError(
|
||||
f"Please set lora_modules_to_save to [{lora_modules_to_save}] when using an adapter and changing the special tokens."
|
||||
f"Please set lora_modules_to_save to [{lora_modules_to_save_str}] "
|
||||
"when using an adapter and changing the special tokens."
|
||||
)
|
||||
|
||||
tokenizer.add_special_tokens(
|
||||
|
||||
@@ -32,4 +32,3 @@ def load(tokenizer, cfg, ds_cfg, processor=None):
|
||||
except Exception as exc: # pylint: disable=broad-exception-caught
|
||||
LOG.error(f"Failed to load prompt strategy `{strategy}`: {str(exc)}")
|
||||
raise exc
|
||||
return None
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import abc
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from datasets import Dataset
|
||||
from transformers import BatchEncoding, PreTrainedTokenizer
|
||||
|
||||
from axolotl.prompters import Prompter
|
||||
@@ -28,6 +29,16 @@ class DatasetWrappingStrategy(abc.ABC):
|
||||
Abstract class for wrapping datasets for Chat Messages
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def wrap_dataset(
|
||||
self,
|
||||
dataset,
|
||||
process_count: int | None = None,
|
||||
keep_in_memory: bool | None = False,
|
||||
**kwargs,
|
||||
) -> Dataset:
|
||||
pass
|
||||
|
||||
|
||||
class PromptTokenizingStrategy(abc.ABC):
|
||||
"""
|
||||
|
||||
@@ -53,8 +53,8 @@ def setup_model_and_tokenizer(
|
||||
) -> tuple[
|
||||
PreTrainedModel, PreTrainedTokenizer, PeftConfig | None, ProcessorMixin | None
|
||||
]:
|
||||
"""
|
||||
Load the tokenizer, processor (for multimodal models), and model based on configuration.
|
||||
"""Load the tokenizer, processor (for multimodal models), and model based on
|
||||
configuration.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
|
||||
@@ -1,16 +1,21 @@
|
||||
"""
|
||||
Data processing modules
|
||||
"""
|
||||
"""Init for `axolotl.utils.data` module."""
|
||||
|
||||
from axolotl.utils.data.pretraining import ( # noqa: F401
|
||||
from axolotl.utils.data.pretraining import (
|
||||
encode_pretraining,
|
||||
wrap_pretraining_dataset,
|
||||
)
|
||||
from axolotl.utils.data.rl import load_prepare_preference_datasets # noqa: F401
|
||||
from axolotl.utils.data.sft import ( # noqa: F401
|
||||
from axolotl.utils.data.rl import prepare_preference_datasets
|
||||
from axolotl.utils.data.sft import (
|
||||
get_dataset_wrapper,
|
||||
load_prepare_datasets,
|
||||
load_tokenized_prepared_datasets,
|
||||
prepare_dataset,
|
||||
prepare_datasets,
|
||||
)
|
||||
from axolotl.utils.data.utils import md5 # noqa: F401
|
||||
from axolotl.utils.data.utils import md5
|
||||
|
||||
__all__ = [
|
||||
"encode_pretraining",
|
||||
"wrap_pretraining_dataset",
|
||||
"prepare_preference_datasets",
|
||||
"get_dataset_wrapper",
|
||||
"prepare_datasets",
|
||||
"md5",
|
||||
]
|
||||
|
||||
66
src/axolotl/utils/data/lock.py
Normal file
66
src/axolotl/utils/data/lock.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""Logic for loading / preparing a dataset once over all processes."""
|
||||
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable
|
||||
|
||||
from filelock import FileLock
|
||||
|
||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
LOCK_FILE_NAME = "datasets_prep.lock"
|
||||
READY_FILE_NAME = "datasets_ready.flag"
|
||||
PROCESS_COUNTER_FILE_NAME = "process_counter.txt"
|
||||
|
||||
|
||||
class FileLockLoader:
|
||||
"""
|
||||
Simple class for abstracting single process data loading / processing. The first
|
||||
process that creates a lock file does the work; the remaining procesees simply load
|
||||
the preprocessed dataset once the first process is done.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg: DictDefault):
|
||||
self.cfg = cfg
|
||||
self.dataset_prepared_path = (
|
||||
cfg.dataset_prepared_path or DEFAULT_DATASET_PREPARED_PATH
|
||||
)
|
||||
self.lock_file_path = Path(self.dataset_prepared_path) / LOCK_FILE_NAME
|
||||
self.ready_flag_path = Path(self.dataset_prepared_path) / READY_FILE_NAME
|
||||
self.counter_path = Path(self.dataset_prepared_path) / PROCESS_COUNTER_FILE_NAME
|
||||
|
||||
def load(self, load_fn: Callable[[], Any]) -> Any:
|
||||
with FileLock(str(self.lock_file_path)):
|
||||
self._increment_counter()
|
||||
|
||||
if not self.ready_flag_path.exists():
|
||||
result = load_fn()
|
||||
self.ready_flag_path.touch()
|
||||
return result
|
||||
|
||||
while not self.ready_flag_path.exists():
|
||||
time.sleep(1)
|
||||
return load_fn()
|
||||
|
||||
def _increment_counter(self):
|
||||
"""Safely increment the process counter."""
|
||||
if self.counter_path.exists():
|
||||
count = int(self.counter_path.read_text().strip())
|
||||
else:
|
||||
count = 0
|
||||
self.counter_path.write_text(str(count + 1))
|
||||
|
||||
def cleanup(self):
|
||||
"""Clean up ready flag when last process is done."""
|
||||
with FileLock(str(self.lock_file_path)):
|
||||
count = int(self.counter_path.read_text().strip())
|
||||
count -= 1
|
||||
|
||||
if count == 0:
|
||||
# Last process cleans everything up
|
||||
self.ready_flag_path.unlink(missing_ok=True)
|
||||
self.counter_path.unlink(missing_ok=True)
|
||||
else:
|
||||
# Still have active processes
|
||||
self.counter_path.write_text(str(count))
|
||||
@@ -250,7 +250,7 @@ def encode_packed_pretraining(
|
||||
# pylint: disable=duplicate-code
|
||||
# tokenize all the examples
|
||||
# rows get split with stride (overlap)
|
||||
train_dataset = ds_wrapper(Dataset.from_dict(examples))[0]
|
||||
train_dataset = ds_wrapper(dataset=Dataset.from_dict(examples))[0]
|
||||
|
||||
train_dataset = process_pretraining_datasets_for_packing(
|
||||
train_dataset,
|
||||
|
||||
@@ -1,75 +1,117 @@
|
||||
"""data handling specific to DPO"""
|
||||
"""Data handling specific to RL trainers."""
|
||||
|
||||
import inspect
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Union
|
||||
from typing import Any, Callable, Literal
|
||||
|
||||
import yaml
|
||||
from datasets import Dataset, DatasetDict, concatenate_datasets, load_from_disk
|
||||
from datasets import Dataset, DatasetDict
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||
from axolotl.loaders import load_tokenizer
|
||||
from axolotl.prompt_strategies.dpo import load as load_dpo
|
||||
from axolotl.prompt_strategies.kto import load as load_kto
|
||||
from axolotl.prompt_strategies.orpo import load as load_orpo
|
||||
from axolotl.utils.data.shared import datasets_w_name_generator, load_dataset_w_config
|
||||
from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5
|
||||
from axolotl.utils.data.lock import FileLockLoader
|
||||
from axolotl.utils.data.shared import (
|
||||
create_train_validation_split,
|
||||
datasets_with_name_generator,
|
||||
generate_dataset_hash_from_config,
|
||||
load_dataset_with_config,
|
||||
load_preprocessed_dataset,
|
||||
merge_datasets,
|
||||
save_preprocessed_dataset,
|
||||
try_load_from_hub,
|
||||
)
|
||||
from axolotl.utils.data.utils import (
|
||||
deduplicate_and_log_datasets,
|
||||
retry_on_request_exceptions,
|
||||
)
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import is_main_process, zero_first
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.schemas.enums import RLType
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def _get_path(ds_hash, cfg):
|
||||
prepared_ds_path = (
|
||||
Path(cfg.dataset_prepared_path) / ds_hash
|
||||
if cfg.dataset_prepared_path
|
||||
else Path(DEFAULT_DATASET_PREPARED_PATH) / ds_hash
|
||||
)
|
||||
@retry_on_request_exceptions(max_retries=3, delay=5)
|
||||
def prepare_preference_datasets(
|
||||
cfg: DictDefault, tokenizer: PreTrainedTokenizer
|
||||
) -> tuple[Dataset, Dataset | None]:
|
||||
"""Load and prepare preference datasets for RL training.
|
||||
|
||||
return prepared_ds_path
|
||||
Loads training and evaluation datasets, handling preprocessing, caching, and
|
||||
deduplication as configured. Uses FileLock for distributed coordination.
|
||||
|
||||
Args:
|
||||
cfg: Configuration object containing dataset and training settings.
|
||||
tokenizer: Tokenizer to use for processing text.
|
||||
|
||||
Returns:
|
||||
Tuple of (train_dataset, eval_dataset). eval_dataset may be None
|
||||
if no evaluation dataset is configured.
|
||||
"""
|
||||
|
||||
def _load_datasets():
|
||||
# Load training dataset
|
||||
train_dataset = _load_or_create_dataset_split(cfg, tokenizer, split="train")
|
||||
|
||||
# Load or create evaluation dataset
|
||||
eval_dataset: Dataset | None = None
|
||||
if cfg.test_datasets:
|
||||
eval_dataset = _load_or_create_dataset_split(cfg, tokenizer, split="test")
|
||||
elif cfg.val_set_size:
|
||||
# Create validation split from training data
|
||||
train_dataset, eval_dataset = create_train_validation_split(
|
||||
train_dataset, cfg, cfg.val_set_size
|
||||
)
|
||||
|
||||
return train_dataset, eval_dataset
|
||||
|
||||
# Prepare datasets (with file locking logic for multiple ranks)
|
||||
loader = FileLockLoader(cfg)
|
||||
try:
|
||||
train_dataset, eval_dataset = loader.load(_load_datasets)
|
||||
finally:
|
||||
loader.cleanup()
|
||||
|
||||
# Apply deduplication if configured
|
||||
if cfg.dataset_exact_deduplication:
|
||||
train_dataset, eval_dataset = deduplicate_and_log_datasets(
|
||||
dataset=train_dataset, other_dataset=eval_dataset
|
||||
)
|
||||
|
||||
return train_dataset, eval_dataset
|
||||
|
||||
|
||||
def _load_preprocessed_ds(cfg, sub_cfg):
|
||||
ds_hash = md5(yaml.dump(sub_cfg, Dumper=yaml.Dumper))
|
||||
prepared_ds_path = _get_path(ds_hash, cfg)
|
||||
dataset = None
|
||||
def _map_dataset(
|
||||
cfg: DictDefault,
|
||||
dataset: Dataset | DatasetDict,
|
||||
ds_transform_fn: Callable[..., Any],
|
||||
tokenizer: Any | None = None,
|
||||
**map_kwargs: Any,
|
||||
) -> Dataset:
|
||||
"""Apply transformation function to dataset.
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
if (
|
||||
cfg.dataset_prepared_path
|
||||
and any(prepared_ds_path.glob("*"))
|
||||
and not cfg.is_preprocess
|
||||
):
|
||||
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
|
||||
dataset = load_from_disk(str(prepared_ds_path))
|
||||
Args:
|
||||
cfg: Configuration object.
|
||||
dataset: Dataset to transform.
|
||||
ds_transform_fn: Transformation function to apply.
|
||||
tokenizer: Optional tokenizer for transformation.
|
||||
**map_kwargs: Additional arguments for dataset mapping.
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def _save_preprocessed_ds(cfg, sub_cfg, dataset):
|
||||
ds_hash = md5(yaml.dump(sub_cfg, Dumper=yaml.Dumper))
|
||||
prepared_ds_path = _get_path(ds_hash, cfg)
|
||||
|
||||
if cfg.is_preprocess and is_main_process():
|
||||
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
|
||||
dataset.save_to_disk(str(prepared_ds_path))
|
||||
|
||||
|
||||
def map_dataset(cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs):
|
||||
Returns:
|
||||
Transformed dataset.
|
||||
"""
|
||||
sig = inspect.signature(ds_transform_fn)
|
||||
if "tokenizer" in sig.parameters:
|
||||
if not tokenizer:
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
ds_transform_fn = partial(ds_transform_fn, tokenizer=tokenizer)
|
||||
|
||||
if isinstance(data_set, DatasetDict):
|
||||
data_set = data_set["train"]
|
||||
if isinstance(dataset, DatasetDict):
|
||||
dataset = dataset["train"]
|
||||
|
||||
data_set = data_set.map(
|
||||
dataset = dataset.map(
|
||||
ds_transform_fn,
|
||||
num_proc=cfg.dataset_processes,
|
||||
load_from_cache_file=not cfg.is_preprocess,
|
||||
@@ -77,13 +119,27 @@ def map_dataset(cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs):
|
||||
**map_kwargs,
|
||||
)
|
||||
|
||||
return data_set
|
||||
return dataset
|
||||
|
||||
|
||||
def drop_long_rl_seq(
|
||||
sample, rl, tokenizer, sequence_len # pylint: disable=invalid-name
|
||||
):
|
||||
if rl in (RLType.DPO, RLType.IPO, RLType.ORPO, RLType.SIMPO):
|
||||
def _drop_long_sequences(
|
||||
sample: dict[str, Any], rl: RLType, tokenizer: Any, sequence_len: int
|
||||
) -> bool:
|
||||
"""Filter out samples that exceed maximum sequence length.
|
||||
|
||||
Args:
|
||||
sample: Dataset sample to check.
|
||||
rl: Reinforcement learning type.
|
||||
tokenizer: Tokenizer for length calculation.
|
||||
sequence_len: Maximum allowed sequence length.
|
||||
|
||||
Returns:
|
||||
True if sample should be kept, False if it should be dropped.
|
||||
|
||||
Raises:
|
||||
ValueError: If required keys are missing or RL type is unknown.
|
||||
"""
|
||||
if rl in {RLType.DPO, RLType.IPO, RLType.ORPO, RLType.SIMPO}:
|
||||
if not (
|
||||
sample.get("prompt") and sample.get("chosen") and sample.get("rejected")
|
||||
):
|
||||
@@ -123,132 +179,115 @@ def drop_long_rl_seq(
|
||||
raise ValueError("Unknown RL type")
|
||||
|
||||
|
||||
def load_prepare_preference_datasets(cfg):
|
||||
def load_split(dataset_cfgs, _cfg):
|
||||
split_datasets: List[Any] = []
|
||||
use_auth_token = _cfg.hf_use_auth_token
|
||||
for config_dataset in datasets_w_name_generator(dataset_cfgs):
|
||||
ds: Union[Dataset, DatasetDict] = load_dataset_w_config(
|
||||
config_dataset, use_auth_token, streaming=False
|
||||
)
|
||||
split_datasets.append(ds)
|
||||
def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
|
||||
"""Load and process dataset split for RL training.
|
||||
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
Args:
|
||||
cfg: Configuration object containing dataset settings.
|
||||
split: Dataset split to load ("train" or "test").
|
||||
|
||||
for i, data_set in enumerate(split_datasets):
|
||||
_type = dataset_cfgs[i]["type"]
|
||||
if _type:
|
||||
if isinstance(_type, DictDefault):
|
||||
_type = "user_defined.default"
|
||||
if _cfg.rl is RLType.ORPO:
|
||||
ds_transform_fn = load_orpo(_type, _cfg, dataset_idx=i)
|
||||
elif _cfg.rl is RLType.KTO:
|
||||
ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i)
|
||||
else:
|
||||
ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
|
||||
Returns:
|
||||
Combined and processed dataset for the specified split.
|
||||
"""
|
||||
datasets_configs = cfg.datasets if split == "train" else cfg.test_datasets
|
||||
split_datasets: list[Dataset | DatasetDict] = []
|
||||
|
||||
map_kwargs = {}
|
||||
if isinstance(ds_transform_fn, tuple):
|
||||
ds_transform_fn, map_kwargs = ds_transform_fn
|
||||
split_datasets[i] = map_dataset(
|
||||
cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs
|
||||
)
|
||||
elif _cfg.rl is RLType.KTO:
|
||||
ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i)
|
||||
map_kwargs = {}
|
||||
if isinstance(ds_transform_fn, tuple):
|
||||
ds_transform_fn, map_kwargs = ds_transform_fn
|
||||
split_datasets[i] = map_dataset(
|
||||
cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs
|
||||
)
|
||||
else:
|
||||
# If no `type` is provided, assume the dataset is already in the expected format with
|
||||
# "prompt", "chosen" and "rejected" already preprocessed
|
||||
split_datasets[i] = data_set
|
||||
|
||||
if not cfg.skip_prepare_dataset:
|
||||
drop_long = partial(
|
||||
drop_long_rl_seq,
|
||||
rl=_cfg.rl,
|
||||
tokenizer=tokenizer,
|
||||
sequence_len=cfg.sequence_len,
|
||||
)
|
||||
|
||||
prior_len = len(split_datasets[i])
|
||||
split_datasets[i] = split_datasets[i].filter(
|
||||
drop_long,
|
||||
num_proc=cfg.dataset_processes,
|
||||
load_from_cache_file=not cfg.is_preprocess,
|
||||
desc="Dropping Long Sequences",
|
||||
)
|
||||
dropped = prior_len - len(split_datasets[i])
|
||||
if dropped:
|
||||
LOG.warning(
|
||||
f"Dropped {dropped} long samples from dataset index {i}"
|
||||
)
|
||||
|
||||
combined_datasets = concatenate_datasets(split_datasets)
|
||||
combined_datasets = combined_datasets.shuffle(seed=cfg.seed or 42)
|
||||
|
||||
return combined_datasets
|
||||
|
||||
with zero_first(is_main_process()):
|
||||
train_is_preprocessed = False
|
||||
eval_is_preprocessed = False
|
||||
if train_dataset := _load_preprocessed_ds(cfg, cfg.datasets):
|
||||
train_is_preprocessed = True
|
||||
else:
|
||||
train_dataset = load_split(cfg.datasets, cfg)
|
||||
|
||||
eval_dataset = None
|
||||
if cfg.test_datasets:
|
||||
if eval_dataset := _load_preprocessed_ds(cfg, cfg.test_datasets):
|
||||
eval_is_preprocessed = True
|
||||
else:
|
||||
eval_dataset = load_split(cfg.test_datasets, cfg)
|
||||
if not eval_dataset:
|
||||
if cfg.val_set_size:
|
||||
seed = cfg.seed if cfg.seed is not None else 42
|
||||
|
||||
# ensure we end up with the same fingerprint by doing rank0 first and being able to cache
|
||||
to_hash_train = (
|
||||
train_dataset._fingerprint # pylint: disable=protected-access
|
||||
+ "|"
|
||||
+ str(cfg.val_set_size)
|
||||
+ "|"
|
||||
+ "train"
|
||||
+ "|"
|
||||
+ str(cfg.seed or 42)
|
||||
)
|
||||
to_hash_test = (
|
||||
train_dataset._fingerprint # pylint: disable=protected-access
|
||||
+ "|"
|
||||
+ str(cfg.val_set_size)
|
||||
+ "|"
|
||||
+ "test"
|
||||
+ "|"
|
||||
+ str(cfg.seed or 42)
|
||||
)
|
||||
train_fingerprint = md5(to_hash_train)
|
||||
test_fingerprint = md5(to_hash_test)
|
||||
ds_w_test_split = train_dataset.train_test_split(
|
||||
test_size=cfg.val_set_size,
|
||||
seed=seed,
|
||||
shuffle=False,
|
||||
train_new_fingerprint=train_fingerprint,
|
||||
test_new_fingerprint=test_fingerprint,
|
||||
)
|
||||
eval_dataset = ds_w_test_split["test"]
|
||||
train_dataset = ds_w_test_split["train"]
|
||||
|
||||
if not train_is_preprocessed:
|
||||
_save_preprocessed_ds(cfg, cfg.datasets, train_dataset)
|
||||
if eval_dataset and not eval_is_preprocessed:
|
||||
_save_preprocessed_ds(cfg, cfg.test_datasets, eval_dataset)
|
||||
|
||||
if cfg.dataset_exact_deduplication:
|
||||
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
|
||||
train_dataset=train_dataset, eval_dataset=eval_dataset
|
||||
for dataset_config in datasets_with_name_generator(datasets_configs):
|
||||
dataset: Dataset | DatasetDict = load_dataset_with_config(
|
||||
dataset_config, cfg.hf_use_auth_token, streaming=False
|
||||
)
|
||||
split_datasets.append(dataset)
|
||||
|
||||
return train_dataset, eval_dataset
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
|
||||
for i, dataset in enumerate(split_datasets):
|
||||
_type = datasets_configs[i]["type"]
|
||||
if _type:
|
||||
if isinstance(_type, DictDefault):
|
||||
_type = "user_defined.default"
|
||||
if cfg.rl is RLType.ORPO:
|
||||
ds_transform_fn = load_orpo(_type, cfg, dataset_idx=i)
|
||||
elif cfg.rl is RLType.KTO:
|
||||
ds_transform_fn = load_kto(_type, cfg, dataset_idx=i)
|
||||
else:
|
||||
ds_transform_fn = load_dpo(_type, cfg, dataset_idx=i)
|
||||
|
||||
map_kwargs: dict[str, Any] = {}
|
||||
if isinstance(ds_transform_fn, tuple):
|
||||
ds_transform_fn, map_kwargs = ds_transform_fn
|
||||
split_datasets[i] = _map_dataset(
|
||||
cfg, dataset, ds_transform_fn, tokenizer, **map_kwargs
|
||||
)
|
||||
else:
|
||||
# If no `type` is provided, assume the dataset is already in the expected format with
|
||||
# "prompt", "chosen", and "rejected" already preprocessed
|
||||
split_datasets[i] = dataset
|
||||
|
||||
if not cfg.skip_prepare_dataset:
|
||||
drop_long = partial(
|
||||
_drop_long_sequences,
|
||||
rl=cfg.rl,
|
||||
tokenizer=tokenizer,
|
||||
sequence_len=cfg.sequence_len,
|
||||
)
|
||||
|
||||
prior_len = len(split_datasets[i])
|
||||
split_datasets[i] = split_datasets[i].filter(
|
||||
drop_long,
|
||||
num_proc=cfg.dataset_processes,
|
||||
load_from_cache_file=not cfg.is_preprocess,
|
||||
desc="Dropping Long Sequences",
|
||||
)
|
||||
dropped = prior_len - len(split_datasets[i])
|
||||
if dropped:
|
||||
LOG.warning(f"Dropped {dropped} long samples from dataset index {i}")
|
||||
|
||||
# Merge datasets
|
||||
dataset = merge_datasets(split_datasets, cfg)
|
||||
|
||||
if not cfg.skip_prepare_dataset:
|
||||
# Save preprocessed dataset
|
||||
dataset_hash = generate_dataset_hash_from_config(
|
||||
cfg, datasets_configs, tokenizer.name_or_path
|
||||
)
|
||||
save_preprocessed_dataset(cfg, dataset, dataset_hash, split)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
def _load_or_create_dataset_split(
|
||||
cfg: DictDefault, tokenizer: PreTrainedTokenizer, split: Literal["train", "test"]
|
||||
) -> Dataset:
|
||||
"""Load preprocessed dataset or create new one for given split.
|
||||
|
||||
Args:
|
||||
cfg: Configuration object.
|
||||
tokenizer: Tokenizer to use for processing text.
|
||||
split: Dataset split to load.
|
||||
|
||||
Returns:
|
||||
Tuple of (dataset, is_preprocessed).
|
||||
"""
|
||||
# Select correct dataset configuration based on split
|
||||
datasets_config = cfg.datasets if split == "train" else cfg.test_datasets
|
||||
|
||||
# Generate dataset hash for caching
|
||||
dataset_hash = generate_dataset_hash_from_config(
|
||||
cfg, datasets_config, tokenizer.name_or_path
|
||||
)
|
||||
|
||||
# Try loading from hub if push_dataset_to_hub is configured
|
||||
dataset = None
|
||||
if cfg.push_dataset_to_hub:
|
||||
dataset = try_load_from_hub(cfg, dataset_hash, split)
|
||||
|
||||
# Attempt to load preprocessed dataset
|
||||
if dataset is None:
|
||||
dataset = load_preprocessed_dataset(cfg, dataset_hash)
|
||||
|
||||
# Otherwise, load it
|
||||
if dataset is None:
|
||||
dataset = _load_split(cfg, split=split)
|
||||
|
||||
return dataset
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,11 +1,21 @@
|
||||
"""
|
||||
dataset loading shared utils
|
||||
"""
|
||||
"""Dataset loading shared utils."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Generator
|
||||
|
||||
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
|
||||
from datasets import (
|
||||
Dataset,
|
||||
DatasetDict,
|
||||
IterableDataset,
|
||||
IterableDatasetDict,
|
||||
concatenate_datasets,
|
||||
load_dataset,
|
||||
load_from_disk,
|
||||
)
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
from huggingface_hub.errors import (
|
||||
HFValidationError,
|
||||
@@ -13,78 +23,141 @@ from huggingface_hub.errors import (
|
||||
RevisionNotFoundError,
|
||||
)
|
||||
|
||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||
from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from adlfs import AzureBlobFileSystem
|
||||
from gcsfs import GCSFileSystem
|
||||
from ocifs import OCIFileSystem
|
||||
from s3fs import S3FileSystem
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
EXTENSIONS_TO_DATASET_TYPES = {
|
||||
".parquet": "parquet",
|
||||
".arrow": "arrow",
|
||||
".csv": "csv",
|
||||
".txt": "text",
|
||||
}
|
||||
|
||||
|
||||
def get_ds_type(config_dataset: DictDefault):
|
||||
"""
|
||||
Get the dataset type from the path if it's not specified
|
||||
"""
|
||||
ds_type = "json"
|
||||
if config_dataset.ds_type:
|
||||
ds_type = config_dataset.ds_type
|
||||
elif ".parquet" in config_dataset.path:
|
||||
ds_type = "parquet"
|
||||
elif ".arrow" in config_dataset.path:
|
||||
ds_type = "arrow"
|
||||
elif ".csv" in config_dataset.path:
|
||||
ds_type = "csv"
|
||||
elif ".txt" in config_dataset.path:
|
||||
ds_type = "text"
|
||||
return ds_type
|
||||
def get_dataset_type(dataset_config: DictDefault) -> str:
|
||||
"""Get the dataset type from the path if it's not specified."""
|
||||
if dataset_config.ds_type:
|
||||
return dataset_config.ds_type
|
||||
|
||||
for extension, dataset_type in EXTENSIONS_TO_DATASET_TYPES.items():
|
||||
if extension in dataset_config.path:
|
||||
return dataset_type
|
||||
|
||||
return "json"
|
||||
|
||||
|
||||
def datasets_w_name_generator(dataset_configs: list[DictDefault]):
|
||||
"""
|
||||
Yields dataset configs handling multiple names or preprocess_shards
|
||||
def datasets_with_name_generator(
|
||||
dataset_configs: list[DictDefault],
|
||||
) -> Generator[DictDefault, None, None]:
|
||||
"""Yields expanded dataset configurations based on multiple names or preprocessing
|
||||
shards.
|
||||
|
||||
When a dataset config has a list of names, it yields separate configs for each
|
||||
name. When a dataset config specifies preprocessing shards, it yields configs for
|
||||
each shard.
|
||||
|
||||
Args:
|
||||
dataset_configs: list of dataset configs (equivalent to cfg.datasets)
|
||||
dataset_configs: List of dataset configuration objects.
|
||||
|
||||
Yields:
|
||||
Individual dataset configurations, expanded as needed for names or shards.
|
||||
"""
|
||||
for dataset in dataset_configs:
|
||||
if dataset.name and isinstance(dataset.name, list):
|
||||
# load_dataset doesn't properly handle multiple named configurations
|
||||
# at the same time for a given dataset
|
||||
for name in dataset.name:
|
||||
yield DictDefault({**dataset, "name": name})
|
||||
elif dataset.preprocess_shards and not dataset.shards:
|
||||
for shard in range(dataset.preprocess_shards):
|
||||
for config in dataset_configs:
|
||||
if config.name and isinstance(config.name, list):
|
||||
for name in config.name:
|
||||
yield DictDefault({**config, "name": name})
|
||||
elif config.preprocess_shards and not config.shards:
|
||||
for shard_idx in range(config.preprocess_shards):
|
||||
yield DictDefault(
|
||||
{
|
||||
**dataset,
|
||||
"shards": dataset.preprocess_shards,
|
||||
"shards_idx": shard,
|
||||
**config,
|
||||
"shards": config.preprocess_shards,
|
||||
"shards_idx": shard_idx,
|
||||
}
|
||||
)
|
||||
else:
|
||||
yield dataset
|
||||
yield config
|
||||
|
||||
|
||||
def load_dataset_w_config(
|
||||
config_dataset: DictDefault, use_auth_token: bool, streaming=False
|
||||
) -> Union[Dataset, DatasetDict]:
|
||||
"""
|
||||
Load a dataset from a config
|
||||
def load_dataset_with_config(
|
||||
dataset_config: DictDefault, use_auth_token: bool, streaming=False
|
||||
) -> Dataset | IterableDataset:
|
||||
"""Load a dataset from a config. Handles datasets that are stored locally, in the
|
||||
HuggingFace Hub, in a remote filesystem (S3, GCS, Azure, OCI), a URL, or
|
||||
`data_files`.
|
||||
|
||||
Args:
|
||||
config_dataset: single dataset config
|
||||
use_auth_token: whether to use HF auth token
|
||||
streaming: whether to stream the dataset
|
||||
dataset_config: Single dataset config.
|
||||
use_auth_token: Whether to use HF auth token.
|
||||
streaming: Whether to stream the dataset.
|
||||
|
||||
Returns:
|
||||
Loaded dataset.
|
||||
"""
|
||||
# pylint: disable=invalid-name
|
||||
ds: Optional[Union[Dataset, DatasetDict]] = None # pylint: disable=invalid-name
|
||||
ds_from_hub = False
|
||||
# Set up common kwargs for dataset loading
|
||||
load_dataset_kwargs = {
|
||||
"split": dataset_config.split if dataset_config.split else None,
|
||||
"name": dataset_config.name,
|
||||
"streaming": streaming,
|
||||
"trust_remote_code": dataset_config.trust_remote_code,
|
||||
}
|
||||
|
||||
# First check if it's a local path
|
||||
if Path(dataset_config.path).exists():
|
||||
return _load_from_local_path(dataset_config, load_dataset_kwargs)
|
||||
|
||||
# Check if it's a HuggingFace dataset
|
||||
is_hub_dataset = _check_if_hub_dataset(dataset_config, use_auth_token)
|
||||
|
||||
# Check if it's a cloud storage path and get appropriate filesystem
|
||||
remote_fs, storage_options = _get_remote_filesystem(dataset_config.path)
|
||||
is_cloud_dataset = False
|
||||
if remote_fs:
|
||||
try:
|
||||
is_cloud_dataset = remote_fs.exists(dataset_config.path)
|
||||
except (FileNotFoundError, ConnectionError):
|
||||
pass
|
||||
|
||||
# Load from appropriate source
|
||||
if is_hub_dataset:
|
||||
return _load_from_hub(dataset_config, use_auth_token, load_dataset_kwargs)
|
||||
if is_cloud_dataset:
|
||||
return _load_from_cloud(
|
||||
dataset_config, remote_fs, storage_options, load_dataset_kwargs
|
||||
)
|
||||
if dataset_config.path.startswith("https://"):
|
||||
return _load_from_url(dataset_config, load_dataset_kwargs)
|
||||
if dataset_config.data_files:
|
||||
return _load_from_data_files(dataset_config, load_dataset_kwargs)
|
||||
|
||||
raise ValueError(
|
||||
f"The dataset could not be loaded. This could be due to a misconfigured dataset path "
|
||||
f"({dataset_config.path}). Try double-check your path / name / data_files. "
|
||||
f"This is not caused by the dataset type."
|
||||
)
|
||||
|
||||
|
||||
def _check_if_hub_dataset(dataset_config: DictDefault, use_auth_token: bool) -> bool:
|
||||
"""Check if a dataset exists on the HuggingFace Hub."""
|
||||
try:
|
||||
# this is just a basic check to see if the path is a
|
||||
# valid HF dataset that's loadable
|
||||
snapshot_download(
|
||||
repo_id=config_dataset.path,
|
||||
repo_id=dataset_config.path,
|
||||
repo_type="dataset",
|
||||
token=use_auth_token,
|
||||
revision=config_dataset.revision,
|
||||
revision=dataset_config.revision,
|
||||
ignore_patterns=["*"],
|
||||
)
|
||||
ds_from_hub = True
|
||||
return True
|
||||
except (
|
||||
RepositoryNotFoundError,
|
||||
RevisionNotFoundError,
|
||||
@@ -93,198 +166,373 @@ def load_dataset_w_config(
|
||||
HFValidationError,
|
||||
ValueError,
|
||||
):
|
||||
pass
|
||||
return False
|
||||
|
||||
ds_from_cloud = False
|
||||
storage_options: dict = {}
|
||||
remote_file_system = None
|
||||
if config_dataset.path.startswith("s3://"):
|
||||
|
||||
def _get_remote_filesystem(
|
||||
path: str,
|
||||
) -> tuple[
|
||||
S3FileSystem | GCSFileSystem | AzureBlobFileSystem | OCIFileSystem | None, dict
|
||||
]:
|
||||
"""Get the appropriate filesystem for a remote path."""
|
||||
if path.startswith("s3://"):
|
||||
try:
|
||||
import s3fs # type: ignore
|
||||
import s3fs
|
||||
|
||||
storage_options = {"anon": False}
|
||||
return s3fs.S3FileSystem(**storage_options), storage_options
|
||||
except ImportError as exc:
|
||||
raise ImportError("s3:// paths require s3fs to be installed") from exc
|
||||
|
||||
# Reads env, credentials from ~/.aws/credentials, or IAM metadata provider
|
||||
# https://s3fs.readthedocs.io/en/latest/index.html?highlight=storage_options#credentials
|
||||
storage_options = {"anon": False}
|
||||
remote_file_system = s3fs.S3FileSystem(**storage_options)
|
||||
elif config_dataset.path.startswith("gs://") or config_dataset.path.startswith(
|
||||
"gcs://"
|
||||
):
|
||||
elif path.startswith(("gs://", "gcs://")):
|
||||
try:
|
||||
import gcsfs # type: ignore
|
||||
import gcsfs
|
||||
|
||||
storage_options = {"token": None} # type: ignore
|
||||
return gcsfs.GCSFileSystem(**storage_options), storage_options
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"gs:// or gcs:// paths require gcsfs to be installed"
|
||||
) from exc
|
||||
|
||||
# gcsfs will use default credentials from the environment else anon
|
||||
# https://gcsfs.readthedocs.io/en/latest/#credentials
|
||||
storage_options = {"token": None}
|
||||
remote_file_system = gcsfs.GCSFileSystem(**storage_options)
|
||||
elif (
|
||||
config_dataset.path.startswith("adl://")
|
||||
or config_dataset.path.startswith("abfs://")
|
||||
or config_dataset.path.startswith("az://")
|
||||
):
|
||||
elif path.startswith(("adl://", "abfs://", "az://")):
|
||||
try:
|
||||
import adlfs
|
||||
|
||||
storage_options = {"anon": False}
|
||||
return adlfs.AzureBlobFileSystem(**storage_options), storage_options
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"adl:// or abfs:// paths require adlfs to be installed"
|
||||
) from exc
|
||||
|
||||
# # Ensure you have the following environment variables set:
|
||||
# # Gen 1
|
||||
# storage_options = {
|
||||
# "tenant_id": AZURE_STORAGE_TENANT_ID,
|
||||
# "client_id": AZURE_STORAGE_CLIENT_ID,
|
||||
# "client_secret": AZURE_STORAGE_CLIENT_SECRET,
|
||||
# }
|
||||
# # Gen 2
|
||||
# storage_options = {
|
||||
# "account_name": AZURE_STORAGE_ACCOUNT_NAME,
|
||||
# "account_key": AZURE_STORAGE_ACCOUNT_KEY,
|
||||
# }
|
||||
|
||||
# Reads env
|
||||
# https://github.com/fsspec/adlfs?tab=readme-ov-file#setting-credentials
|
||||
storage_options = {"anon": False}
|
||||
remote_file_system = adlfs.AzureBlobFileSystem(**storage_options)
|
||||
elif config_dataset.path.startswith("oci://"):
|
||||
elif path.startswith("oci://"):
|
||||
try:
|
||||
import ocifs
|
||||
|
||||
storage_options = {}
|
||||
return ocifs.OCIFileSystem(**storage_options), storage_options
|
||||
except ImportError as exc:
|
||||
raise ImportError("oci:// paths require ocifs to be installed") from exc
|
||||
|
||||
# https://ocifs.readthedocs.io/en/latest/getting-connected.html#Using-Environment-Variables
|
||||
remote_file_system = ocifs.OCIFileSystem(**storage_options)
|
||||
return None, {}
|
||||
|
||||
try:
|
||||
if remote_file_system and remote_file_system.exists(config_dataset.path):
|
||||
ds_from_cloud = True
|
||||
except (FileNotFoundError, ConnectionError):
|
||||
pass
|
||||
|
||||
# gather extra args from the config
|
||||
load_ds_kwargs = {}
|
||||
if config_dataset.split:
|
||||
load_ds_kwargs["split"] = config_dataset.split
|
||||
def _load_from_local_path(
|
||||
dataset_config: DictDefault, load_dataset_kwargs: dict
|
||||
) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict:
|
||||
"""Load a dataset from a local path."""
|
||||
local_path = Path(dataset_config.path)
|
||||
|
||||
if local_path.is_dir():
|
||||
if dataset_config.data_files:
|
||||
dataset_type = get_dataset_type(dataset_config)
|
||||
return load_dataset(
|
||||
dataset_type,
|
||||
data_files=dataset_config.data_files,
|
||||
**load_dataset_kwargs,
|
||||
)
|
||||
try:
|
||||
return load_from_disk(dataset_config.path)
|
||||
except FileNotFoundError:
|
||||
load_dataset_kwargs["streaming"] = False
|
||||
return load_dataset(dataset_config.path, **load_dataset_kwargs)
|
||||
elif local_path.is_file():
|
||||
dataset_type = get_dataset_type(dataset_config)
|
||||
load_dataset_kwargs["streaming"] = False
|
||||
return load_dataset(
|
||||
dataset_type,
|
||||
data_files=dataset_config.path,
|
||||
**load_dataset_kwargs,
|
||||
)
|
||||
else:
|
||||
load_ds_kwargs["split"] = None
|
||||
|
||||
# prefer local dataset, even if hub exists
|
||||
local_path = Path(config_dataset.path)
|
||||
if local_path.exists():
|
||||
if local_path.is_dir():
|
||||
if config_dataset.data_files:
|
||||
ds_type = get_ds_type(config_dataset)
|
||||
ds = load_dataset( # pylint: disable=invalid-name
|
||||
ds_type,
|
||||
name=config_dataset.name,
|
||||
data_files=config_dataset.data_files,
|
||||
streaming=streaming,
|
||||
**load_ds_kwargs,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
ds = load_from_disk(
|
||||
config_dataset.path
|
||||
) # pylint: disable=invalid-name
|
||||
except FileNotFoundError:
|
||||
ds = load_dataset(
|
||||
config_dataset.path,
|
||||
name=config_dataset.name,
|
||||
streaming=False,
|
||||
**load_ds_kwargs,
|
||||
)
|
||||
elif local_path.is_file():
|
||||
ds_type = get_ds_type(config_dataset)
|
||||
|
||||
ds = load_dataset( # pylint: disable=invalid-name
|
||||
ds_type,
|
||||
name=config_dataset.name,
|
||||
data_files=config_dataset.path,
|
||||
streaming=False,
|
||||
**load_ds_kwargs,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"unhandled dataset load: local path exists, but is neither a directory or a file"
|
||||
)
|
||||
elif ds_from_hub:
|
||||
ds = load_dataset(
|
||||
config_dataset.path,
|
||||
name=config_dataset.name,
|
||||
streaming=streaming,
|
||||
data_files=config_dataset.data_files,
|
||||
token=use_auth_token,
|
||||
revision=config_dataset.revision,
|
||||
trust_remote_code=config_dataset.trust_remote_code,
|
||||
**load_ds_kwargs,
|
||||
)
|
||||
elif ds_from_cloud and remote_file_system:
|
||||
if remote_file_system.isdir(config_dataset.path):
|
||||
ds = load_from_disk(
|
||||
config_dataset.path,
|
||||
storage_options=storage_options,
|
||||
)
|
||||
elif remote_file_system.isfile(config_dataset.path):
|
||||
ds_type = get_ds_type(config_dataset)
|
||||
ds = load_dataset(
|
||||
ds_type,
|
||||
name=config_dataset.name,
|
||||
data_files=config_dataset.path,
|
||||
streaming=streaming,
|
||||
storage_options=storage_options,
|
||||
trust_remote_code=config_dataset.trust_remote_code,
|
||||
**load_ds_kwargs,
|
||||
)
|
||||
elif config_dataset.path.startswith("https://"):
|
||||
ds_type = get_ds_type(config_dataset)
|
||||
ds = load_dataset(
|
||||
ds_type,
|
||||
name=config_dataset.name,
|
||||
data_files=config_dataset.path,
|
||||
streaming=streaming,
|
||||
storage_options=storage_options,
|
||||
trust_remote_code=config_dataset.trust_remote_code,
|
||||
**load_ds_kwargs,
|
||||
)
|
||||
elif config_dataset.data_files:
|
||||
fp: str | list[str] | None = None
|
||||
if isinstance(config_dataset.data_files, str):
|
||||
fp = hf_hub_download(
|
||||
repo_id=config_dataset.path,
|
||||
repo_type="dataset",
|
||||
filename=config_dataset.data_files,
|
||||
revision=config_dataset.revision,
|
||||
)
|
||||
elif isinstance(config_dataset.data_files, list):
|
||||
fp = []
|
||||
for file in config_dataset.data_files:
|
||||
fp.append(
|
||||
hf_hub_download(
|
||||
repo_id=config_dataset.path,
|
||||
repo_type="dataset",
|
||||
filename=file,
|
||||
revision=config_dataset.revision,
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError("data_files must be either a string or list of strings")
|
||||
ds = load_dataset(
|
||||
"json",
|
||||
name=config_dataset.name,
|
||||
data_files=fp,
|
||||
streaming=streaming,
|
||||
**load_ds_kwargs,
|
||||
)
|
||||
if not ds:
|
||||
raise ValueError(
|
||||
"The dataset could not be loaded. This could be due to a misconfigured dataset path "
|
||||
f"({config_dataset.path}). Try double-check your path / name / data_files. "
|
||||
"This is not caused by the dataset type."
|
||||
"Unhandled dataset load: local path exists, but is neither a directory or a file"
|
||||
)
|
||||
|
||||
return ds
|
||||
|
||||
def _load_from_hub(
|
||||
dataset_config: DictDefault, use_auth_token: bool, load_dataset_kwargs: dict
|
||||
) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict:
|
||||
"""Load a dataset from the HuggingFace Hub."""
|
||||
return load_dataset(
|
||||
dataset_config.path,
|
||||
data_files=dataset_config.data_files,
|
||||
token=use_auth_token,
|
||||
revision=dataset_config.revision,
|
||||
**load_dataset_kwargs,
|
||||
)
|
||||
|
||||
|
||||
def _load_from_cloud(
|
||||
dataset_config: DictDefault,
|
||||
remote_fs: S3FileSystem | GCSFileSystem | AzureBlobFileSystem | OCIFileSystem,
|
||||
storage_options: dict,
|
||||
load_dataset_kwargs: dict,
|
||||
) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict:
|
||||
"""Load a dataset from cloud storage."""
|
||||
if remote_fs.isdir(dataset_config.path):
|
||||
return load_from_disk(
|
||||
dataset_config.path,
|
||||
storage_options=storage_options,
|
||||
)
|
||||
|
||||
if remote_fs.isfile(dataset_config.path):
|
||||
dataset_type = get_dataset_type(dataset_config)
|
||||
return load_dataset(
|
||||
dataset_type,
|
||||
data_files=dataset_config.path,
|
||||
storage_options=storage_options,
|
||||
**load_dataset_kwargs,
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
f"Cloud path {dataset_config.path} is neither a directory nor a file"
|
||||
)
|
||||
|
||||
|
||||
def _load_from_url(
|
||||
dataset_config: DictDefault, load_dataset_kwargs: dict
|
||||
) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict:
|
||||
"""Load a dataset from a URL."""
|
||||
dataset_type = get_dataset_type(dataset_config)
|
||||
return load_dataset(
|
||||
dataset_type,
|
||||
data_files=dataset_config.path,
|
||||
**load_dataset_kwargs,
|
||||
)
|
||||
|
||||
|
||||
def _load_from_data_files(
|
||||
dataset_config: DictDefault, load_dataset_kwargs: dict
|
||||
) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict:
|
||||
"""Load a dataset from data files."""
|
||||
file_path = None
|
||||
|
||||
if isinstance(dataset_config.data_files, str):
|
||||
file_path = hf_hub_download(
|
||||
repo_id=dataset_config.path,
|
||||
repo_type="dataset",
|
||||
filename=dataset_config.data_files,
|
||||
revision=dataset_config.revision,
|
||||
)
|
||||
elif isinstance(dataset_config.data_files, list):
|
||||
file_path = [
|
||||
hf_hub_download(
|
||||
repo_id=dataset_config.path,
|
||||
repo_type="dataset",
|
||||
filename=file,
|
||||
revision=dataset_config.revision,
|
||||
)
|
||||
for file in dataset_config.data_files
|
||||
]
|
||||
else:
|
||||
raise ValueError("data_files must be either a string or list of strings")
|
||||
|
||||
return load_dataset("json", data_files=file_path, **load_dataset_kwargs)
|
||||
|
||||
|
||||
def generate_split_fingerprints(
|
||||
dataset: Dataset, val_set_size: int | float, seed: int
|
||||
) -> tuple[str, str]:
|
||||
"""Generate consistent fingerprints for train/test splits."""
|
||||
fingerprint = dataset._fingerprint # pylint: disable=protected-access
|
||||
|
||||
train_hash_input = f"{fingerprint}|{val_set_size}|train|{seed}"
|
||||
test_hash_input = f"{fingerprint}|{val_set_size}|test|{seed}"
|
||||
|
||||
train_fingerprint = md5(train_hash_input)
|
||||
test_fingerprint = md5(test_hash_input)
|
||||
|
||||
return train_fingerprint, test_fingerprint
|
||||
|
||||
|
||||
def get_prepared_dataset_path(cfg: DictDefault, dataset_hash: str) -> Path:
|
||||
"""Get standardized path for prepared datasets.
|
||||
|
||||
Args:
|
||||
cfg: Configuration object.
|
||||
dataset_hash: Hash identifying the specific dataset configuration.
|
||||
|
||||
Returns:
|
||||
Path where the prepared dataset should be stored.
|
||||
"""
|
||||
base_path = cfg.dataset_prepared_path or DEFAULT_DATASET_PREPARED_PATH
|
||||
return Path(base_path) / dataset_hash
|
||||
|
||||
|
||||
def create_train_validation_split(
|
||||
dataset: Dataset, cfg: DictDefault, val_set_size: int | float
|
||||
) -> tuple[Dataset, Dataset]:
|
||||
"""Create train/validation split with consistent fingerprinting.
|
||||
|
||||
Args:
|
||||
dataset: Dataset to split.
|
||||
cfg: Configuration object containing seed and other settings.
|
||||
val_set_size: Size of validation set (absolute number or fraction).
|
||||
|
||||
Returns:
|
||||
Tuple of (train_dataset, eval_dataset).
|
||||
"""
|
||||
train_fingerprint, test_fingerprint = generate_split_fingerprints(
|
||||
dataset, val_set_size, cfg.seed
|
||||
)
|
||||
|
||||
# Apply deduplication before splitting if configured
|
||||
if cfg.dataset_exact_deduplication:
|
||||
dataset, _ = deduplicate_and_log_datasets(dataset=dataset)
|
||||
|
||||
split_dataset = dataset.train_test_split(
|
||||
test_size=val_set_size,
|
||||
shuffle=False,
|
||||
seed=cfg.seed,
|
||||
train_new_fingerprint=train_fingerprint,
|
||||
test_new_fingerprint=test_fingerprint,
|
||||
)
|
||||
|
||||
return split_dataset["train"], split_dataset["test"]
|
||||
|
||||
|
||||
def _generate_from_iterable_dataset(
|
||||
dataset: IterableDataset, worker_id: list[int], num_workers: list[int]
|
||||
) -> Generator[Any, None, None]:
|
||||
"""Generator function to correctly split the dataset for each worker"""
|
||||
for i, item in enumerate(dataset):
|
||||
if i % num_workers[0] == worker_id[0]:
|
||||
yield item
|
||||
|
||||
|
||||
def save_preprocessed_dataset(
|
||||
cfg: DictDefault,
|
||||
dataset: Dataset,
|
||||
dataset_hash: str,
|
||||
split: str,
|
||||
) -> None:
|
||||
"""Save preprocessed dataset to disk and optionally push to the HF Hub."""
|
||||
prepared_ds_path = get_prepared_dataset_path(cfg, dataset_hash)
|
||||
if isinstance(dataset, IterableDataset):
|
||||
num_workers = cfg.dataset_processes
|
||||
|
||||
ds_from_iter = Dataset.from_generator(
|
||||
functools.partial(_generate_from_iterable_dataset, dataset),
|
||||
features=dataset.features,
|
||||
num_proc=num_workers,
|
||||
split=split,
|
||||
gen_kwargs={
|
||||
"worker_id": list(range(num_workers)),
|
||||
"num_workers": [num_workers] * num_workers,
|
||||
},
|
||||
)
|
||||
ds_from_iter.save_to_disk(str(prepared_ds_path))
|
||||
else:
|
||||
os.makedirs(prepared_ds_path, exist_ok=True)
|
||||
dataset.save_to_disk(str(prepared_ds_path))
|
||||
if cfg.push_dataset_to_hub:
|
||||
LOG.info(
|
||||
"Pushing merged prepared dataset to Huggingface hub at "
|
||||
f"{cfg.push_dataset_to_hub} (version {dataset_hash})...",
|
||||
main_process_only=False,
|
||||
)
|
||||
dataset.push_to_hub(
|
||||
cfg.push_dataset_to_hub,
|
||||
dataset_hash,
|
||||
private=True,
|
||||
)
|
||||
|
||||
|
||||
def load_preprocessed_dataset(cfg: DictDefault, dataset_hash: str) -> Dataset | None:
|
||||
"""Load preprocessed dataset from disk if available.
|
||||
|
||||
Args:
|
||||
cfg: Configuration object.
|
||||
dataset_hash: Hash identifying the dataset configuration.
|
||||
|
||||
Returns:
|
||||
Loaded dataset if found and conditions are met, None otherwise.
|
||||
"""
|
||||
prepared_ds_path = get_prepared_dataset_path(cfg, dataset_hash)
|
||||
|
||||
if (
|
||||
cfg.dataset_prepared_path
|
||||
and any(prepared_ds_path.glob("*"))
|
||||
and not cfg.skip_prepare_dataset
|
||||
and not cfg.is_preprocess
|
||||
):
|
||||
LOG.info(
|
||||
f"Loading prepared dataset from disk at {prepared_ds_path}...",
|
||||
main_process_only=False,
|
||||
)
|
||||
return load_from_disk(str(prepared_ds_path))
|
||||
|
||||
LOG.info(
|
||||
f"Unable to find prepared dataset in {prepared_ds_path}",
|
||||
main_process_only=False,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def try_load_from_hub(
|
||||
cfg: DictDefault, dataset_hash: str, split: str
|
||||
) -> Dataset | None:
|
||||
"""Try to load the prepared dataset from HuggingFace Hub."""
|
||||
try:
|
||||
LOG.info(
|
||||
"Attempting to load prepared dataset from HuggingFace Hub at "
|
||||
f"{cfg.push_dataset_to_hub} (version {dataset_hash})..."
|
||||
)
|
||||
dataset = load_dataset(
|
||||
cfg.push_dataset_to_hub,
|
||||
dataset_hash,
|
||||
token=cfg.hf_use_auth_token,
|
||||
)
|
||||
return dataset[split]
|
||||
except Exception: # pylint: disable=broad-except # nosec
|
||||
LOG.info("Unable to find prepared dataset in HuggingFace Hub")
|
||||
return None
|
||||
|
||||
|
||||
def generate_dataset_hash_from_config(
|
||||
cfg: DictDefault, cfg_datasets: list, tokenizer_name: str
|
||||
) -> str:
|
||||
"""Generate a hash to uniquely identify a dataset configuration for SFT.
|
||||
|
||||
Args:
|
||||
cfg: Main configuration object.
|
||||
cfg_datasets: List of dataset configurations.
|
||||
tokenizer_name: Name of the tokenizer being used.
|
||||
|
||||
Returns:
|
||||
MD5 hash string representing the configuration.
|
||||
"""
|
||||
config_str = (
|
||||
f"{cfg.sequence_len}@{cfg.sample_packing}@{cfg.eval_sample_packing}@"
|
||||
f"{cfg.group_by_length}@{cfg.kd_temperature or 1.0}|"
|
||||
f"{'|'.join(sorted([f'{d.path}:{d.type}:{d.shards}:{d.conversation}:{d.split}:{d.temperature or 1.0}' for d in cfg_datasets]))}"
|
||||
f"|{tokenizer_name}"
|
||||
)
|
||||
return str(md5(config_str))
|
||||
|
||||
|
||||
def merge_datasets(datasets: list[Dataset], cfg: DictDefault) -> Dataset:
|
||||
"""Merge multiple datasets into one with optional shuffling.
|
||||
|
||||
Args:
|
||||
datasets: List of datasets to merge.
|
||||
cfg: Configuration object containing shuffle settings.
|
||||
|
||||
Returns:
|
||||
Merged dataset.
|
||||
"""
|
||||
if len(datasets) == 1:
|
||||
return datasets[0]
|
||||
|
||||
LOG.info("Merging datasets...")
|
||||
merged_dataset = concatenate_datasets(datasets)
|
||||
|
||||
if cfg.shuffle_merged_datasets:
|
||||
LOG.debug("Shuffling merged datasets...")
|
||||
merged_dataset = merged_dataset.shuffle(seed=cfg.seed)
|
||||
else:
|
||||
LOG.debug("Not shuffling merged datasets.")
|
||||
|
||||
return merged_dataset
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
"""data handling helpers"""
|
||||
"""Data handling helpers"""
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import hashlib
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import Callable
|
||||
|
||||
import huggingface_hub
|
||||
import numpy as np
|
||||
@@ -19,9 +21,7 @@ LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class RetryStrategy(Enum):
|
||||
"""
|
||||
Enum for retry strategies.
|
||||
"""
|
||||
"""Enum for retry strategies."""
|
||||
|
||||
CONSTANT = 1
|
||||
LINEAR = 2
|
||||
@@ -30,7 +30,18 @@ class RetryStrategy(Enum):
|
||||
|
||||
def retry_on_request_exceptions(
|
||||
max_retries=3, delay=1, retry_strategy: RetryStrategy = RetryStrategy.LINEAR
|
||||
):
|
||||
) -> Callable:
|
||||
"""Decorator that retries function calls on specific request exceptions.
|
||||
|
||||
Args:
|
||||
max_retries: Maximum number of retry attempts.
|
||||
delay: Base delay between retries in seconds.
|
||||
retry_strategy: Strategy for calculating retry delays.
|
||||
|
||||
Returns:
|
||||
Decorated function with retry logic.
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements
|
||||
@@ -59,6 +70,7 @@ def retry_on_request_exceptions(
|
||||
|
||||
|
||||
def md5(to_hash: str, encoding: str = "utf-8") -> str:
|
||||
"""Generate MD5 hash of a string."""
|
||||
try:
|
||||
return hashlib.md5(to_hash.encode(encoding), usedforsecurity=False).hexdigest()
|
||||
except TypeError:
|
||||
@@ -66,102 +78,89 @@ def md5(to_hash: str, encoding: str = "utf-8") -> str:
|
||||
|
||||
|
||||
def sha256(to_hash: str, encoding: str = "utf-8") -> str:
|
||||
"""Generate SHA256 hash of a string."""
|
||||
return hashlib.sha256(to_hash.encode(encoding)).hexdigest()
|
||||
|
||||
|
||||
def deduplicate_dataset(
|
||||
dataset: Dataset, seen_hashes: dict[str, list[int]], other_dataset: Dataset = None
|
||||
) -> Dataset:
|
||||
unique_indices = []
|
||||
def _deduplicate_dataset(
|
||||
dataset: Dataset,
|
||||
seen_hashes: set[str] | None = None,
|
||||
) -> tuple[Dataset, set[str]]:
|
||||
"""Remove duplicate rows from a dataset using SHA256 hashes.
|
||||
|
||||
Args:
|
||||
dataset: Dataset to deduplicate.
|
||||
seen_hashes: Set of previously seen row hashes (for cross-deduplication).
|
||||
|
||||
Returns:
|
||||
Tuple of deduplicated dataset and the set of seen hashes.
|
||||
"""
|
||||
if seen_hashes is None:
|
||||
seen_hashes = set()
|
||||
|
||||
unique_indices = []
|
||||
for idx, row in enumerate(dataset):
|
||||
row_hash = sha256(str(row)) # Using SHA256 for collision resistance.
|
||||
row_hash = sha256(str(row)) # Using SHA256 for collision resistance
|
||||
if row_hash not in seen_hashes:
|
||||
seen_hashes[row_hash] = [idx]
|
||||
seen_hashes.add(row_hash)
|
||||
unique_indices.append(idx)
|
||||
else:
|
||||
# Check for collision by looking up the original dataset indices
|
||||
original_indices = seen_hashes[row_hash]
|
||||
is_duplicate = False
|
||||
for original_idx in original_indices:
|
||||
if (
|
||||
not idx == original_idx
|
||||
and original_idx < len(dataset)
|
||||
and str(dataset[original_idx]) == str(row)
|
||||
):
|
||||
is_duplicate = True
|
||||
break
|
||||
# Check in the other dataset if provided
|
||||
if other_dataset is not None:
|
||||
if original_idx < len(other_dataset) and str(
|
||||
other_dataset[original_idx]
|
||||
) == str(row):
|
||||
is_duplicate = True
|
||||
break
|
||||
if not is_duplicate:
|
||||
seen_hashes[row_hash].append(idx)
|
||||
unique_indices.append(idx)
|
||||
continue
|
||||
return dataset.select(unique_indices)
|
||||
|
||||
return dataset.select(unique_indices), seen_hashes
|
||||
|
||||
|
||||
def deduplicate_and_log_datasets(
|
||||
*,
|
||||
train_dataset: Dataset = None,
|
||||
eval_dataset: Dataset = None,
|
||||
dataset: Dataset = None,
|
||||
) -> tuple[Dataset, Dataset, Dataset]:
|
||||
"""
|
||||
Deduplicates train, eval, and an optional dataset if provided, logging original and new sizes.
|
||||
dataset: Dataset,
|
||||
other_dataset: Dataset | None = None,
|
||||
dataset_name: str | None = "train",
|
||||
other_name: str | None = "eval",
|
||||
) -> tuple[Dataset, Dataset | None]:
|
||||
"""Deduplicate datasets, with optional cross-dataset deduplication.
|
||||
|
||||
Args:
|
||||
dataset: Primary dataset to deduplicate.
|
||||
other_dataset: Optional second dataset to deduplicate against the first.
|
||||
dataset_name: Name for the primary dataset (for logging).
|
||||
other_name: Name for the second dataset (for logging).
|
||||
|
||||
Returns:
|
||||
tuple: Deduplicated train, eval, and additional datasets.
|
||||
Tuple of (deduplicated_dataset, deduplicated_other_dataset).
|
||||
"""
|
||||
seen_hashes: dict[str, list[int]] = {}
|
||||
# Deduplicate primary dataset
|
||||
LOG.info(
|
||||
f"Starting deduplication for {dataset_name} dataset. Original size: {len(dataset)}"
|
||||
)
|
||||
dataset, seen_rows = _deduplicate_dataset(dataset)
|
||||
LOG.info(
|
||||
f"Deduplication complete for {dataset_name} dataset. New size: {len(dataset)}"
|
||||
)
|
||||
|
||||
# Handle cases where datasets are None
|
||||
if train_dataset is not None:
|
||||
# Deduplicate second dataset if provided
|
||||
if other_dataset is not None:
|
||||
LOG.info(
|
||||
f"Starting deduplication for train dataset. Original size: {len(train_dataset)}"
|
||||
)
|
||||
train_dataset = deduplicate_dataset(
|
||||
dataset=train_dataset, seen_hashes=seen_hashes
|
||||
f"Starting deduplication for {other_name} dataset. Original size: {len(other_dataset)}"
|
||||
)
|
||||
other_dataset, _ = _deduplicate_dataset(other_dataset, seen_rows)
|
||||
LOG.info(
|
||||
f"Deduplication complete for train dataset. New size: {len(train_dataset)}"
|
||||
)
|
||||
else:
|
||||
LOG.info("Train dataset is None. Skipping deduplication.")
|
||||
|
||||
if eval_dataset is not None:
|
||||
LOG.info(
|
||||
f"Starting deduplication for eval dataset. Original size: {len(eval_dataset)}"
|
||||
)
|
||||
eval_dataset = deduplicate_dataset(
|
||||
dataset=eval_dataset, seen_hashes=seen_hashes, other_dataset=train_dataset
|
||||
)
|
||||
LOG.info(
|
||||
f"Deduplication complete for eval dataset. New size: {len(eval_dataset)}"
|
||||
)
|
||||
else:
|
||||
LOG.info("Eval dataset is None. Skipping deduplication.")
|
||||
|
||||
if dataset is not None and (eval_dataset is None and train_dataset is None):
|
||||
LOG.info(
|
||||
f"Starting deduplication for combined dataset. Original size: {len(dataset)}"
|
||||
)
|
||||
dataset = deduplicate_dataset(dataset=dataset, seen_hashes=seen_hashes)
|
||||
LOG.info(
|
||||
f"Deduplication complete for combined dataset. New size: {len(dataset)}"
|
||||
f"Deduplication complete for {other_name} dataset. New size: {len(other_dataset)}"
|
||||
)
|
||||
|
||||
return train_dataset, eval_dataset, dataset
|
||||
return dataset, other_dataset
|
||||
|
||||
|
||||
def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault):
|
||||
def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault) -> Dataset:
|
||||
"""Remove sequences longer than configured maximum from dataset.
|
||||
|
||||
Args:
|
||||
dataset: Dataset to filter.
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
|
||||
Returns:
|
||||
Filtered dataset with long sequences removed.
|
||||
"""
|
||||
if "input_ids" not in dataset.column_names:
|
||||
LOG.warning(
|
||||
"Dataset does not contain 'input_ids' column. Skip drop long seq. This is expected for RewardModeling."
|
||||
"Dataset does not contain 'input_ids' column. Skip drop long seq. This is "
|
||||
"expected for reward modeling."
|
||||
)
|
||||
return dataset
|
||||
|
||||
@@ -171,20 +170,14 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault):
|
||||
min_sequence_len=cfg.min_sample_len,
|
||||
)
|
||||
|
||||
try:
|
||||
with contextlib.suppress(AttributeError):
|
||||
ds_lengths = get_dataset_lengths(dataset, from_arrow=True)
|
||||
min_input_len = np.min(ds_lengths)
|
||||
LOG.info(f"min_input_len: {min_input_len}")
|
||||
max_input_len = np.max(ds_lengths)
|
||||
LOG.info(f"max_input_len: {max_input_len}")
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
try:
|
||||
prior_len = len(dataset)
|
||||
except TypeError:
|
||||
# handle iterable datasets case
|
||||
prior_len = None
|
||||
prior_len = len(dataset) if hasattr(dataset, "__len__") else None
|
||||
|
||||
filter_map_kwargs = {}
|
||||
if not isinstance(dataset, IterableDataset):
|
||||
|
||||
425
src/axolotl/utils/data/wrappers.py
Normal file
425
src/axolotl/utils/data/wrappers.py
Normal file
@@ -0,0 +1,425 @@
|
||||
"""Data handling specific to SFT."""
|
||||
|
||||
import logging
|
||||
from typing import Any, NoReturn, cast
|
||||
|
||||
from datasets import (
|
||||
Dataset,
|
||||
IterableDataset,
|
||||
Sequence,
|
||||
Value,
|
||||
)
|
||||
from transformers import PreTrainedTokenizer
|
||||
from transformers.processing_utils import ProcessorMixin
|
||||
|
||||
from axolotl.datasets import TokenizedPromptDataset, wrap_dataset_for_tokenized_prompt
|
||||
from axolotl.prompt_strategies import load
|
||||
from axolotl.prompt_strategies.bradley_terry import load as bradley_terry_load
|
||||
from axolotl.prompt_tokenizers import (
|
||||
AlpacaMultipleChoicePromptTokenizingStrategy,
|
||||
AlpacaPromptTokenizingStrategy,
|
||||
AlpacaReflectionPTStrategy,
|
||||
DatasetWrappingStrategy,
|
||||
GPTeacherPromptTokenizingStrategy,
|
||||
JeopardyPromptTokenizingStrategy,
|
||||
OpenAssistantPromptTokenizingStrategy,
|
||||
PromptTokenizingStrategy,
|
||||
SummarizeTLDRPromptTokenizingStrategy,
|
||||
)
|
||||
from axolotl.prompters import (
|
||||
AlpacaPrompter,
|
||||
GPTeacherPrompter,
|
||||
JeopardyPrompter,
|
||||
MultipleChoiceConcisePrompter,
|
||||
MultipleChoiceExplainPrompter,
|
||||
Prompter,
|
||||
ReflectAlpacaPrompter,
|
||||
SummarizeTLDRPrompter,
|
||||
UnsupportedPrompter,
|
||||
)
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def handle_unknown_dataset_strategy(dataset_config: DictDefault) -> NoReturn:
|
||||
"""Raise error for unknown dataset strategy."""
|
||||
ds_type = dataset_config.type
|
||||
suffix = ""
|
||||
if ":load_" in ds_type:
|
||||
suffix = f"Did you mean {ds_type.replace(':load_', '.load_')}?"
|
||||
|
||||
error_message = f"unhandled prompt tokenization strategy: {ds_type}. {suffix}"
|
||||
LOG.error(error_message)
|
||||
raise ValueError(error_message)
|
||||
|
||||
|
||||
# pylint: disable=too-many-return-statements
|
||||
def get_dataset_wrapper(
|
||||
dataset_config: DictDefault,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
cfg: DictDefault,
|
||||
dataset_base_type: str | None,
|
||||
dataset: Dataset | IterableDataset,
|
||||
dataset_prompt_style: str | None = None,
|
||||
processor: ProcessorMixin | None = None, # pylint: disable=unused-argument
|
||||
) -> tuple[Dataset | IterableDataset, Prompter | None]:
|
||||
"""Create an appropriate dataset wrapper and prompter based on dataset
|
||||
configuration.
|
||||
|
||||
Args:
|
||||
dataset_config: Configuration for the dataset.
|
||||
tokenizer: Tokenizer to use for processing text.
|
||||
cfg: Global configuration object.
|
||||
dataset_base_type: The base type of the dataset.
|
||||
dataset: The actual dataset object.
|
||||
dataset_prompt_style: Optional prompt style specification.
|
||||
processor: Optional processor for multimodal datasets.
|
||||
|
||||
Returns:
|
||||
tuple of (dataset_wrapper, dataset_prompter).
|
||||
"""
|
||||
# Common parameters for dataset wrapping
|
||||
dataset_kwargs: dict[str, Any] = {
|
||||
"process_count": cfg.dataset_processes,
|
||||
"keep_in_memory": cfg.dataset_keep_in_memory is True,
|
||||
}
|
||||
|
||||
LOG.info(
|
||||
f"Loading dataset: {dataset_config['path']} with base_type: "
|
||||
f"{dataset_base_type} and prompt_style: {dataset_prompt_style}"
|
||||
)
|
||||
|
||||
# Dataset is already tokenized
|
||||
if _is_dataset_already_tokenized(dataset):
|
||||
return dataset, UnsupportedPrompter()
|
||||
|
||||
# Custom dataset type definition
|
||||
if isinstance(dataset_config.type, DictDefault):
|
||||
return _handle_custom_dataset_type(
|
||||
dataset_config, tokenizer, cfg, dataset, dataset_kwargs
|
||||
)
|
||||
|
||||
# Skip preparation if configured
|
||||
if cfg.skip_prepare_dataset:
|
||||
return dataset, None
|
||||
|
||||
# Bradley-Terry dataset
|
||||
if dataset_config.type.startswith("bradley_terry"):
|
||||
return _handle_bradley_terry_dataset(
|
||||
dataset_config, tokenizer, cfg, dataset, dataset_kwargs
|
||||
)
|
||||
|
||||
# Stepwise supervised dataset
|
||||
if dataset_config.type.startswith("stepwise_supervised"):
|
||||
return _handle_stepwise_supervised_dataset(
|
||||
dataset_config, tokenizer, cfg, dataset, dataset_kwargs
|
||||
)
|
||||
|
||||
# Try to load prompt tokenizer / dataset wrapper strategy from registry
|
||||
dataset_strategy = load(
|
||||
dataset_config.type, tokenizer, cfg, dataset_config, processor=processor
|
||||
)
|
||||
if dataset_strategy:
|
||||
return _handle_loaded_strategy(dataset_strategy, dataset, dataset_kwargs)
|
||||
|
||||
# Known dataset types with specific handling
|
||||
if dataset_base_type in DATASET_HANDLERS:
|
||||
handler = DATASET_HANDLERS[dataset_base_type]
|
||||
return handler(dataset_prompt_style, tokenizer, cfg, dataset, dataset_kwargs)
|
||||
|
||||
# Unhandled dataset type
|
||||
handle_unknown_dataset_strategy(dataset_config)
|
||||
|
||||
|
||||
def _is_dataset_already_tokenized(dataset: Dataset | IterableDataset) -> bool:
|
||||
"""Check if the dataset is already tokenized."""
|
||||
return (
|
||||
isinstance(dataset, Dataset)
|
||||
and "input_ids" in dataset.features
|
||||
and "attention_mask" in dataset.features
|
||||
and "labels" in dataset.features
|
||||
)
|
||||
|
||||
|
||||
def _handle_custom_dataset_type(
|
||||
dataset_config: DictDefault,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
cfg: DictDefault,
|
||||
dataset: Dataset | IterableDataset,
|
||||
dataset_kwargs: dict[str, Any],
|
||||
) -> tuple[Dataset | IterableDataset, Prompter]:
|
||||
"""Handle a custom dataset type defined in the configuration."""
|
||||
dataset_strategy = cast(
|
||||
PromptTokenizingStrategy,
|
||||
load("user_defined", tokenizer, cfg, dataset_config.type.to_dict()),
|
||||
)
|
||||
dataset_prompter = UnsupportedPrompter()
|
||||
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||
dataset_strategy,
|
||||
dataset,
|
||||
**dataset_kwargs,
|
||||
)
|
||||
return dataset_wrapper, dataset_prompter
|
||||
|
||||
|
||||
def _handle_bradley_terry_dataset(
|
||||
dataset_config: DictDefault,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
cfg: DictDefault,
|
||||
dataset: Dataset | IterableDataset,
|
||||
dataset_kwargs: dict[str, Any],
|
||||
) -> tuple[Dataset | IterableDataset, Prompter | None]:
|
||||
"""Handle a Bradley-Terry dataset."""
|
||||
bt_type = dataset_config.type.split(".", 1)[1]
|
||||
dataset_strategy = bradley_terry_load(bt_type, tokenizer, cfg, dataset_config)
|
||||
|
||||
if not dataset_strategy:
|
||||
handle_unknown_dataset_strategy(dataset_config)
|
||||
|
||||
dataset_prompter = UnsupportedPrompter()
|
||||
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||
dataset_strategy,
|
||||
dataset,
|
||||
**dataset_kwargs,
|
||||
)
|
||||
|
||||
return dataset_wrapper, dataset_prompter
|
||||
|
||||
|
||||
def _handle_stepwise_supervised_dataset(
|
||||
dataset_config: DictDefault,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
cfg: DictDefault,
|
||||
dataset: Dataset | IterableDataset,
|
||||
dataset_kwargs: dict[str, Any],
|
||||
) -> tuple[Dataset | IterableDataset, Prompter]:
|
||||
"""Handle a stepwise supervised dataset."""
|
||||
dataset_prompter = UnsupportedPrompter()
|
||||
dataset_strategy = load(dataset_config.type, tokenizer, cfg, dataset_config)
|
||||
|
||||
# We need to explicitly cast boolean labels to int
|
||||
# for compatibility with how trl's PRMTrainer works
|
||||
if isinstance(dataset, Dataset):
|
||||
dataset = dataset.cast_column("labels", Sequence(Value("int64")))
|
||||
|
||||
dataset_wrapper = TokenizedPromptDataset(
|
||||
dataset_strategy,
|
||||
dataset,
|
||||
**dataset_kwargs,
|
||||
)
|
||||
return dataset_wrapper, dataset_prompter
|
||||
|
||||
|
||||
def _handle_loaded_strategy(
|
||||
dataset_strategy: PromptTokenizingStrategy | DatasetWrappingStrategy,
|
||||
dataset: Dataset | IterableDataset,
|
||||
dataset_kwargs: dict[str, Any],
|
||||
) -> tuple[Dataset | IterableDataset, Prompter | None]:
|
||||
"""Handle a dataset with a strategy loaded from the registry."""
|
||||
if isinstance(dataset_strategy, DatasetWrappingStrategy):
|
||||
return dataset_strategy.wrap_dataset(dataset, **dataset_kwargs), None
|
||||
|
||||
dataset_prompter = UnsupportedPrompter()
|
||||
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||
dataset_strategy,
|
||||
dataset,
|
||||
**dataset_kwargs,
|
||||
)
|
||||
return dataset_wrapper, dataset_prompter
|
||||
|
||||
|
||||
def _handle_alpaca_dataset(
|
||||
dataset_prompt_style: str | None,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
cfg: DictDefault,
|
||||
dataset: Dataset | IterableDataset,
|
||||
dataset_kwargs: dict[str, Any],
|
||||
) -> tuple[Dataset | IterableDataset, Prompter]:
|
||||
"""Handle an Alpaca dataset."""
|
||||
dataset_prompter = AlpacaPrompter(dataset_prompt_style)
|
||||
dataset_strategy = AlpacaPromptTokenizingStrategy(
|
||||
dataset_prompter,
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||
dataset_strategy,
|
||||
dataset,
|
||||
**dataset_kwargs,
|
||||
)
|
||||
return dataset_wrapper, dataset_prompter
|
||||
|
||||
|
||||
def _handle_explainchoice_dataset(
|
||||
dataset_prompt_style: str | None,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
cfg: DictDefault,
|
||||
dataset: Dataset | IterableDataset,
|
||||
dataset_kwargs: dict[str, Any],
|
||||
) -> tuple[Dataset | IterableDataset, Prompter]:
|
||||
"""Handle an ExplainChoice dataset."""
|
||||
dataset_prompter = MultipleChoiceExplainPrompter(dataset_prompt_style)
|
||||
dataset_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
|
||||
dataset_prompter,
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||
dataset_strategy,
|
||||
dataset,
|
||||
**dataset_kwargs,
|
||||
)
|
||||
return dataset_wrapper, dataset_prompter
|
||||
|
||||
|
||||
def _handle_concisechoice_dataset(
|
||||
dataset_prompt_style: str | None,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
cfg: DictDefault,
|
||||
dataset: Dataset | IterableDataset,
|
||||
dataset_kwargs: dict[str, Any],
|
||||
) -> tuple[Dataset | IterableDataset, Prompter]:
|
||||
"""Handle a ConciseChoice dataset."""
|
||||
dataset_prompter = MultipleChoiceConcisePrompter(dataset_prompt_style)
|
||||
dataset_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
|
||||
dataset_prompter,
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||
dataset_strategy,
|
||||
dataset,
|
||||
**dataset_kwargs,
|
||||
)
|
||||
return dataset_wrapper, dataset_prompter
|
||||
|
||||
|
||||
def _handle_summarizetldr_dataset(
|
||||
dataset_prompt_style: str | None,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
cfg: DictDefault,
|
||||
dataset: Dataset | IterableDataset,
|
||||
dataset_kwargs: dict[str, Any],
|
||||
) -> tuple[Dataset | IterableDataset, Prompter]:
|
||||
"""Handle a SummarizeTLDR dataset."""
|
||||
dataset_prompter = SummarizeTLDRPrompter(dataset_prompt_style)
|
||||
dataset_strategy = SummarizeTLDRPromptTokenizingStrategy(
|
||||
dataset_prompter,
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||
dataset_strategy,
|
||||
dataset,
|
||||
**dataset_kwargs,
|
||||
)
|
||||
return dataset_wrapper, dataset_prompter
|
||||
|
||||
|
||||
def _handle_jeopardy_dataset(
|
||||
dataset_prompt_style: str | None,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
cfg: DictDefault,
|
||||
dataset: Dataset | IterableDataset,
|
||||
dataset_kwargs: dict[str, Any],
|
||||
) -> tuple[Dataset | IterableDataset, Prompter]:
|
||||
"""Handle a Jeopardy dataset."""
|
||||
dataset_prompter = JeopardyPrompter(dataset_prompt_style)
|
||||
dataset_strategy = JeopardyPromptTokenizingStrategy(
|
||||
dataset_prompter,
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||
dataset_strategy,
|
||||
dataset,
|
||||
**dataset_kwargs,
|
||||
)
|
||||
return dataset_wrapper, dataset_prompter
|
||||
|
||||
|
||||
def _handle_oasst_dataset(
|
||||
dataset_prompt_style: str | None,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
cfg: DictDefault,
|
||||
dataset: Dataset | IterableDataset,
|
||||
dataset_kwargs: dict[str, Any],
|
||||
) -> tuple[Dataset | IterableDataset, Prompter]:
|
||||
"""Handle an OpenAssistant dataset."""
|
||||
dataset_prompter = AlpacaPrompter(dataset_prompt_style)
|
||||
dataset_strategy = OpenAssistantPromptTokenizingStrategy(
|
||||
dataset_prompter,
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||
dataset_strategy,
|
||||
dataset,
|
||||
**dataset_kwargs,
|
||||
)
|
||||
return dataset_wrapper, dataset_prompter
|
||||
|
||||
|
||||
def _handle_gpteacher_dataset(
|
||||
dataset_prompt_style: str | None,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
cfg: DictDefault,
|
||||
dataset: Dataset | IterableDataset,
|
||||
dataset_kwargs: dict[str, Any],
|
||||
) -> tuple[Dataset | IterableDataset, Prompter]:
|
||||
"""Handle a GPTeacher dataset."""
|
||||
dataset_prompter = GPTeacherPrompter(dataset_prompt_style)
|
||||
dataset_strategy = GPTeacherPromptTokenizingStrategy(
|
||||
dataset_prompter,
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||
dataset_strategy,
|
||||
dataset,
|
||||
**dataset_kwargs,
|
||||
)
|
||||
return dataset_wrapper, dataset_prompter
|
||||
|
||||
|
||||
def _handle_reflection_dataset(
|
||||
dataset_prompt_style: str | None,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
cfg: DictDefault,
|
||||
dataset: Dataset | IterableDataset,
|
||||
dataset_kwargs: dict[str, Any],
|
||||
) -> tuple[Dataset | IterableDataset, Prompter]:
|
||||
"""Handle a Reflection dataset."""
|
||||
dataset_prompter = ReflectAlpacaPrompter(dataset_prompt_style)
|
||||
dataset_strategy = AlpacaReflectionPTStrategy(
|
||||
dataset_prompter,
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||
dataset_strategy,
|
||||
dataset,
|
||||
**dataset_kwargs,
|
||||
)
|
||||
return dataset_wrapper, dataset_prompter
|
||||
|
||||
|
||||
DATASET_HANDLERS = {
|
||||
"alpaca": _handle_alpaca_dataset,
|
||||
"explainchoice": _handle_explainchoice_dataset,
|
||||
"concisechoice": _handle_concisechoice_dataset,
|
||||
"summarizetldr": _handle_summarizetldr_dataset,
|
||||
"jeopardy": _handle_jeopardy_dataset,
|
||||
"oasst": _handle_oasst_dataset,
|
||||
"gpteacher": _handle_gpteacher_dataset,
|
||||
"reflection": _handle_reflection_dataset,
|
||||
}
|
||||
@@ -336,6 +336,14 @@ class AxolotlInputConfig(
|
||||
|
||||
plugins: list[str] | None = Field(default=None)
|
||||
|
||||
@field_validator("seed", mode="after")
|
||||
@classmethod
|
||||
def set_default_seed(cls, seed):
|
||||
if seed is None:
|
||||
LOG.info("`seed` not set in config; setting to 42")
|
||||
seed = 42
|
||||
return seed
|
||||
|
||||
@field_validator("datasets", mode="before")
|
||||
@classmethod
|
||||
def deprecate_sharegpt_datasets(cls, datasets):
|
||||
@@ -1199,7 +1207,7 @@ class AxolotlInputConfig(
|
||||
"flash_attention: true must be set with sequence_parallel_degree > 1"
|
||||
)
|
||||
|
||||
if self.sample_packing and self.micro_batch_size > 1:
|
||||
if self.sample_packing and getattr(self, "micro_batch_size", 1) > 1:
|
||||
raise ValueError(
|
||||
"micro_batch_size must be set to 1 when sample_packing is enabled "
|
||||
"due to a `ring-flash-attn` requirement"
|
||||
|
||||
@@ -12,7 +12,7 @@ from axolotl.common.datasets import load_datasets
|
||||
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||
from axolotl.loaders import ModelLoader, load_tokenizer
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
||||
from axolotl.utils.data import prepare_preference_datasets
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.schemas.enums import RLType
|
||||
|
||||
@@ -451,15 +451,19 @@ def rand_reward_func(prompts, completions) -> list[float]:
|
||||
# Only use mock for the commented out configs
|
||||
if dataset_name is not None:
|
||||
with patch(
|
||||
"axolotl.utils.data.rl.load_dataset_w_config"
|
||||
"axolotl.utils.data.rl.load_dataset_with_config"
|
||||
) as mock_load_dataset:
|
||||
mock_load_dataset.return_value = request.getfixturevalue(
|
||||
dataset_name
|
||||
)
|
||||
train_dataset, eval_dataset = load_prepare_preference_datasets(cfg)
|
||||
train_dataset, eval_dataset = prepare_preference_datasets(
|
||||
cfg, tokenizer
|
||||
)
|
||||
else:
|
||||
# Load actual datasets for orpo_cfg and kto_cfg
|
||||
train_dataset, eval_dataset = load_prepare_preference_datasets(cfg)
|
||||
train_dataset, eval_dataset = prepare_preference_datasets(
|
||||
cfg, tokenizer
|
||||
)
|
||||
|
||||
builder.train_dataset = train_dataset
|
||||
builder.eval_dataset = eval_dataset
|
||||
|
||||
@@ -4,7 +4,6 @@ Simple end-to-end test for Cut Cross Entropy integration
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils import get_pytorch_version
|
||||
@@ -59,8 +58,7 @@ class TestCutCrossEntropyIntegration:
|
||||
cfg = validate_config(cfg)
|
||||
prepare_plugins(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
major, minor, _ = get_pytorch_version()
|
||||
if (major, minor) < (2, 4):
|
||||
@@ -105,8 +103,7 @@ class TestCutCrossEntropyIntegration:
|
||||
cfg = validate_config(cfg)
|
||||
prepare_plugins(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
major, minor, _ = get_pytorch_version()
|
||||
if (major, minor) < (2, 4):
|
||||
@@ -134,8 +131,7 @@ class TestCutCrossEntropyIntegration:
|
||||
cfg = validate_config(cfg)
|
||||
prepare_plugins(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
major, minor, _ = get_pytorch_version()
|
||||
if (major, minor) < (2, 4):
|
||||
|
||||
@@ -5,7 +5,6 @@ e2e tests to make sure all the hooks are fired on the plugin
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
from axolotl.train import train
|
||||
@@ -160,8 +159,7 @@ class TestPluginHooks:
|
||||
cfg = validate_config(cfg)
|
||||
prepare_plugins(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -6,7 +6,6 @@ from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
||||
@@ -84,8 +83,7 @@ class TestKnowledgeDistillation:
|
||||
cfg = validate_config(cfg)
|
||||
prepare_plugins(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||
@@ -115,8 +113,7 @@ class TestKnowledgeDistillation:
|
||||
cfg = validate_config(cfg)
|
||||
prepare_plugins(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
Simple end-to-end test for Liger integration
|
||||
"""
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
||||
@@ -57,8 +56,7 @@ class LigerIntegrationTestCase:
|
||||
cfg = validate_config(cfg)
|
||||
prepare_plugins(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -104,8 +102,7 @@ class LigerIntegrationTestCase:
|
||||
cfg = validate_config(cfg)
|
||||
prepare_plugins(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -6,7 +6,6 @@ from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
||||
@@ -88,8 +87,7 @@ class TestLLMCompressorIntegration:
|
||||
prepare_plugins(cfg)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
try:
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
|
||||
@@ -105,7 +105,7 @@ def start_vllm(
|
||||
print(f"{i}: VLLM server failed to start: {str(exc)}")
|
||||
|
||||
# also check if the process.pid is still running
|
||||
if not process.poll() is None:
|
||||
if process.poll() is not None:
|
||||
break
|
||||
|
||||
time.sleep(period_seconds)
|
||||
|
||||
192
tests/e2e/multigpu/test_locking.py
Normal file
192
tests/e2e/multigpu/test_locking.py
Normal file
@@ -0,0 +1,192 @@
|
||||
"""Tests for FileLockLoader class."""
|
||||
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.utils.data.lock import FileLockLoader
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
class TestFileLockLoader:
|
||||
"""Class with tests for FileLockLoader."""
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir(self):
|
||||
"""Create a temporary directory for testing."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
yield Path(tmp_dir)
|
||||
|
||||
@pytest.fixture
|
||||
def cfg(self, temp_dir):
|
||||
"""Create a test configuration."""
|
||||
return DictDefault({"dataset_prepared_path": str(temp_dir)})
|
||||
|
||||
@pytest.fixture
|
||||
def loader(self, cfg):
|
||||
"""Create a FileLockLoader instance for testing."""
|
||||
return FileLockLoader(cfg)
|
||||
|
||||
def test_load_first_process(self, loader):
|
||||
"""Test load() when no ready flag exists (first process)."""
|
||||
mock_load_fn = Mock(return_value="test_data")
|
||||
|
||||
result = loader.load(mock_load_fn)
|
||||
|
||||
# Should call the load function
|
||||
mock_load_fn.assert_called_once()
|
||||
assert result == "test_data"
|
||||
|
||||
# Should create the ready flag
|
||||
assert loader.ready_flag_path.exists()
|
||||
|
||||
def test_load_subsequent_process(self, loader):
|
||||
"""Test load() when ready flag already exists (subsequent process)."""
|
||||
# Create ready flag first
|
||||
loader.ready_flag_path.touch()
|
||||
|
||||
mock_load_fn = Mock(return_value="loaded_data")
|
||||
|
||||
result = loader.load(mock_load_fn)
|
||||
|
||||
# Should still call load function (to load the prepared data)
|
||||
mock_load_fn.assert_called_once()
|
||||
assert result == "loaded_data"
|
||||
|
||||
def test_load_concurrent_processes(self, cfg):
|
||||
"""Test that concurrent processes coordinate correctly."""
|
||||
results = []
|
||||
call_count = 0
|
||||
|
||||
def slow_load_fn():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
time.sleep(0.1) # Simulate slow loading
|
||||
return f"data_{call_count}"
|
||||
|
||||
def worker():
|
||||
loader = FileLockLoader(cfg)
|
||||
result = loader.load(slow_load_fn)
|
||||
results.append(result)
|
||||
|
||||
# Start multiple threads simultaneously
|
||||
threads = [threading.Thread(target=worker) for _ in range(3)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# Only one thread should have done the initial loading
|
||||
# All should return data, but the load function should be called
|
||||
# once by the first process and once by each subsequent process
|
||||
assert len(results) == 3
|
||||
assert all(result.startswith("data_") for result in results)
|
||||
|
||||
@patch("time.sleep")
|
||||
def test_load_waiting_for_ready_flag(self, mock_sleep, loader):
|
||||
"""Test that processes wait for the ready flag to appear."""
|
||||
mock_load_fn = Mock(return_value="waiting_data")
|
||||
mock_ready_flag_path = Mock()
|
||||
exists_call_count = 0
|
||||
|
||||
def mock_exists():
|
||||
nonlocal exists_call_count
|
||||
exists_call_count += 1
|
||||
|
||||
if exists_call_count == 1:
|
||||
# First check: ready flag exists (not first process)
|
||||
return True
|
||||
if exists_call_count <= 3:
|
||||
# While loop checks: flag doesn't exist yet
|
||||
return False
|
||||
return True
|
||||
|
||||
mock_ready_flag_path.exists.side_effect = mock_exists
|
||||
|
||||
# Replace the ready_flag_path with our mock
|
||||
original_path = loader.ready_flag_path
|
||||
loader.ready_flag_path = mock_ready_flag_path
|
||||
|
||||
try:
|
||||
result = loader.load(mock_load_fn)
|
||||
finally:
|
||||
# Restore original path
|
||||
loader.ready_flag_path = original_path
|
||||
|
||||
# Should have slept twice while waiting
|
||||
assert mock_sleep.call_count == 2
|
||||
mock_sleep.assert_called_with(1)
|
||||
|
||||
# Should eventually call load function
|
||||
mock_load_fn.assert_called_once()
|
||||
assert result == "waiting_data"
|
||||
|
||||
def test_complete_workflow_with_cleanup(self, loader):
|
||||
"""Test the complete load -> cleanup workflow."""
|
||||
mock_load_fn = Mock(return_value="test_data")
|
||||
|
||||
# First process calls load (this should set up counter)
|
||||
result = loader.load(mock_load_fn)
|
||||
assert result == "test_data"
|
||||
assert loader.ready_flag_path.exists()
|
||||
assert loader.counter_path.exists()
|
||||
|
||||
# Cleanup should remove everything since there's only one process
|
||||
loader.cleanup()
|
||||
assert not loader.ready_flag_path.exists()
|
||||
assert not loader.counter_path.exists()
|
||||
|
||||
def test_multiple_processes_workflow(self, loader):
|
||||
"""Test workflow with multiple processes."""
|
||||
# Simulate multiple processes by manually setting up counter
|
||||
loader.ready_flag_path.touch()
|
||||
loader.counter_path.write_text("3") # 3 processes
|
||||
|
||||
# First process cleanup
|
||||
loader.cleanup()
|
||||
assert loader.ready_flag_path.exists()
|
||||
assert loader.counter_path.read_text().strip() == "2"
|
||||
|
||||
# Second process cleanup
|
||||
loader.cleanup()
|
||||
assert loader.ready_flag_path.exists()
|
||||
assert loader.counter_path.read_text().strip() == "1"
|
||||
|
||||
# Last process cleanup
|
||||
loader.cleanup()
|
||||
assert not loader.ready_flag_path.exists()
|
||||
assert not loader.counter_path.exists()
|
||||
|
||||
def test_load_exception_handling(self, loader):
|
||||
"""Test behavior when load_fn raises an exception."""
|
||||
|
||||
def failing_load_fn():
|
||||
raise ValueError("Load failed")
|
||||
|
||||
with pytest.raises(ValueError, match="Load failed"):
|
||||
loader.load(failing_load_fn)
|
||||
|
||||
# Ready flag should not be created on failure
|
||||
assert not loader.ready_flag_path.exists()
|
||||
|
||||
def test_file_lock_called(self, loader):
|
||||
"""Test that FileLock is properly used."""
|
||||
mock_load_fn = Mock(return_value="locked_data")
|
||||
|
||||
with patch("axolotl.utils.data.lock.FileLock") as mock_filelock:
|
||||
mock_context = MagicMock()
|
||||
mock_filelock.return_value.__enter__ = Mock(return_value=mock_context)
|
||||
mock_filelock.return_value.__exit__ = Mock(return_value=None)
|
||||
|
||||
loader.load(mock_load_fn)
|
||||
|
||||
# Verify FileLock was called with correct path
|
||||
mock_filelock.assert_called_once_with(str(loader.lock_file_path))
|
||||
|
||||
# Verify context manager was used
|
||||
mock_filelock.return_value.__enter__.assert_called_once()
|
||||
mock_filelock.return_value.__exit__.assert_called_once()
|
||||
@@ -4,7 +4,6 @@ E2E tests for multipack fft llama using 4d attention masks
|
||||
|
||||
import unittest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -60,8 +59,7 @@ class Test4dMultipackLlama(unittest.TestCase):
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -108,8 +106,7 @@ class Test4dMultipackLlama(unittest.TestCase):
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -6,7 +6,6 @@ import pytest
|
||||
import transformers
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -75,8 +74,7 @@ class TestActivationCheckpointing:
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -5,7 +5,6 @@ E2E tests for lora llama
|
||||
import pytest
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -73,8 +72,7 @@ class TestFAXentropyLlama:
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -6,7 +6,6 @@ import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -63,8 +62,7 @@ class TestFalconPatched(unittest.TestCase):
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -105,8 +103,7 @@ class TestFalconPatched(unittest.TestCase):
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -7,7 +7,6 @@ import unittest
|
||||
import pytest
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -62,8 +61,7 @@ class TestFusedLlama(unittest.TestCase):
|
||||
cfg.fp16 = True
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -6,7 +6,6 @@ import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -64,8 +63,7 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -107,8 +105,7 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -7,7 +7,6 @@ import unittest
|
||||
import pytest
|
||||
from transformers.utils import is_auto_gptq_available, is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -65,8 +64,7 @@ class TestLoraLlama(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -114,8 +112,7 @@ class TestLoraLlama(unittest.TestCase):
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -4,7 +4,6 @@ E2E tests for lora llama
|
||||
|
||||
import unittest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -60,8 +59,7 @@ class TestMistral(unittest.TestCase):
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -102,8 +100,7 @@ class TestMistral(unittest.TestCase):
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -4,7 +4,6 @@ E2E tests for mixtral
|
||||
|
||||
import unittest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -57,8 +56,7 @@ class TestMixtral(unittest.TestCase):
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -96,8 +94,7 @@ class TestMixtral(unittest.TestCase):
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -4,7 +4,6 @@ E2E tests for lora llama
|
||||
|
||||
import unittest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -60,8 +59,7 @@ class TestPhiMultipack(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -112,8 +110,7 @@ class TestPhiMultipack(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -7,7 +7,6 @@ import subprocess
|
||||
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -67,8 +66,7 @@ class TestResumeLlama:
|
||||
cfg.fp16 = True
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
|
||||
@@ -78,7 +76,6 @@ class TestResumeLlama:
|
||||
}
|
||||
)
|
||||
normalize_config(resume_cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
|
||||
train(cfg=resume_cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -4,7 +4,6 @@ e2e tests for unsloth qlora
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -68,8 +67,7 @@ class TestUnslothQLoRA:
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -119,8 +117,7 @@ class TestUnslothQLoRA:
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -175,8 +172,7 @@ class TestUnslothQLoRA:
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -6,7 +6,6 @@ import unittest
|
||||
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -59,8 +58,7 @@ class TestPackedFlex(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ E2E tests for relora llama
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -71,8 +70,7 @@ class TestReLoraLlama(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(Path(temp_dir) / "checkpoint-100/adapter", cfg)
|
||||
|
||||
@@ -6,7 +6,6 @@ from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -72,8 +71,7 @@ class TestDeepseekV3:
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
||||
@@ -122,8 +120,7 @@ class TestDeepseekV3:
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
"""
|
||||
E2E tests for lora llama
|
||||
"""
|
||||
"""E2E tests for lora llama"""
|
||||
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
@@ -4,7 +4,6 @@ E2E tests for llama pretrain
|
||||
|
||||
import unittest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -54,8 +53,7 @@ class TestEmbeddingsLrScale(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -99,8 +97,7 @@ class TestEmbeddingsLrScale(unittest.TestCase):
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -6,7 +6,6 @@ import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -66,8 +65,7 @@ class TestFalcon(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -122,8 +120,7 @@ class TestFalcon(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -164,8 +161,7 @@ class TestFalcon(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -6,7 +6,6 @@ from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -69,8 +68,7 @@ class TestGemma2:
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
||||
@@ -121,8 +119,7 @@ class TestGemma2:
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||
|
||||
@@ -6,7 +6,6 @@ from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -68,8 +67,7 @@ class TestGemma3Text:
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
||||
@@ -119,8 +117,7 @@ class TestGemma3Text:
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
E2E tests for llama
|
||||
"""
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -51,8 +50,7 @@ class TestLlama:
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -99,8 +97,7 @@ class TestLlama:
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -144,8 +141,7 @@ class TestLlama:
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -185,8 +181,7 @@ class TestLlama:
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
"""
|
||||
E2E tests for llama pretrain
|
||||
"""
|
||||
"""E2E tests for llama pretrain"""
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -14,9 +11,7 @@ from .utils import check_model_output_exists, check_tensorboard
|
||||
|
||||
|
||||
class TestPretrainLlama:
|
||||
"""
|
||||
Test case for Llama models w pretraining
|
||||
"""
|
||||
"""Test case for Llama models w pretraining"""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sample_packing",
|
||||
@@ -66,8 +61,7 @@ class TestPretrainLlama:
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -4,7 +4,6 @@ E2E tests for lora llama
|
||||
|
||||
import unittest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -60,8 +59,7 @@ class TestLlamaVision(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -106,8 +104,7 @@ class TestLlamaVision(unittest.TestCase):
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -4,7 +4,6 @@ E2E tests for lora llama
|
||||
|
||||
import unittest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -55,8 +54,7 @@ class TestLoraLlama(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -6,7 +6,6 @@ import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -57,8 +56,7 @@ class TestMamba(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -6,7 +6,6 @@ import unittest
|
||||
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -61,8 +60,7 @@ class TestMistral(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -106,8 +104,7 @@ class TestMistral(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -7,7 +7,6 @@ import unittest
|
||||
import torch
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -67,8 +66,7 @@ class TestMixtral(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (
|
||||
@@ -123,8 +121,7 @@ class TestMixtral(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (
|
||||
@@ -182,8 +179,7 @@ class TestMixtral(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (
|
||||
@@ -241,8 +237,7 @@ class TestMixtral(unittest.TestCase):
|
||||
cfg.bf16 = True
|
||||
else:
|
||||
cfg.fp16 = True
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (
|
||||
@@ -287,8 +282,7 @@ class TestMixtral(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -4,7 +4,6 @@ E2E tests for custom optimizers using Llama
|
||||
|
||||
import unittest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -61,8 +60,7 @@ class TestCustomOptimizers(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -107,8 +105,7 @@ class TestCustomOptimizers(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -154,8 +151,7 @@ class TestCustomOptimizers(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -194,8 +190,7 @@ class TestCustomOptimizers(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -242,8 +237,7 @@ class TestCustomOptimizers(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -6,7 +6,6 @@ import unittest
|
||||
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -58,8 +57,7 @@ class TestPackedLlama(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ E2E tests for lora llama
|
||||
|
||||
import unittest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -58,8 +57,7 @@ class TestPhi(unittest.TestCase):
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -108,8 +106,7 @@ class TestPhi(unittest.TestCase):
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -4,7 +4,6 @@ E2E tests for process reward model w/ lora llama
|
||||
|
||||
import unittest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -54,8 +53,7 @@ class TestProcessRewardSmolLM2(unittest.TestCase):
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_tensorboard(
|
||||
|
||||
@@ -5,7 +5,6 @@ E2E tests for QAT
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -64,8 +63,7 @@ class TestQATLlama(unittest.TestCase):
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(Path(temp_dir) / "checkpoint-5", cfg)
|
||||
|
||||
@@ -4,7 +4,6 @@ E2E tests for reward model lora llama
|
||||
|
||||
import unittest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -63,8 +62,7 @@ class TestRewardModelLoraSmolLM2(unittest.TestCase):
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_tensorboard(
|
||||
|
||||
@@ -4,7 +4,6 @@ E2E tests for custom schedulers using Llama
|
||||
|
||||
import unittest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -57,8 +56,7 @@ class TestCustomSchedulers(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -6,8 +6,9 @@ import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.loaders.tokenizer import load_tokenizer
|
||||
from axolotl.prompt_strategies.dpo import load as load_dpo
|
||||
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
||||
from axolotl.utils.data.rl import prepare_preference_datasets
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.hf_offline_utils import enable_hf_offline
|
||||
@@ -55,7 +56,8 @@ class TestDPOChatml:
|
||||
# test that dpo.load works
|
||||
load_dpo("chatml", cfg)
|
||||
# now actually load the datasets with the strategy
|
||||
train_ds, _ = load_prepare_preference_datasets(cfg)
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
train_ds, _ = prepare_preference_datasets(cfg, tokenizer)
|
||||
assert train_ds[0]["prompt"].startswith("<|im_start|>")
|
||||
assert train_ds[0]["prompt"].endswith("<|im_start|>assistant\n")
|
||||
assert "chosen" in train_ds[0]
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
"""
|
||||
Test dataset loading under various conditions.
|
||||
"""
|
||||
"""Test dataset loading under various conditions."""
|
||||
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any, Generator
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@@ -12,8 +11,9 @@ from datasets import Dataset
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from axolotl.utils.data import load_tokenized_prepared_datasets
|
||||
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
||||
from axolotl.loaders.tokenizer import load_tokenizer
|
||||
from axolotl.utils.data.rl import prepare_preference_datasets
|
||||
from axolotl.utils.data.sft import _load_tokenized_prepared_datasets
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.constants import (
|
||||
@@ -28,7 +28,9 @@ class TestDatasetPreparation:
|
||||
"""Test a configured dataloader."""
|
||||
|
||||
@pytest.fixture
|
||||
def tokenizer(self, tokenizer_huggyllama) -> PreTrainedTokenizer:
|
||||
def tokenizer(
|
||||
self, tokenizer_huggyllama
|
||||
) -> Generator[PreTrainedTokenizer, Any, Any]:
|
||||
tokenizer_huggyllama.add_special_tokens(SPECIAL_TOKENS)
|
||||
yield tokenizer_huggyllama
|
||||
|
||||
@@ -63,7 +65,10 @@ class TestDatasetPreparation:
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||
with patch(
|
||||
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
|
||||
):
|
||||
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
|
||||
|
||||
assert len(dataset) == 2000
|
||||
assert "input_ids" in dataset.features
|
||||
@@ -107,7 +112,10 @@ class TestDatasetPreparation:
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||
with patch(
|
||||
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
|
||||
):
|
||||
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
|
||||
|
||||
assert len(dataset) == 2000
|
||||
assert "input_ids" in dataset.features
|
||||
@@ -136,7 +144,10 @@ class TestDatasetPreparation:
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||
with patch(
|
||||
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
|
||||
):
|
||||
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
|
||||
|
||||
assert len(dataset) == 1
|
||||
assert "input_ids" in dataset.features
|
||||
@@ -145,7 +156,7 @@ class TestDatasetPreparation:
|
||||
|
||||
@enable_hf_offline
|
||||
def test_load_from_dir_of_parquet(self, tokenizer, dataset_fixture):
|
||||
"""Usual use case. Verify a directory of parquet files can be loaded."""
|
||||
"""Usual use case. Verify a directory of parquet files can be loaded."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_ds_dir = Path(tmp_dir) / "tmp_dataset"
|
||||
tmp_ds_dir.mkdir()
|
||||
@@ -171,7 +182,10 @@ class TestDatasetPreparation:
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||
with patch(
|
||||
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
|
||||
):
|
||||
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
|
||||
|
||||
assert len(dataset) == 1
|
||||
assert "input_ids" in dataset.features
|
||||
@@ -206,7 +220,10 @@ class TestDatasetPreparation:
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||
with patch(
|
||||
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
|
||||
):
|
||||
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
|
||||
|
||||
assert len(dataset) == 1
|
||||
assert "input_ids" in dataset.features
|
||||
@@ -235,7 +252,10 @@ class TestDatasetPreparation:
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||
with patch(
|
||||
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
|
||||
):
|
||||
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
|
||||
|
||||
assert len(dataset) == 1
|
||||
assert "input_ids" in dataset.features
|
||||
@@ -264,7 +284,10 @@ class TestDatasetPreparation:
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||
with patch(
|
||||
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
|
||||
):
|
||||
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
|
||||
|
||||
assert len(dataset) == 1
|
||||
assert "input_ids" in dataset.features
|
||||
@@ -286,7 +309,8 @@ class TestDatasetPreparation:
|
||||
}
|
||||
)
|
||||
|
||||
train_dataset, _ = load_prepare_preference_datasets(cfg)
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
train_dataset, _ = prepare_preference_datasets(cfg, tokenizer)
|
||||
|
||||
assert len(train_dataset) == 1800
|
||||
assert "conversation" not in train_dataset.features
|
||||
@@ -318,7 +342,10 @@ class TestDatasetPreparation:
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||
with patch(
|
||||
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
|
||||
):
|
||||
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
|
||||
|
||||
assert len(dataset) == 2000
|
||||
assert "input_ids" in dataset.features
|
||||
@@ -342,13 +369,16 @@ class TestDatasetPreparation:
|
||||
)
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
with patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset:
|
||||
with patch(
|
||||
"axolotl.utils.data.rl.load_dataset_with_config"
|
||||
) as mock_load_dataset:
|
||||
# Set up the mock to return different values on successive calls
|
||||
mock_load_dataset.return_value = (
|
||||
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff
|
||||
)
|
||||
|
||||
train_dataset, _ = load_prepare_preference_datasets(cfg)
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
train_dataset, _ = prepare_preference_datasets(cfg, tokenizer)
|
||||
|
||||
assert len(train_dataset) == 1800
|
||||
assert "conversation" not in train_dataset.features
|
||||
@@ -393,16 +423,18 @@ class TestDatasetPreparation:
|
||||
)
|
||||
|
||||
with patch(
|
||||
"axolotl.utils.data.shared.load_dataset_w_config"
|
||||
"axolotl.utils.data.shared.load_dataset_with_config"
|
||||
) as mock_load_dataset:
|
||||
# Set up the mock to return different values on successive calls
|
||||
mock_load_dataset.return_value = (
|
||||
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(
|
||||
tokenizer, cfg, prepared_path
|
||||
)
|
||||
with patch(
|
||||
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH",
|
||||
str(prepared_path),
|
||||
):
|
||||
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
|
||||
|
||||
assert len(dataset) == 2000
|
||||
assert "input_ids" in dataset.features
|
||||
@@ -437,7 +469,10 @@ class TestDatasetPreparation:
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||
with patch(
|
||||
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
|
||||
):
|
||||
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
|
||||
|
||||
assert len(dataset) == 2000
|
||||
assert "input_ids" in dataset.features
|
||||
|
||||
@@ -5,7 +5,6 @@ Additionally, this test suite includes tests for functions that indirectly call
|
||||
`deduplicate_and_log_datasets` during the execution of the preprocess command.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -14,8 +13,7 @@ from datasets import Dataset
|
||||
|
||||
from axolotl.loaders import load_processor, load_tokenizer
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.data import prepare_dataset
|
||||
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
||||
from axolotl.utils.data import prepare_datasets, prepare_preference_datasets
|
||||
from axolotl.utils.data.utils import deduplicate_and_log_datasets
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
@@ -71,36 +69,14 @@ class TestDeduplicateIndividualFunctions(unittest.TestCase):
|
||||
self.expected_dataset = Dataset.from_dict(self.expected_data)
|
||||
|
||||
def test_deduplication(self):
|
||||
train_dataset, _, _ = deduplicate_and_log_datasets(train_dataset=self.dataset)
|
||||
_, eval_dataset, _ = deduplicate_and_log_datasets(eval_dataset=self.dataset)
|
||||
train_dataset, _ = deduplicate_and_log_datasets(dataset=self.dataset)
|
||||
eval_dataset, _ = deduplicate_and_log_datasets(
|
||||
dataset=self.dataset, dataset_name="eval"
|
||||
)
|
||||
|
||||
verify_deduplication(train_dataset, self.expected_dataset, "train_dataset")
|
||||
verify_deduplication(eval_dataset, self.expected_dataset, "eval_dataset")
|
||||
|
||||
def test_datasets_are_none(self):
|
||||
# Test when both datasets are None
|
||||
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
|
||||
train_dataset=None, eval_dataset=None
|
||||
)
|
||||
self.assertIsNone(train_dataset, "Expected train_dataset to be None")
|
||||
self.assertIsNone(eval_dataset, "Expected eval_dataset to be None")
|
||||
|
||||
def test_only_train_is_none(self):
|
||||
# Test when only train_dataset is None
|
||||
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
|
||||
train_dataset=None, eval_dataset=self.dataset
|
||||
)
|
||||
self.assertIsNone(train_dataset, "Expected train_dataset to be None")
|
||||
verify_deduplication(eval_dataset, self.expected_dataset, "eval_dataset")
|
||||
|
||||
def test_only_eval_is_none(self):
|
||||
# Test when only eval_dataset is None
|
||||
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
|
||||
train_dataset=self.dataset, eval_dataset=None
|
||||
)
|
||||
self.assertIsNone(eval_dataset, "Expected eval_dataset to be None")
|
||||
verify_deduplication(train_dataset, self.expected_dataset, "train_dataset")
|
||||
|
||||
def test_exact_duplicates(self):
|
||||
# Test when datasets are exact duplicates
|
||||
duplicate_data = {
|
||||
@@ -115,8 +91,10 @@ class TestDeduplicateIndividualFunctions(unittest.TestCase):
|
||||
expected_dataset = Dataset.from_dict(expected_data)
|
||||
|
||||
# Run deduplication
|
||||
train_dataset, _, _ = deduplicate_and_log_datasets(train_dataset=dataset)
|
||||
_, eval_dataset, _ = deduplicate_and_log_datasets(eval_dataset=dataset)
|
||||
train_dataset, _ = deduplicate_and_log_datasets(dataset=dataset)
|
||||
eval_dataset, _ = deduplicate_and_log_datasets(
|
||||
dataset=dataset, dataset_name="eval"
|
||||
)
|
||||
|
||||
verify_deduplication(train_dataset, expected_dataset, "train_dataset")
|
||||
verify_deduplication(eval_dataset, expected_dataset, "eval_dataset")
|
||||
@@ -139,8 +117,10 @@ class TestDeduplicateIndividualFunctions(unittest.TestCase):
|
||||
expected_dataset = Dataset.from_dict(expected_data)
|
||||
|
||||
# Run deduplication
|
||||
train_dataset, _, _ = deduplicate_and_log_datasets(train_dataset=dataset)
|
||||
_, eval_dataset, _ = deduplicate_and_log_datasets(eval_dataset=dataset)
|
||||
train_dataset, _ = deduplicate_and_log_datasets(dataset=dataset)
|
||||
eval_dataset, _ = deduplicate_and_log_datasets(
|
||||
dataset=dataset, dataset_name="eval"
|
||||
)
|
||||
|
||||
verify_deduplication(train_dataset, expected_dataset, "train_dataset")
|
||||
verify_deduplication(eval_dataset, expected_dataset, "eval_dataset")
|
||||
@@ -169,8 +149,8 @@ class TestDeduplicateIndividualFunctions(unittest.TestCase):
|
||||
expected_dataset_eval = Dataset.from_dict(expected_data_eval)
|
||||
|
||||
# Run deduplication
|
||||
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
|
||||
train_dataset=dataset, eval_dataset=dataset
|
||||
train_dataset, eval_dataset = deduplicate_and_log_datasets(
|
||||
dataset=dataset, other_dataset=dataset
|
||||
)
|
||||
|
||||
verify_deduplication(train_dataset, expected_dataset_train, "train_dataset")
|
||||
@@ -206,8 +186,8 @@ class TestDeduplicateIndividualFunctions(unittest.TestCase):
|
||||
expected_dataset_eval = Dataset.from_dict(expected_data_eval)
|
||||
|
||||
# Run deduplication
|
||||
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
|
||||
train_dataset=dataset_train, eval_dataset=dataset_eval
|
||||
train_dataset, eval_dataset = deduplicate_and_log_datasets(
|
||||
dataset=dataset_train, other_dataset=dataset_eval
|
||||
)
|
||||
|
||||
verify_deduplication(train_dataset, expected_dataset_train, "train_dataset")
|
||||
@@ -245,7 +225,9 @@ class TestDeduplicateRLDataset:
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
with (
|
||||
patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset,
|
||||
patch(
|
||||
"axolotl.utils.data.rl.load_dataset_with_config"
|
||||
) as mock_load_dataset,
|
||||
patch("axolotl.loaders.load_tokenizer") as mock_load_tokenizer,
|
||||
):
|
||||
# Set up the mock to return different values on successive calls
|
||||
@@ -255,7 +237,8 @@ class TestDeduplicateRLDataset:
|
||||
]
|
||||
mock_load_tokenizer.return_value = tokenizer_huggyllama
|
||||
|
||||
train_dataset, _ = load_prepare_preference_datasets(cfg)
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
train_dataset, _ = prepare_preference_datasets(cfg, tokenizer)
|
||||
|
||||
# Verify that the dataset has been deduplicated
|
||||
assert len(train_dataset) == 1800, "Dataset was not properly deduplicated"
|
||||
@@ -269,7 +252,9 @@ class TestDeduplicateRLDataset:
|
||||
):
|
||||
# pylint: disable=duplicate-code
|
||||
with (
|
||||
patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset,
|
||||
patch(
|
||||
"axolotl.utils.data.rl.load_dataset_with_config"
|
||||
) as mock_load_dataset,
|
||||
patch("axolotl.loaders.load_tokenizer") as mock_load_tokenizer,
|
||||
):
|
||||
# Set up the mock to return different values on successive calls
|
||||
@@ -279,9 +264,10 @@ class TestDeduplicateRLDataset:
|
||||
]
|
||||
mock_load_tokenizer.return_value = tokenizer_huggyllama
|
||||
|
||||
cfg.dataset_exact_deduplication = False
|
||||
# Load the dataset without deduplication
|
||||
train_dataset, _ = load_prepare_preference_datasets(cfg)
|
||||
cfg.dataset_exact_deduplication = False
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
train_dataset, _ = prepare_preference_datasets(cfg, tokenizer)
|
||||
|
||||
# Verify that the dataset retains duplicates
|
||||
assert (
|
||||
@@ -335,7 +321,7 @@ class TestDeduplicateNonRL(unittest.TestCase):
|
||||
)
|
||||
|
||||
# Prepare dataset using the prepare_dataset function
|
||||
train_dataset, _, _, _ = prepare_dataset(
|
||||
train_dataset, _, _, _ = prepare_datasets(
|
||||
self.cfg_1,
|
||||
tokenizer,
|
||||
processor=processor,
|
||||
@@ -362,7 +348,7 @@ class TestDeduplicateNonRL(unittest.TestCase):
|
||||
)
|
||||
|
||||
# Prepare dataset using the prepare_dataset function
|
||||
_, eval_dataset, _, _ = prepare_dataset(
|
||||
_, eval_dataset, _, _ = prepare_datasets(
|
||||
self.cfg_1,
|
||||
tokenizer,
|
||||
processor=processor,
|
||||
@@ -389,7 +375,7 @@ class TestDeduplicateNonRL(unittest.TestCase):
|
||||
)
|
||||
|
||||
# Prepare dataset using the prepare_dataset function
|
||||
train_dataset, eval_dataset, _, _ = prepare_dataset(
|
||||
train_dataset, eval_dataset, _, _ = prepare_datasets(
|
||||
self.cfg_1,
|
||||
tokenizer,
|
||||
processor=processor,
|
||||
@@ -428,41 +414,8 @@ class TestWrongCollisions(unittest.TestCase):
|
||||
self.eval_dataset = Dataset.from_dict(self.eval_data)
|
||||
self.dataset = Dataset.from_dict(self.dataset_data)
|
||||
|
||||
@patch(
|
||||
"axolotl.utils.data.utils.sha256",
|
||||
side_effect=lambda x: (
|
||||
hashlib.sha256("forced_collision_hash".encode("utf-8")).hexdigest()
|
||||
if "sample 5" in x
|
||||
else hashlib.sha256(x.encode("utf-8")).hexdigest()
|
||||
),
|
||||
)
|
||||
def test_deduplication_wrong_collision_train_eval(self, _mock_sha256):
|
||||
dedup_train, dedup_eval, _ = deduplicate_and_log_datasets(
|
||||
train_dataset=self.train_dataset, eval_dataset=self.eval_dataset
|
||||
)
|
||||
self.assertEqual(
|
||||
len(dedup_train),
|
||||
2,
|
||||
"train dataset should not deduplicate rows with forced hash collisions but different labels.",
|
||||
)
|
||||
self.assertEqual(
|
||||
len(dedup_eval),
|
||||
2,
|
||||
"Eval dataset should not deduplicate rows with forced hash collisions but different labels.",
|
||||
)
|
||||
self.assertEqual(
|
||||
len(dedup_eval),
|
||||
len(self.eval_dataset),
|
||||
"The output eval dataset should have the same number of rows as the input eval dataset.",
|
||||
)
|
||||
self.assertEqual(
|
||||
str(dedup_eval),
|
||||
str(self.eval_dataset),
|
||||
"The string representation of the output eval dataset should be identical to the input eval dataset.",
|
||||
)
|
||||
|
||||
def test_deduplication_dataset_only(self):
|
||||
_, _, dedup_dataset = deduplicate_and_log_datasets(dataset=self.dataset)
|
||||
dedup_dataset, _ = deduplicate_and_log_datasets(dataset=self.dataset)
|
||||
self.assertEqual(
|
||||
len(dedup_dataset), 3, "Dataset should have all original values"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user