Files
axolotl/src/axolotl/utils/data/shared.py
NanoCode012 243620394a fix: force train split for json,csv,txt for test_datasets and misc doc changes (#3226)
* fix: force train split for json,csv,txt for test_datasets

* feat(doc): add info on mixing datasets for VLM

* feat(doc): max memory

* fix(doc): clarify lr groups

* fix: add info on vision not being dropped

* feat: add qwen3-vl to multimodal docs

* fix: add moe blocks to arch list

* feat(doc): improve mistral docs

* chore: add helpful link [skip-e2e]

* fix: add vram usage for mistral small

* Update link in docs/faq.qmd

Co-authored-by: salman <salman.mohammadi@outlook.com>

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: salman <salman.mohammadi@outlook.com>
2025-10-22 15:23:20 -07:00

571 lines
18 KiB
Python

"""Dataset loading shared utils."""
from __future__ import annotations
import functools
import os
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generator
from datasets import (
Dataset,
DatasetDict,
IterableDataset,
IterableDatasetDict,
concatenate_datasets,
load_dataset,
load_from_disk,
)
from huggingface_hub import hf_hub_download, snapshot_download
from huggingface_hub.errors import (
HFValidationError,
RepositoryNotFoundError,
RevisionNotFoundError,
)
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5
from axolotl.utils.datasets import get_default_process_count
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
if TYPE_CHECKING:
from adlfs import AzureBlobFileSystem
from gcsfs import GCSFileSystem
from ocifs import OCIFileSystem
from s3fs import S3FileSystem
LOG = get_logger(__name__)
EXTENSIONS_TO_DATASET_TYPES = {
".parquet": "parquet",
".arrow": "arrow",
".csv": "csv",
".txt": "text",
}
def get_dataset_type(dataset_config: DictDefault) -> str:
"""Get the dataset type from the path if it's not specified."""
if dataset_config.ds_type:
return dataset_config.ds_type
for extension, dataset_type in EXTENSIONS_TO_DATASET_TYPES.items():
if extension in dataset_config.path:
return dataset_type
return "json"
def datasets_with_name_generator(
dataset_configs: list[DictDefault],
) -> Generator[DictDefault, None, None]:
"""Yields expanded dataset configurations based on multiple names or preprocessing
shards.
When a dataset config has a list of names, it yields separate configs for each
name. When a dataset config specifies preprocessing shards, it yields configs for
each shard.
Args:
dataset_configs: List of dataset configuration objects.
Yields:
Individual dataset configurations, expanded as needed for names or shards.
"""
for config in dataset_configs:
if config.name and isinstance(config.name, list):
for name in config.name:
yield DictDefault({**config, "name": name})
elif config.preprocess_shards and not config.shards:
for shard_idx in range(config.preprocess_shards):
yield DictDefault(
{
**config,
"shards": config.preprocess_shards,
"shards_idx": shard_idx,
}
)
else:
yield config
def load_dataset_with_config(
dataset_config: DictDefault, use_auth_token: bool, streaming=False
) -> Dataset | IterableDataset:
"""Load a dataset from a config. Handles datasets that are stored locally, in the
HuggingFace Hub, in a remote filesystem (S3, GCS, Azure, OCI), a URL, or
`data_files`.
Args:
dataset_config: Single dataset config.
use_auth_token: Whether to use HF auth token.
streaming: Whether to stream the dataset.
Returns:
Loaded dataset.
"""
# Set up common kwargs for dataset loading
load_dataset_kwargs = {
"split": dataset_config.split if dataset_config.split else None,
"name": dataset_config.name,
"streaming": streaming,
"trust_remote_code": dataset_config.trust_remote_code,
}
# First check if it's a local path
if Path(dataset_config.path).exists():
return _load_from_local_path(dataset_config, load_dataset_kwargs)
# Check if it's a HuggingFace dataset
is_hub_dataset = _check_if_hub_dataset(dataset_config, use_auth_token)
# Check if it's a cloud storage path and get appropriate filesystem
remote_fs, storage_options = _get_remote_filesystem(dataset_config.path)
is_cloud_dataset = False
if remote_fs:
try:
is_cloud_dataset = remote_fs.exists(dataset_config.path)
except (FileNotFoundError, ConnectionError):
pass
# Load from appropriate source
if is_hub_dataset:
return _load_from_hub(dataset_config, use_auth_token, load_dataset_kwargs)
if is_cloud_dataset:
return _load_from_cloud(
dataset_config, remote_fs, storage_options, load_dataset_kwargs
)
if dataset_config.path.startswith("https://"):
return _load_from_url(dataset_config, load_dataset_kwargs)
if dataset_config.data_files:
return _load_from_data_files(dataset_config, load_dataset_kwargs)
raise ValueError(
f"The dataset could not be loaded. This could be due to a misconfigured dataset path "
f"({dataset_config.path}). Try double-check your path / name / data_files. "
f"This is not caused by the dataset type."
)
def _check_if_hub_dataset(dataset_config: DictDefault, use_auth_token: bool) -> bool:
"""Check if a dataset exists on the HuggingFace Hub."""
try:
snapshot_download(
repo_id=dataset_config.path,
repo_type="dataset",
token=use_auth_token,
revision=dataset_config.revision,
ignore_patterns=["*"],
)
return True
except (
RepositoryNotFoundError,
RevisionNotFoundError,
FileNotFoundError,
ConnectionError,
HFValidationError,
ValueError,
):
return False
def _get_remote_filesystem(
path: str,
) -> tuple[
S3FileSystem | GCSFileSystem | AzureBlobFileSystem | OCIFileSystem | None, dict
]:
"""Get the appropriate filesystem for a remote path."""
if path.startswith("s3://"):
try:
import s3fs
storage_options = {"anon": False}
return s3fs.S3FileSystem(**storage_options), storage_options
except ImportError as exc:
raise ImportError("s3:// paths require s3fs to be installed") from exc
elif path.startswith(("gs://", "gcs://")):
try:
import gcsfs
storage_options = {"token": None} # type: ignore
return gcsfs.GCSFileSystem(**storage_options), storage_options
except ImportError as exc:
raise ImportError(
"gs:// or gcs:// paths require gcsfs to be installed"
) from exc
elif path.startswith(("adl://", "abfs://", "az://")):
try:
import adlfs
storage_options = {"anon": False}
return adlfs.AzureBlobFileSystem(**storage_options), storage_options
except ImportError as exc:
raise ImportError(
"adl:// or abfs:// paths require adlfs to be installed"
) from exc
elif path.startswith("oci://"):
try:
import ocifs
storage_options = {}
return ocifs.OCIFileSystem(**storage_options), storage_options
except ImportError as exc:
raise ImportError("oci:// paths require ocifs to be installed") from exc
return None, {}
def _load_from_local_path(
dataset_config: DictDefault, load_dataset_kwargs: dict
) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict:
"""Load a dataset from a local path."""
local_path = Path(dataset_config.path)
if local_path.is_dir():
if dataset_config.data_files:
dataset_type = get_dataset_type(dataset_config)
return load_dataset(
dataset_type,
data_files=dataset_config.data_files,
**load_dataset_kwargs,
)
try:
return load_from_disk(dataset_config.path)
except FileNotFoundError:
return load_dataset(dataset_config.path, **load_dataset_kwargs)
elif local_path.is_file():
dataset_type = get_dataset_type(dataset_config)
# For single file datasets, HF always creates only a "train" split
if dataset_type in ("json", "csv", "text"):
load_dataset_kwargs["split"] = "train"
return load_dataset(
dataset_type,
data_files=dataset_config.path,
**load_dataset_kwargs,
)
else:
raise ValueError(
"Unhandled dataset load: local path exists, but is neither a directory or a file"
)
def _load_from_hub(
dataset_config: DictDefault, use_auth_token: bool, load_dataset_kwargs: dict
) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict:
"""Load a dataset from the HuggingFace Hub."""
return load_dataset(
dataset_config.path,
data_files=dataset_config.data_files,
token=use_auth_token,
revision=dataset_config.revision,
**load_dataset_kwargs,
)
def _load_from_cloud(
dataset_config: DictDefault,
remote_fs: S3FileSystem | GCSFileSystem | AzureBlobFileSystem | OCIFileSystem,
storage_options: dict,
load_dataset_kwargs: dict,
) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict:
"""Load a dataset from cloud storage."""
if remote_fs.isdir(dataset_config.path):
return load_from_disk(
dataset_config.path,
storage_options=storage_options,
)
if remote_fs.isfile(dataset_config.path):
dataset_type = get_dataset_type(dataset_config)
return load_dataset(
dataset_type,
data_files=dataset_config.path,
storage_options=storage_options,
**load_dataset_kwargs,
)
raise ValueError(
f"Cloud path {dataset_config.path} is neither a directory nor a file"
)
def _load_from_url(
dataset_config: DictDefault, load_dataset_kwargs: dict
) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict:
"""Load a dataset from a URL."""
dataset_type = get_dataset_type(dataset_config)
return load_dataset(
dataset_type,
data_files=dataset_config.path,
**load_dataset_kwargs,
)
def _load_from_data_files(
dataset_config: DictDefault, load_dataset_kwargs: dict
) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict:
"""Load a dataset from data files."""
file_path = None
if isinstance(dataset_config.data_files, str):
file_path = hf_hub_download(
repo_id=dataset_config.path,
repo_type="dataset",
filename=dataset_config.data_files,
revision=dataset_config.revision,
)
elif isinstance(dataset_config.data_files, list):
file_path = [
hf_hub_download(
repo_id=dataset_config.path,
repo_type="dataset",
filename=file,
revision=dataset_config.revision,
)
for file in dataset_config.data_files
]
else:
raise ValueError("data_files must be either a string or list of strings")
return load_dataset("json", data_files=file_path, **load_dataset_kwargs)
def generate_split_fingerprints(
dataset: Dataset, val_set_size: int | float, seed: int
) -> tuple[str, str]:
"""Generate consistent fingerprints for train/test splits."""
fingerprint = dataset._fingerprint
train_hash_input = f"{fingerprint}|{val_set_size}|train|{seed}"
test_hash_input = f"{fingerprint}|{val_set_size}|test|{seed}"
train_fingerprint = md5(train_hash_input)
test_fingerprint = md5(test_hash_input)
return train_fingerprint, test_fingerprint
def get_prepared_dataset_path(cfg: DictDefault, dataset_hash: str) -> Path:
"""Get standardized path for prepared datasets.
Args:
cfg: Configuration object.
dataset_hash: Hash identifying the specific dataset configuration.
Returns:
Path where the prepared dataset should be stored.
"""
base_path = cfg.dataset_prepared_path or DEFAULT_DATASET_PREPARED_PATH
return Path(base_path) / dataset_hash
def create_train_validation_split(
dataset: Dataset, cfg: DictDefault, val_set_size: int | float
) -> tuple[Dataset, Dataset]:
"""Create train/validation split with consistent fingerprinting.
Args:
dataset: Dataset to split.
cfg: Configuration object containing seed and other settings.
val_set_size: Size of validation set (absolute number or fraction).
Returns:
Tuple of (train_dataset, eval_dataset).
"""
train_fingerprint, test_fingerprint = generate_split_fingerprints(
dataset, val_set_size, cfg.seed
)
# Apply deduplication before splitting if configured
if cfg.dataset_exact_deduplication:
dataset, _ = deduplicate_and_log_datasets(dataset=dataset)
split_dataset = dataset.train_test_split(
test_size=val_set_size,
shuffle=False,
seed=cfg.seed,
train_new_fingerprint=train_fingerprint,
test_new_fingerprint=test_fingerprint,
)
return split_dataset["train"], split_dataset["test"]
def _generate_from_iterable_dataset(
dataset: IterableDataset, worker_id: list[int], num_workers: list[int]
) -> Generator[Any, None, None]:
"""Generator function to correctly split the dataset for each worker"""
for i, item in enumerate(dataset):
if i % num_workers[0] == worker_id[0]:
yield item
def save_preprocessed_dataset(
cfg: DictDefault,
dataset: Dataset,
dataset_hash: str,
split: str,
) -> None:
"""Save preprocessed dataset to disk and optionally push to the HF Hub."""
prepared_ds_path = get_prepared_dataset_path(cfg, dataset_hash)
num_workers = cfg.dataset_num_proc or get_default_process_count()
if isinstance(dataset, IterableDataset):
ds_from_iter = Dataset.from_generator(
functools.partial(_generate_from_iterable_dataset, dataset),
features=dataset.features,
num_proc=num_workers,
split=split,
gen_kwargs={
"worker_id": list(range(num_workers)),
"num_workers": [num_workers] * num_workers,
},
)
ds_from_iter.save_to_disk(
str(prepared_ds_path),
num_proc=num_workers,
max_shard_size=None,
num_shards=cfg.num_dataset_shards_to_save,
)
else:
min_rows_per_proc = 256
os.makedirs(prepared_ds_path, exist_ok=True)
dataset.save_to_disk(
str(prepared_ds_path),
num_proc=min(max(1, len(dataset) // min_rows_per_proc), num_workers),
max_shard_size=None,
num_shards=cfg.num_dataset_shards_to_save,
)
if cfg.push_dataset_to_hub:
LOG.info(
"Pushing merged prepared dataset to Huggingface hub at "
f"{cfg.push_dataset_to_hub} (version {dataset_hash})...",
main_process_only=False,
)
dataset.push_to_hub(
cfg.push_dataset_to_hub,
dataset_hash,
private=True,
)
def load_preprocessed_dataset(cfg: DictDefault, dataset_hash: str) -> Dataset | None:
"""Load preprocessed dataset from disk if available.
Args:
cfg: Configuration object.
dataset_hash: Hash identifying the dataset configuration.
Returns:
Loaded dataset if found and conditions are met, None otherwise.
"""
prepared_ds_path = get_prepared_dataset_path(cfg, dataset_hash)
if (
cfg.dataset_prepared_path
and any(prepared_ds_path.glob("*"))
and not cfg.skip_prepare_dataset
and not cfg.is_preprocess
):
LOG.info(
f"Loading prepared dataset from disk at {prepared_ds_path}...",
main_process_only=True,
)
return load_from_disk(str(prepared_ds_path))
LOG.info(
f"Unable to find prepared dataset in {prepared_ds_path}",
main_process_only=True,
)
return None
def try_load_from_hub(
cfg: DictDefault, dataset_hash: str, split: str
) -> Dataset | None:
"""Try to load the prepared dataset from HuggingFace Hub."""
try:
LOG.info(
"Attempting to load prepared dataset from HuggingFace Hub at "
f"{cfg.push_dataset_to_hub} (version {dataset_hash})..."
)
dataset = load_dataset(
cfg.push_dataset_to_hub,
dataset_hash,
token=cfg.hf_use_auth_token,
)
return dataset[split]
except Exception:
LOG.info("Unable to find prepared dataset in HuggingFace Hub")
return None
def generate_dataset_hash_from_config(
cfg: DictDefault, cfg_datasets: list, tokenizer_name: str
) -> str:
"""Generate a hash to uniquely identify a dataset configuration for SFT.
Args:
cfg: Main configuration object.
cfg_datasets: List of dataset configurations.
tokenizer_name: Name of the tokenizer being used.
Returns:
MD5 hash string representing the configuration.
"""
config_str = (
f"{cfg.sequence_len}@{cfg.sample_packing}@{cfg.eval_sample_packing}@"
f"{cfg.group_by_length}@{cfg.kd_temperature or 1.0}|"
f"{'|'.join(sorted([f'{d.path}:{d.type}:{d.shards}:{d.conversation}:{d.split}:{d.temperature or 1.0}' for d in cfg_datasets]))}"
f"|{tokenizer_name}"
)
return str(md5(config_str))
def merge_datasets(datasets: list[Dataset], cfg: DictDefault) -> Dataset:
"""Merge multiple datasets into one with optional shuffling.
Args:
datasets: List of datasets to merge.
cfg: Configuration object containing shuffle settings.
Returns:
Merged dataset.
"""
if len(datasets) == 1:
ds = datasets[0]
# Do not shuffle if curriculum sampling is enabled or
# shuffle_merged_datasets is disabled
if cfg.curriculum_sampling or not cfg.shuffle_merged_datasets:
return ds
return ds.shuffle(seed=cfg.seed)
# If enabled, shuffle each dataset independently before merging.
# This allows curriculum learning strategies to be applied at the dataset level.
if cfg.shuffle_before_merging_datasets:
LOG.info("Shuffling each dataset individually before merging...")
datasets = [ds.shuffle(seed=cfg.seed) for ds in datasets]
LOG.info("Merging datasets...")
merged_dataset = concatenate_datasets(datasets)
if cfg.shuffle_merged_datasets:
LOG.debug("Shuffling merged datasets...")
if cfg.curriculum_sampling:
LOG.warning(
"Shuffling merged datasets with curriculum sampling is not recommended. "
"This will randomize the order of samples."
)
merged_dataset = merged_dataset.shuffle(seed=cfg.seed)
else:
LOG.debug("Not shuffling merged datasets.")
return merged_dataset