[Fixing #2149] load_from_disk for RL-type training (#2193)

* Update rl.py

* Update rl.py

* Update rl.py

* refactor pref dataset loading to reuse load_dataset_w_config

* refactor again after rebase from main

* chore: add docstring and types

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: NanoCode012 <nano@axolotl.ai>
This commit is contained in:
Lee Park
2025-02-13 08:31:07 -05:00
committed by GitHub
parent 30046315d9
commit fdbb1a207c
3 changed files with 49 additions and 43 deletions

View File

@@ -43,7 +43,7 @@ from axolotl.prompters import (
UnsupportedPrompter,
)
from axolotl.utils.data.pretraining import wrap_pretraining_dataset
from axolotl.utils.data.shared import load_dataset_w_config
from axolotl.utils.data.shared import datasets_w_name_generator, load_dataset_w_config
from axolotl.utils.data.utils import (
deduplicate_and_log_datasets,
drop_long_seq_in_dataset,
@@ -263,30 +263,11 @@ def load_tokenized_prepared_datasets(
datasets = []
def for_d_in_datasets(dataset_configs):
for dataset in dataset_configs:
if dataset.name and isinstance(dataset.name, list):
# load_dataset doesn't properly handle multiple named configurations
# at the same time for a given dataset
for name in dataset.name:
yield DictDefault({**dataset, "name": name})
elif dataset.preprocess_shards and not dataset.shards:
for shard in range(dataset.preprocess_shards):
yield DictDefault(
{
**dataset,
"shards": dataset.preprocess_shards,
"shards_idx": shard,
}
)
else:
yield dataset
streaming_ds = False
if preprocess_iterable:
streaming_ds = True
# pylint: disable=invalid-name
for config_dataset in for_d_in_datasets(cfg_datasets):
for config_dataset in datasets_w_name_generator(cfg_datasets):
ds: Union[Dataset, DatasetDict] = load_dataset_w_config(
config_dataset, use_auth_token, streaming=streaming_ds
)