[Fixing #2149] load_from_disk for RL-type training (#2193)

* Update rl.py

* Update rl.py

* Update rl.py

* refactor pref dataset loading to reuse load_dataset_w_config

* refactor again after rebase from main

* chore: add docstring and types

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: NanoCode012 <nano@axolotl.ai>
This commit is contained in:
Lee Park
2025-02-13 08:31:07 -05:00
committed by GitHub
parent 30046315d9
commit fdbb1a207c
3 changed files with 49 additions and 43 deletions

View File

@@ -4,15 +4,16 @@ import inspect
import logging
from functools import partial
from pathlib import Path
from typing import Any, List
from typing import Any, List, Union
import yaml
from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk
from datasets import Dataset, DatasetDict, concatenate_datasets, load_from_disk
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.prompt_strategies.dpo import load as load_dpo
from axolotl.prompt_strategies.kto import load as load_kto
from axolotl.prompt_strategies.orpo import load as load_orpo
from axolotl.utils.data.shared import datasets_w_name_generator, load_dataset_w_config
from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process, zero_first
@@ -118,23 +119,12 @@ def drop_long_rl_seq(
def load_prepare_preference_datasets(cfg):
def load_split(dataset_cfgs, _cfg):
split_datasets: List[Any] = []
for i, ds_cfg in enumerate(dataset_cfgs):
if ds_cfg["ds_type"] == "json":
for data_file in ds_cfg["data_files"]:
data_files = {ds_cfg["split"]: data_file}
ds = load_dataset( # pylint: disable=invalid-name
"json",
data_files=data_files,
split=ds_cfg["split"],
)
split_datasets.insert(i, ds)
else:
ds = load_dataset( # pylint: disable=invalid-name
ds_cfg["path"],
split=ds_cfg["split"],
revision=ds_cfg.get("revision", None),
)
split_datasets.insert(i, ds)
use_auth_token = _cfg.hf_use_auth_token
for config_dataset in datasets_w_name_generator(dataset_cfgs):
ds: Union[Dataset, DatasetDict] = load_dataset_w_config(
config_dataset, use_auth_token, streaming=False
)
split_datasets.append(ds)
tokenizer = load_tokenizer(cfg)

View File

@@ -43,7 +43,7 @@ from axolotl.prompters import (
UnsupportedPrompter,
)
from axolotl.utils.data.pretraining import wrap_pretraining_dataset
from axolotl.utils.data.shared import load_dataset_w_config
from axolotl.utils.data.shared import datasets_w_name_generator, load_dataset_w_config
from axolotl.utils.data.utils import (
deduplicate_and_log_datasets,
drop_long_seq_in_dataset,
@@ -263,30 +263,11 @@ def load_tokenized_prepared_datasets(
datasets = []
def for_d_in_datasets(dataset_configs):
for dataset in dataset_configs:
if dataset.name and isinstance(dataset.name, list):
# load_dataset doesn't properly handle multiple named configurations
# at the same time for a given dataset
for name in dataset.name:
yield DictDefault({**dataset, "name": name})
elif dataset.preprocess_shards and not dataset.shards:
for shard in range(dataset.preprocess_shards):
yield DictDefault(
{
**dataset,
"shards": dataset.preprocess_shards,
"shards_idx": shard,
}
)
else:
yield dataset
streaming_ds = False
if preprocess_iterable:
streaming_ds = True
# pylint: disable=invalid-name
for config_dataset in for_d_in_datasets(cfg_datasets):
for config_dataset in datasets_w_name_generator(cfg_datasets):
ds: Union[Dataset, DatasetDict] = load_dataset_w_config(
config_dataset, use_auth_token, streaming=streaming_ds
)

View File

@@ -1,6 +1,7 @@
"""
dataset loading shared utils
"""
from pathlib import Path
from typing import Optional, Union
@@ -29,9 +30,43 @@ def get_ds_type(config_dataset: DictDefault):
return ds_type
def datasets_w_name_generator(dataset_configs: list[DictDefault]):
"""
Yields dataset configs handling multiple names or preprocess_shards
Args:
dataset_configs: list of dataset configs (equivalent to cfg.datasets)
"""
for dataset in dataset_configs:
if dataset.name and isinstance(dataset.name, list):
# load_dataset doesn't properly handle multiple named configurations
# at the same time for a given dataset
for name in dataset.name:
yield DictDefault({**dataset, "name": name})
elif dataset.preprocess_shards and not dataset.shards:
for shard in range(dataset.preprocess_shards):
yield DictDefault(
{
**dataset,
"shards": dataset.preprocess_shards,
"shards_idx": shard,
}
)
else:
yield dataset
def load_dataset_w_config(
config_dataset, auth_token, streaming=False
config_dataset: DictDefault, use_auth_token: bool, streaming=False
) -> Union[Dataset, DatasetDict]:
"""
Load a dataset from a config
Args:
config_dataset: single dataset config
use_auth_token: whether to use HF auth token
streaming: whether to stream the dataset
"""
# pylint: disable=invalid-name
ds: Optional[Union[Dataset, DatasetDict]] = None # pylint: disable=invalid-name
ds_from_hub = False
@@ -43,7 +78,7 @@ def load_dataset_w_config(
config_dataset.path,
name=config_dataset.name,
streaming=True,
token=auth_token,
token=use_auth_token,
revision=config_dataset.revision,
trust_remote_code=ds_trust_remote_code,
)
@@ -161,7 +196,7 @@ def load_dataset_w_config(
name=config_dataset.name,
streaming=streaming,
data_files=config_dataset.data_files,
token=auth_token,
token=use_auth_token,
revision=config_dataset.revision,
trust_remote_code=config_dataset.trust_remote_code,
**load_ds_kwargs,