[Fixing #2149] load_from_disk for RL-type training (#2193)

* Update rl.py

* Update rl.py

* Update rl.py

* refactor pref dataset loading to reuse load_dataset_w_config

* refactor again after rebase from main

* chore: add docstring and types

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: NanoCode012 <nano@axolotl.ai>
This commit is contained in:
Lee Park
2025-02-13 08:31:07 -05:00
committed by GitHub
parent 30046315d9
commit fdbb1a207c
3 changed files with 49 additions and 43 deletions

View File

@@ -4,15 +4,16 @@ import inspect
import logging import logging
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
from typing import Any, List from typing import Any, List, Union
import yaml import yaml
from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk from datasets import Dataset, DatasetDict, concatenate_datasets, load_from_disk
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.prompt_strategies.dpo import load as load_dpo from axolotl.prompt_strategies.dpo import load as load_dpo
from axolotl.prompt_strategies.kto import load as load_kto from axolotl.prompt_strategies.kto import load as load_kto
from axolotl.prompt_strategies.orpo import load as load_orpo from axolotl.prompt_strategies.orpo import load as load_orpo
from axolotl.utils.data.shared import datasets_w_name_generator, load_dataset_w_config
from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5 from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process, zero_first from axolotl.utils.distributed import is_main_process, zero_first
@@ -118,23 +119,12 @@ def drop_long_rl_seq(
def load_prepare_preference_datasets(cfg): def load_prepare_preference_datasets(cfg):
def load_split(dataset_cfgs, _cfg): def load_split(dataset_cfgs, _cfg):
split_datasets: List[Any] = [] split_datasets: List[Any] = []
for i, ds_cfg in enumerate(dataset_cfgs): use_auth_token = _cfg.hf_use_auth_token
if ds_cfg["ds_type"] == "json": for config_dataset in datasets_w_name_generator(dataset_cfgs):
for data_file in ds_cfg["data_files"]: ds: Union[Dataset, DatasetDict] = load_dataset_w_config(
data_files = {ds_cfg["split"]: data_file} config_dataset, use_auth_token, streaming=False
ds = load_dataset( # pylint: disable=invalid-name )
"json", split_datasets.append(ds)
data_files=data_files,
split=ds_cfg["split"],
)
split_datasets.insert(i, ds)
else:
ds = load_dataset( # pylint: disable=invalid-name
ds_cfg["path"],
split=ds_cfg["split"],
revision=ds_cfg.get("revision", None),
)
split_datasets.insert(i, ds)
tokenizer = load_tokenizer(cfg) tokenizer = load_tokenizer(cfg)

View File

@@ -43,7 +43,7 @@ from axolotl.prompters import (
UnsupportedPrompter, UnsupportedPrompter,
) )
from axolotl.utils.data.pretraining import wrap_pretraining_dataset from axolotl.utils.data.pretraining import wrap_pretraining_dataset
from axolotl.utils.data.shared import load_dataset_w_config from axolotl.utils.data.shared import datasets_w_name_generator, load_dataset_w_config
from axolotl.utils.data.utils import ( from axolotl.utils.data.utils import (
deduplicate_and_log_datasets, deduplicate_and_log_datasets,
drop_long_seq_in_dataset, drop_long_seq_in_dataset,
@@ -263,30 +263,11 @@ def load_tokenized_prepared_datasets(
datasets = [] datasets = []
def for_d_in_datasets(dataset_configs):
for dataset in dataset_configs:
if dataset.name and isinstance(dataset.name, list):
# load_dataset doesn't properly handle multiple named configurations
# at the same time for a given dataset
for name in dataset.name:
yield DictDefault({**dataset, "name": name})
elif dataset.preprocess_shards and not dataset.shards:
for shard in range(dataset.preprocess_shards):
yield DictDefault(
{
**dataset,
"shards": dataset.preprocess_shards,
"shards_idx": shard,
}
)
else:
yield dataset
streaming_ds = False streaming_ds = False
if preprocess_iterable: if preprocess_iterable:
streaming_ds = True streaming_ds = True
# pylint: disable=invalid-name # pylint: disable=invalid-name
for config_dataset in for_d_in_datasets(cfg_datasets): for config_dataset in datasets_w_name_generator(cfg_datasets):
ds: Union[Dataset, DatasetDict] = load_dataset_w_config( ds: Union[Dataset, DatasetDict] = load_dataset_w_config(
config_dataset, use_auth_token, streaming=streaming_ds config_dataset, use_auth_token, streaming=streaming_ds
) )

View File

@@ -1,6 +1,7 @@
""" """
dataset loading shared utils dataset loading shared utils
""" """
from pathlib import Path from pathlib import Path
from typing import Optional, Union from typing import Optional, Union
@@ -29,9 +30,43 @@ def get_ds_type(config_dataset: DictDefault):
return ds_type return ds_type
def datasets_w_name_generator(dataset_configs: list[DictDefault]):
"""
Yields dataset configs handling multiple names or preprocess_shards
Args:
dataset_configs: list of dataset configs (equivalent to cfg.datasets)
"""
for dataset in dataset_configs:
if dataset.name and isinstance(dataset.name, list):
# load_dataset doesn't properly handle multiple named configurations
# at the same time for a given dataset
for name in dataset.name:
yield DictDefault({**dataset, "name": name})
elif dataset.preprocess_shards and not dataset.shards:
for shard in range(dataset.preprocess_shards):
yield DictDefault(
{
**dataset,
"shards": dataset.preprocess_shards,
"shards_idx": shard,
}
)
else:
yield dataset
def load_dataset_w_config( def load_dataset_w_config(
config_dataset, auth_token, streaming=False config_dataset: DictDefault, use_auth_token: bool, streaming=False
) -> Union[Dataset, DatasetDict]: ) -> Union[Dataset, DatasetDict]:
"""
Load a dataset from a config
Args:
config_dataset: single dataset config
use_auth_token: whether to use HF auth token
streaming: whether to stream the dataset
"""
# pylint: disable=invalid-name # pylint: disable=invalid-name
ds: Optional[Union[Dataset, DatasetDict]] = None # pylint: disable=invalid-name ds: Optional[Union[Dataset, DatasetDict]] = None # pylint: disable=invalid-name
ds_from_hub = False ds_from_hub = False
@@ -43,7 +78,7 @@ def load_dataset_w_config(
config_dataset.path, config_dataset.path,
name=config_dataset.name, name=config_dataset.name,
streaming=True, streaming=True,
token=auth_token, token=use_auth_token,
revision=config_dataset.revision, revision=config_dataset.revision,
trust_remote_code=ds_trust_remote_code, trust_remote_code=ds_trust_remote_code,
) )
@@ -161,7 +196,7 @@ def load_dataset_w_config(
name=config_dataset.name, name=config_dataset.name,
streaming=streaming, streaming=streaming,
data_files=config_dataset.data_files, data_files=config_dataset.data_files,
token=auth_token, token=use_auth_token,
revision=config_dataset.revision, revision=config_dataset.revision,
trust_remote_code=config_dataset.trust_remote_code, trust_remote_code=config_dataset.trust_remote_code,
**load_ds_kwargs, **load_ds_kwargs,