From fdbb1a207cd8f68f658c466698ba42c88fe600b6 Mon Sep 17 00:00:00 2001 From: Lee Park <56708372+leeparkuky@users.noreply.github.com> Date: Thu, 13 Feb 2025 08:31:07 -0500 Subject: [PATCH] [Fixing #2149] load_from_disk for RL-type training (#2193) * Update rl.py * Update rl.py * Update rl.py * refactor pref dataset loading to reuse load_dataset_w_config * refactor again after rebase from main * chore: add docstring and types --------- Co-authored-by: Wing Lian Co-authored-by: NanoCode012 --- src/axolotl/utils/data/rl.py | 28 +++++++--------------- src/axolotl/utils/data/sft.py | 23 ++---------------- src/axolotl/utils/data/shared.py | 41 +++++++++++++++++++++++++++++--- 3 files changed, 49 insertions(+), 43 deletions(-) diff --git a/src/axolotl/utils/data/rl.py b/src/axolotl/utils/data/rl.py index 9f5c726ab..67075cc9f 100644 --- a/src/axolotl/utils/data/rl.py +++ b/src/axolotl/utils/data/rl.py @@ -4,15 +4,16 @@ import inspect import logging from functools import partial from pathlib import Path -from typing import Any, List +from typing import Any, List, Union import yaml -from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk +from datasets import Dataset, DatasetDict, concatenate_datasets, load_from_disk from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH from axolotl.prompt_strategies.dpo import load as load_dpo from axolotl.prompt_strategies.kto import load as load_kto from axolotl.prompt_strategies.orpo import load as load_orpo +from axolotl.utils.data.shared import datasets_w_name_generator, load_dataset_w_config from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5 from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import is_main_process, zero_first @@ -118,23 +119,12 @@ def drop_long_rl_seq( def load_prepare_preference_datasets(cfg): def load_split(dataset_cfgs, _cfg): split_datasets: List[Any] = [] - for i, ds_cfg in enumerate(dataset_cfgs): - if ds_cfg["ds_type"] == "json": - for data_file in ds_cfg["data_files"]: - data_files = {ds_cfg["split"]: data_file} - ds = load_dataset( # pylint: disable=invalid-name - "json", - data_files=data_files, - split=ds_cfg["split"], - ) - split_datasets.insert(i, ds) - else: - ds = load_dataset( # pylint: disable=invalid-name - ds_cfg["path"], - split=ds_cfg["split"], - revision=ds_cfg.get("revision", None), - ) - split_datasets.insert(i, ds) + use_auth_token = _cfg.hf_use_auth_token + for config_dataset in datasets_w_name_generator(dataset_cfgs): + ds: Union[Dataset, DatasetDict] = load_dataset_w_config( + config_dataset, use_auth_token, streaming=False + ) + split_datasets.append(ds) tokenizer = load_tokenizer(cfg) diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index 722ad2de2..52ca91365 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -43,7 +43,7 @@ from axolotl.prompters import ( UnsupportedPrompter, ) from axolotl.utils.data.pretraining import wrap_pretraining_dataset -from axolotl.utils.data.shared import load_dataset_w_config +from axolotl.utils.data.shared import datasets_w_name_generator, load_dataset_w_config from axolotl.utils.data.utils import ( deduplicate_and_log_datasets, drop_long_seq_in_dataset, @@ -263,30 +263,11 @@ def load_tokenized_prepared_datasets( datasets = [] - def for_d_in_datasets(dataset_configs): - for dataset in dataset_configs: - if dataset.name and isinstance(dataset.name, list): - # load_dataset doesn't properly handle multiple named configurations - # at the same time for a given dataset - for name in dataset.name: - yield DictDefault({**dataset, "name": name}) - elif dataset.preprocess_shards and not dataset.shards: - for shard in range(dataset.preprocess_shards): - yield DictDefault( - { - **dataset, - "shards": dataset.preprocess_shards, - "shards_idx": shard, - } - ) - else: - yield dataset - streaming_ds = False if preprocess_iterable: streaming_ds = True # pylint: disable=invalid-name - for config_dataset in for_d_in_datasets(cfg_datasets): + for config_dataset in datasets_w_name_generator(cfg_datasets): ds: Union[Dataset, DatasetDict] = load_dataset_w_config( config_dataset, use_auth_token, streaming=streaming_ds ) diff --git a/src/axolotl/utils/data/shared.py b/src/axolotl/utils/data/shared.py index 013d7a895..405057efc 100644 --- a/src/axolotl/utils/data/shared.py +++ b/src/axolotl/utils/data/shared.py @@ -1,6 +1,7 @@ """ dataset loading shared utils """ + from pathlib import Path from typing import Optional, Union @@ -29,9 +30,43 @@ def get_ds_type(config_dataset: DictDefault): return ds_type +def datasets_w_name_generator(dataset_configs: list[DictDefault]): + """ + Yields dataset configs handling multiple names or preprocess_shards + + Args: + dataset_configs: list of dataset configs (equivalent to cfg.datasets) + """ + for dataset in dataset_configs: + if dataset.name and isinstance(dataset.name, list): + # load_dataset doesn't properly handle multiple named configurations + # at the same time for a given dataset + for name in dataset.name: + yield DictDefault({**dataset, "name": name}) + elif dataset.preprocess_shards and not dataset.shards: + for shard in range(dataset.preprocess_shards): + yield DictDefault( + { + **dataset, + "shards": dataset.preprocess_shards, + "shards_idx": shard, + } + ) + else: + yield dataset + + def load_dataset_w_config( - config_dataset, auth_token, streaming=False + config_dataset: DictDefault, use_auth_token: bool, streaming=False ) -> Union[Dataset, DatasetDict]: + """ + Load a dataset from a config + + Args: + config_dataset: single dataset config + use_auth_token: whether to use HF auth token + streaming: whether to stream the dataset + """ # pylint: disable=invalid-name ds: Optional[Union[Dataset, DatasetDict]] = None # pylint: disable=invalid-name ds_from_hub = False @@ -43,7 +78,7 @@ def load_dataset_w_config( config_dataset.path, name=config_dataset.name, streaming=True, - token=auth_token, + token=use_auth_token, revision=config_dataset.revision, trust_remote_code=ds_trust_remote_code, ) @@ -161,7 +196,7 @@ def load_dataset_w_config( name=config_dataset.name, streaming=streaming, data_files=config_dataset.data_files, - token=auth_token, + token=use_auth_token, revision=config_dataset.revision, trust_remote_code=config_dataset.trust_remote_code, **load_ds_kwargs,