* fix: force train split for json,csv,txt for test_datasets * feat(doc): add info on mixing datasets for VLM * feat(doc): max memory * fix(doc): clarify lr groups * fix: add info on vision not being dropped * feat: add qwen3-vl to multimodal docs * fix: add moe blocks to arch list * feat(doc): improve mistral docs * chore: add helpful link [skip-e2e] * fix: add vram usage for mistral small * Update link in docs/faq.qmd Co-authored-by: salman <salman.mohammadi@outlook.com> --------- Co-authored-by: Wing Lian <wing@axolotl.ai> Co-authored-by: salman <salman.mohammadi@outlook.com>
571 lines
18 KiB
Python
571 lines
18 KiB
Python
"""Dataset loading shared utils."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import functools
|
|
import os
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING, Any, Generator
|
|
|
|
from datasets import (
|
|
Dataset,
|
|
DatasetDict,
|
|
IterableDataset,
|
|
IterableDatasetDict,
|
|
concatenate_datasets,
|
|
load_dataset,
|
|
load_from_disk,
|
|
)
|
|
from huggingface_hub import hf_hub_download, snapshot_download
|
|
from huggingface_hub.errors import (
|
|
HFValidationError,
|
|
RepositoryNotFoundError,
|
|
RevisionNotFoundError,
|
|
)
|
|
|
|
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
|
from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5
|
|
from axolotl.utils.datasets import get_default_process_count
|
|
from axolotl.utils.dict import DictDefault
|
|
from axolotl.utils.logging import get_logger
|
|
|
|
if TYPE_CHECKING:
|
|
from adlfs import AzureBlobFileSystem
|
|
from gcsfs import GCSFileSystem
|
|
from ocifs import OCIFileSystem
|
|
from s3fs import S3FileSystem
|
|
|
|
LOG = get_logger(__name__)
|
|
|
|
EXTENSIONS_TO_DATASET_TYPES = {
|
|
".parquet": "parquet",
|
|
".arrow": "arrow",
|
|
".csv": "csv",
|
|
".txt": "text",
|
|
}
|
|
|
|
|
|
def get_dataset_type(dataset_config: DictDefault) -> str:
|
|
"""Get the dataset type from the path if it's not specified."""
|
|
if dataset_config.ds_type:
|
|
return dataset_config.ds_type
|
|
|
|
for extension, dataset_type in EXTENSIONS_TO_DATASET_TYPES.items():
|
|
if extension in dataset_config.path:
|
|
return dataset_type
|
|
|
|
return "json"
|
|
|
|
|
|
def datasets_with_name_generator(
|
|
dataset_configs: list[DictDefault],
|
|
) -> Generator[DictDefault, None, None]:
|
|
"""Yields expanded dataset configurations based on multiple names or preprocessing
|
|
shards.
|
|
|
|
When a dataset config has a list of names, it yields separate configs for each
|
|
name. When a dataset config specifies preprocessing shards, it yields configs for
|
|
each shard.
|
|
|
|
Args:
|
|
dataset_configs: List of dataset configuration objects.
|
|
|
|
Yields:
|
|
Individual dataset configurations, expanded as needed for names or shards.
|
|
"""
|
|
for config in dataset_configs:
|
|
if config.name and isinstance(config.name, list):
|
|
for name in config.name:
|
|
yield DictDefault({**config, "name": name})
|
|
elif config.preprocess_shards and not config.shards:
|
|
for shard_idx in range(config.preprocess_shards):
|
|
yield DictDefault(
|
|
{
|
|
**config,
|
|
"shards": config.preprocess_shards,
|
|
"shards_idx": shard_idx,
|
|
}
|
|
)
|
|
else:
|
|
yield config
|
|
|
|
|
|
def load_dataset_with_config(
|
|
dataset_config: DictDefault, use_auth_token: bool, streaming=False
|
|
) -> Dataset | IterableDataset:
|
|
"""Load a dataset from a config. Handles datasets that are stored locally, in the
|
|
HuggingFace Hub, in a remote filesystem (S3, GCS, Azure, OCI), a URL, or
|
|
`data_files`.
|
|
|
|
Args:
|
|
dataset_config: Single dataset config.
|
|
use_auth_token: Whether to use HF auth token.
|
|
streaming: Whether to stream the dataset.
|
|
|
|
Returns:
|
|
Loaded dataset.
|
|
"""
|
|
# Set up common kwargs for dataset loading
|
|
load_dataset_kwargs = {
|
|
"split": dataset_config.split if dataset_config.split else None,
|
|
"name": dataset_config.name,
|
|
"streaming": streaming,
|
|
"trust_remote_code": dataset_config.trust_remote_code,
|
|
}
|
|
|
|
# First check if it's a local path
|
|
if Path(dataset_config.path).exists():
|
|
return _load_from_local_path(dataset_config, load_dataset_kwargs)
|
|
|
|
# Check if it's a HuggingFace dataset
|
|
is_hub_dataset = _check_if_hub_dataset(dataset_config, use_auth_token)
|
|
|
|
# Check if it's a cloud storage path and get appropriate filesystem
|
|
remote_fs, storage_options = _get_remote_filesystem(dataset_config.path)
|
|
is_cloud_dataset = False
|
|
if remote_fs:
|
|
try:
|
|
is_cloud_dataset = remote_fs.exists(dataset_config.path)
|
|
except (FileNotFoundError, ConnectionError):
|
|
pass
|
|
|
|
# Load from appropriate source
|
|
if is_hub_dataset:
|
|
return _load_from_hub(dataset_config, use_auth_token, load_dataset_kwargs)
|
|
if is_cloud_dataset:
|
|
return _load_from_cloud(
|
|
dataset_config, remote_fs, storage_options, load_dataset_kwargs
|
|
)
|
|
if dataset_config.path.startswith("https://"):
|
|
return _load_from_url(dataset_config, load_dataset_kwargs)
|
|
if dataset_config.data_files:
|
|
return _load_from_data_files(dataset_config, load_dataset_kwargs)
|
|
|
|
raise ValueError(
|
|
f"The dataset could not be loaded. This could be due to a misconfigured dataset path "
|
|
f"({dataset_config.path}). Try double-check your path / name / data_files. "
|
|
f"This is not caused by the dataset type."
|
|
)
|
|
|
|
|
|
def _check_if_hub_dataset(dataset_config: DictDefault, use_auth_token: bool) -> bool:
|
|
"""Check if a dataset exists on the HuggingFace Hub."""
|
|
try:
|
|
snapshot_download(
|
|
repo_id=dataset_config.path,
|
|
repo_type="dataset",
|
|
token=use_auth_token,
|
|
revision=dataset_config.revision,
|
|
ignore_patterns=["*"],
|
|
)
|
|
return True
|
|
except (
|
|
RepositoryNotFoundError,
|
|
RevisionNotFoundError,
|
|
FileNotFoundError,
|
|
ConnectionError,
|
|
HFValidationError,
|
|
ValueError,
|
|
):
|
|
return False
|
|
|
|
|
|
def _get_remote_filesystem(
|
|
path: str,
|
|
) -> tuple[
|
|
S3FileSystem | GCSFileSystem | AzureBlobFileSystem | OCIFileSystem | None, dict
|
|
]:
|
|
"""Get the appropriate filesystem for a remote path."""
|
|
if path.startswith("s3://"):
|
|
try:
|
|
import s3fs
|
|
|
|
storage_options = {"anon": False}
|
|
return s3fs.S3FileSystem(**storage_options), storage_options
|
|
except ImportError as exc:
|
|
raise ImportError("s3:// paths require s3fs to be installed") from exc
|
|
|
|
elif path.startswith(("gs://", "gcs://")):
|
|
try:
|
|
import gcsfs
|
|
|
|
storage_options = {"token": None} # type: ignore
|
|
return gcsfs.GCSFileSystem(**storage_options), storage_options
|
|
except ImportError as exc:
|
|
raise ImportError(
|
|
"gs:// or gcs:// paths require gcsfs to be installed"
|
|
) from exc
|
|
|
|
elif path.startswith(("adl://", "abfs://", "az://")):
|
|
try:
|
|
import adlfs
|
|
|
|
storage_options = {"anon": False}
|
|
return adlfs.AzureBlobFileSystem(**storage_options), storage_options
|
|
except ImportError as exc:
|
|
raise ImportError(
|
|
"adl:// or abfs:// paths require adlfs to be installed"
|
|
) from exc
|
|
|
|
elif path.startswith("oci://"):
|
|
try:
|
|
import ocifs
|
|
|
|
storage_options = {}
|
|
return ocifs.OCIFileSystem(**storage_options), storage_options
|
|
except ImportError as exc:
|
|
raise ImportError("oci:// paths require ocifs to be installed") from exc
|
|
|
|
return None, {}
|
|
|
|
|
|
def _load_from_local_path(
|
|
dataset_config: DictDefault, load_dataset_kwargs: dict
|
|
) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict:
|
|
"""Load a dataset from a local path."""
|
|
local_path = Path(dataset_config.path)
|
|
|
|
if local_path.is_dir():
|
|
if dataset_config.data_files:
|
|
dataset_type = get_dataset_type(dataset_config)
|
|
return load_dataset(
|
|
dataset_type,
|
|
data_files=dataset_config.data_files,
|
|
**load_dataset_kwargs,
|
|
)
|
|
try:
|
|
return load_from_disk(dataset_config.path)
|
|
except FileNotFoundError:
|
|
return load_dataset(dataset_config.path, **load_dataset_kwargs)
|
|
elif local_path.is_file():
|
|
dataset_type = get_dataset_type(dataset_config)
|
|
|
|
# For single file datasets, HF always creates only a "train" split
|
|
if dataset_type in ("json", "csv", "text"):
|
|
load_dataset_kwargs["split"] = "train"
|
|
|
|
return load_dataset(
|
|
dataset_type,
|
|
data_files=dataset_config.path,
|
|
**load_dataset_kwargs,
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
"Unhandled dataset load: local path exists, but is neither a directory or a file"
|
|
)
|
|
|
|
|
|
def _load_from_hub(
|
|
dataset_config: DictDefault, use_auth_token: bool, load_dataset_kwargs: dict
|
|
) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict:
|
|
"""Load a dataset from the HuggingFace Hub."""
|
|
return load_dataset(
|
|
dataset_config.path,
|
|
data_files=dataset_config.data_files,
|
|
token=use_auth_token,
|
|
revision=dataset_config.revision,
|
|
**load_dataset_kwargs,
|
|
)
|
|
|
|
|
|
def _load_from_cloud(
|
|
dataset_config: DictDefault,
|
|
remote_fs: S3FileSystem | GCSFileSystem | AzureBlobFileSystem | OCIFileSystem,
|
|
storage_options: dict,
|
|
load_dataset_kwargs: dict,
|
|
) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict:
|
|
"""Load a dataset from cloud storage."""
|
|
if remote_fs.isdir(dataset_config.path):
|
|
return load_from_disk(
|
|
dataset_config.path,
|
|
storage_options=storage_options,
|
|
)
|
|
|
|
if remote_fs.isfile(dataset_config.path):
|
|
dataset_type = get_dataset_type(dataset_config)
|
|
return load_dataset(
|
|
dataset_type,
|
|
data_files=dataset_config.path,
|
|
storage_options=storage_options,
|
|
**load_dataset_kwargs,
|
|
)
|
|
|
|
raise ValueError(
|
|
f"Cloud path {dataset_config.path} is neither a directory nor a file"
|
|
)
|
|
|
|
|
|
def _load_from_url(
|
|
dataset_config: DictDefault, load_dataset_kwargs: dict
|
|
) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict:
|
|
"""Load a dataset from a URL."""
|
|
dataset_type = get_dataset_type(dataset_config)
|
|
return load_dataset(
|
|
dataset_type,
|
|
data_files=dataset_config.path,
|
|
**load_dataset_kwargs,
|
|
)
|
|
|
|
|
|
def _load_from_data_files(
|
|
dataset_config: DictDefault, load_dataset_kwargs: dict
|
|
) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict:
|
|
"""Load a dataset from data files."""
|
|
file_path = None
|
|
|
|
if isinstance(dataset_config.data_files, str):
|
|
file_path = hf_hub_download(
|
|
repo_id=dataset_config.path,
|
|
repo_type="dataset",
|
|
filename=dataset_config.data_files,
|
|
revision=dataset_config.revision,
|
|
)
|
|
elif isinstance(dataset_config.data_files, list):
|
|
file_path = [
|
|
hf_hub_download(
|
|
repo_id=dataset_config.path,
|
|
repo_type="dataset",
|
|
filename=file,
|
|
revision=dataset_config.revision,
|
|
)
|
|
for file in dataset_config.data_files
|
|
]
|
|
else:
|
|
raise ValueError("data_files must be either a string or list of strings")
|
|
|
|
return load_dataset("json", data_files=file_path, **load_dataset_kwargs)
|
|
|
|
|
|
def generate_split_fingerprints(
|
|
dataset: Dataset, val_set_size: int | float, seed: int
|
|
) -> tuple[str, str]:
|
|
"""Generate consistent fingerprints for train/test splits."""
|
|
fingerprint = dataset._fingerprint
|
|
|
|
train_hash_input = f"{fingerprint}|{val_set_size}|train|{seed}"
|
|
test_hash_input = f"{fingerprint}|{val_set_size}|test|{seed}"
|
|
|
|
train_fingerprint = md5(train_hash_input)
|
|
test_fingerprint = md5(test_hash_input)
|
|
|
|
return train_fingerprint, test_fingerprint
|
|
|
|
|
|
def get_prepared_dataset_path(cfg: DictDefault, dataset_hash: str) -> Path:
|
|
"""Get standardized path for prepared datasets.
|
|
|
|
Args:
|
|
cfg: Configuration object.
|
|
dataset_hash: Hash identifying the specific dataset configuration.
|
|
|
|
Returns:
|
|
Path where the prepared dataset should be stored.
|
|
"""
|
|
base_path = cfg.dataset_prepared_path or DEFAULT_DATASET_PREPARED_PATH
|
|
return Path(base_path) / dataset_hash
|
|
|
|
|
|
def create_train_validation_split(
|
|
dataset: Dataset, cfg: DictDefault, val_set_size: int | float
|
|
) -> tuple[Dataset, Dataset]:
|
|
"""Create train/validation split with consistent fingerprinting.
|
|
|
|
Args:
|
|
dataset: Dataset to split.
|
|
cfg: Configuration object containing seed and other settings.
|
|
val_set_size: Size of validation set (absolute number or fraction).
|
|
|
|
Returns:
|
|
Tuple of (train_dataset, eval_dataset).
|
|
"""
|
|
train_fingerprint, test_fingerprint = generate_split_fingerprints(
|
|
dataset, val_set_size, cfg.seed
|
|
)
|
|
|
|
# Apply deduplication before splitting if configured
|
|
if cfg.dataset_exact_deduplication:
|
|
dataset, _ = deduplicate_and_log_datasets(dataset=dataset)
|
|
|
|
split_dataset = dataset.train_test_split(
|
|
test_size=val_set_size,
|
|
shuffle=False,
|
|
seed=cfg.seed,
|
|
train_new_fingerprint=train_fingerprint,
|
|
test_new_fingerprint=test_fingerprint,
|
|
)
|
|
|
|
return split_dataset["train"], split_dataset["test"]
|
|
|
|
|
|
def _generate_from_iterable_dataset(
|
|
dataset: IterableDataset, worker_id: list[int], num_workers: list[int]
|
|
) -> Generator[Any, None, None]:
|
|
"""Generator function to correctly split the dataset for each worker"""
|
|
for i, item in enumerate(dataset):
|
|
if i % num_workers[0] == worker_id[0]:
|
|
yield item
|
|
|
|
|
|
def save_preprocessed_dataset(
|
|
cfg: DictDefault,
|
|
dataset: Dataset,
|
|
dataset_hash: str,
|
|
split: str,
|
|
) -> None:
|
|
"""Save preprocessed dataset to disk and optionally push to the HF Hub."""
|
|
prepared_ds_path = get_prepared_dataset_path(cfg, dataset_hash)
|
|
num_workers = cfg.dataset_num_proc or get_default_process_count()
|
|
if isinstance(dataset, IterableDataset):
|
|
ds_from_iter = Dataset.from_generator(
|
|
functools.partial(_generate_from_iterable_dataset, dataset),
|
|
features=dataset.features,
|
|
num_proc=num_workers,
|
|
split=split,
|
|
gen_kwargs={
|
|
"worker_id": list(range(num_workers)),
|
|
"num_workers": [num_workers] * num_workers,
|
|
},
|
|
)
|
|
ds_from_iter.save_to_disk(
|
|
str(prepared_ds_path),
|
|
num_proc=num_workers,
|
|
max_shard_size=None,
|
|
num_shards=cfg.num_dataset_shards_to_save,
|
|
)
|
|
else:
|
|
min_rows_per_proc = 256
|
|
os.makedirs(prepared_ds_path, exist_ok=True)
|
|
dataset.save_to_disk(
|
|
str(prepared_ds_path),
|
|
num_proc=min(max(1, len(dataset) // min_rows_per_proc), num_workers),
|
|
max_shard_size=None,
|
|
num_shards=cfg.num_dataset_shards_to_save,
|
|
)
|
|
if cfg.push_dataset_to_hub:
|
|
LOG.info(
|
|
"Pushing merged prepared dataset to Huggingface hub at "
|
|
f"{cfg.push_dataset_to_hub} (version {dataset_hash})...",
|
|
main_process_only=False,
|
|
)
|
|
dataset.push_to_hub(
|
|
cfg.push_dataset_to_hub,
|
|
dataset_hash,
|
|
private=True,
|
|
)
|
|
|
|
|
|
def load_preprocessed_dataset(cfg: DictDefault, dataset_hash: str) -> Dataset | None:
|
|
"""Load preprocessed dataset from disk if available.
|
|
|
|
Args:
|
|
cfg: Configuration object.
|
|
dataset_hash: Hash identifying the dataset configuration.
|
|
|
|
Returns:
|
|
Loaded dataset if found and conditions are met, None otherwise.
|
|
"""
|
|
prepared_ds_path = get_prepared_dataset_path(cfg, dataset_hash)
|
|
|
|
if (
|
|
cfg.dataset_prepared_path
|
|
and any(prepared_ds_path.glob("*"))
|
|
and not cfg.skip_prepare_dataset
|
|
and not cfg.is_preprocess
|
|
):
|
|
LOG.info(
|
|
f"Loading prepared dataset from disk at {prepared_ds_path}...",
|
|
main_process_only=True,
|
|
)
|
|
return load_from_disk(str(prepared_ds_path))
|
|
|
|
LOG.info(
|
|
f"Unable to find prepared dataset in {prepared_ds_path}",
|
|
main_process_only=True,
|
|
)
|
|
return None
|
|
|
|
|
|
def try_load_from_hub(
|
|
cfg: DictDefault, dataset_hash: str, split: str
|
|
) -> Dataset | None:
|
|
"""Try to load the prepared dataset from HuggingFace Hub."""
|
|
try:
|
|
LOG.info(
|
|
"Attempting to load prepared dataset from HuggingFace Hub at "
|
|
f"{cfg.push_dataset_to_hub} (version {dataset_hash})..."
|
|
)
|
|
dataset = load_dataset(
|
|
cfg.push_dataset_to_hub,
|
|
dataset_hash,
|
|
token=cfg.hf_use_auth_token,
|
|
)
|
|
return dataset[split]
|
|
except Exception:
|
|
LOG.info("Unable to find prepared dataset in HuggingFace Hub")
|
|
return None
|
|
|
|
|
|
def generate_dataset_hash_from_config(
|
|
cfg: DictDefault, cfg_datasets: list, tokenizer_name: str
|
|
) -> str:
|
|
"""Generate a hash to uniquely identify a dataset configuration for SFT.
|
|
|
|
Args:
|
|
cfg: Main configuration object.
|
|
cfg_datasets: List of dataset configurations.
|
|
tokenizer_name: Name of the tokenizer being used.
|
|
|
|
Returns:
|
|
MD5 hash string representing the configuration.
|
|
"""
|
|
config_str = (
|
|
f"{cfg.sequence_len}@{cfg.sample_packing}@{cfg.eval_sample_packing}@"
|
|
f"{cfg.group_by_length}@{cfg.kd_temperature or 1.0}|"
|
|
f"{'|'.join(sorted([f'{d.path}:{d.type}:{d.shards}:{d.conversation}:{d.split}:{d.temperature or 1.0}' for d in cfg_datasets]))}"
|
|
f"|{tokenizer_name}"
|
|
)
|
|
return str(md5(config_str))
|
|
|
|
|
|
def merge_datasets(datasets: list[Dataset], cfg: DictDefault) -> Dataset:
|
|
"""Merge multiple datasets into one with optional shuffling.
|
|
|
|
Args:
|
|
datasets: List of datasets to merge.
|
|
cfg: Configuration object containing shuffle settings.
|
|
|
|
Returns:
|
|
Merged dataset.
|
|
"""
|
|
if len(datasets) == 1:
|
|
ds = datasets[0]
|
|
|
|
# Do not shuffle if curriculum sampling is enabled or
|
|
# shuffle_merged_datasets is disabled
|
|
if cfg.curriculum_sampling or not cfg.shuffle_merged_datasets:
|
|
return ds
|
|
|
|
return ds.shuffle(seed=cfg.seed)
|
|
|
|
# If enabled, shuffle each dataset independently before merging.
|
|
# This allows curriculum learning strategies to be applied at the dataset level.
|
|
if cfg.shuffle_before_merging_datasets:
|
|
LOG.info("Shuffling each dataset individually before merging...")
|
|
datasets = [ds.shuffle(seed=cfg.seed) for ds in datasets]
|
|
|
|
LOG.info("Merging datasets...")
|
|
merged_dataset = concatenate_datasets(datasets)
|
|
|
|
if cfg.shuffle_merged_datasets:
|
|
LOG.debug("Shuffling merged datasets...")
|
|
if cfg.curriculum_sampling:
|
|
LOG.warning(
|
|
"Shuffling merged datasets with curriculum sampling is not recommended. "
|
|
"This will randomize the order of samples."
|
|
)
|
|
merged_dataset = merged_dataset.shuffle(seed=cfg.seed)
|
|
else:
|
|
LOG.debug("Not shuffling merged datasets.")
|
|
|
|
return merged_dataset
|