* Update rl.py * Update rl.py * Update rl.py * refactor pref dataset loading to reuse load_dataset_w_config * refactor again after rebase from main * chore: add docstring and types --------- Co-authored-by: Wing Lian <wing@axolotl.ai> Co-authored-by: NanoCode012 <nano@axolotl.ai>
This commit is contained in:
@@ -4,15 +4,16 @@ import inspect
|
|||||||
import logging
|
import logging
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, List
|
from typing import Any, List, Union
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk
|
from datasets import Dataset, DatasetDict, concatenate_datasets, load_from_disk
|
||||||
|
|
||||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||||
from axolotl.prompt_strategies.dpo import load as load_dpo
|
from axolotl.prompt_strategies.dpo import load as load_dpo
|
||||||
from axolotl.prompt_strategies.kto import load as load_kto
|
from axolotl.prompt_strategies.kto import load as load_kto
|
||||||
from axolotl.prompt_strategies.orpo import load as load_orpo
|
from axolotl.prompt_strategies.orpo import load as load_orpo
|
||||||
|
from axolotl.utils.data.shared import datasets_w_name_generator, load_dataset_w_config
|
||||||
from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5
|
from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.distributed import is_main_process, zero_first
|
from axolotl.utils.distributed import is_main_process, zero_first
|
||||||
@@ -118,23 +119,12 @@ def drop_long_rl_seq(
|
|||||||
def load_prepare_preference_datasets(cfg):
|
def load_prepare_preference_datasets(cfg):
|
||||||
def load_split(dataset_cfgs, _cfg):
|
def load_split(dataset_cfgs, _cfg):
|
||||||
split_datasets: List[Any] = []
|
split_datasets: List[Any] = []
|
||||||
for i, ds_cfg in enumerate(dataset_cfgs):
|
use_auth_token = _cfg.hf_use_auth_token
|
||||||
if ds_cfg["ds_type"] == "json":
|
for config_dataset in datasets_w_name_generator(dataset_cfgs):
|
||||||
for data_file in ds_cfg["data_files"]:
|
ds: Union[Dataset, DatasetDict] = load_dataset_w_config(
|
||||||
data_files = {ds_cfg["split"]: data_file}
|
config_dataset, use_auth_token, streaming=False
|
||||||
ds = load_dataset( # pylint: disable=invalid-name
|
)
|
||||||
"json",
|
split_datasets.append(ds)
|
||||||
data_files=data_files,
|
|
||||||
split=ds_cfg["split"],
|
|
||||||
)
|
|
||||||
split_datasets.insert(i, ds)
|
|
||||||
else:
|
|
||||||
ds = load_dataset( # pylint: disable=invalid-name
|
|
||||||
ds_cfg["path"],
|
|
||||||
split=ds_cfg["split"],
|
|
||||||
revision=ds_cfg.get("revision", None),
|
|
||||||
)
|
|
||||||
split_datasets.insert(i, ds)
|
|
||||||
|
|
||||||
tokenizer = load_tokenizer(cfg)
|
tokenizer = load_tokenizer(cfg)
|
||||||
|
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ from axolotl.prompters import (
|
|||||||
UnsupportedPrompter,
|
UnsupportedPrompter,
|
||||||
)
|
)
|
||||||
from axolotl.utils.data.pretraining import wrap_pretraining_dataset
|
from axolotl.utils.data.pretraining import wrap_pretraining_dataset
|
||||||
from axolotl.utils.data.shared import load_dataset_w_config
|
from axolotl.utils.data.shared import datasets_w_name_generator, load_dataset_w_config
|
||||||
from axolotl.utils.data.utils import (
|
from axolotl.utils.data.utils import (
|
||||||
deduplicate_and_log_datasets,
|
deduplicate_and_log_datasets,
|
||||||
drop_long_seq_in_dataset,
|
drop_long_seq_in_dataset,
|
||||||
@@ -263,30 +263,11 @@ def load_tokenized_prepared_datasets(
|
|||||||
|
|
||||||
datasets = []
|
datasets = []
|
||||||
|
|
||||||
def for_d_in_datasets(dataset_configs):
|
|
||||||
for dataset in dataset_configs:
|
|
||||||
if dataset.name and isinstance(dataset.name, list):
|
|
||||||
# load_dataset doesn't properly handle multiple named configurations
|
|
||||||
# at the same time for a given dataset
|
|
||||||
for name in dataset.name:
|
|
||||||
yield DictDefault({**dataset, "name": name})
|
|
||||||
elif dataset.preprocess_shards and not dataset.shards:
|
|
||||||
for shard in range(dataset.preprocess_shards):
|
|
||||||
yield DictDefault(
|
|
||||||
{
|
|
||||||
**dataset,
|
|
||||||
"shards": dataset.preprocess_shards,
|
|
||||||
"shards_idx": shard,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
yield dataset
|
|
||||||
|
|
||||||
streaming_ds = False
|
streaming_ds = False
|
||||||
if preprocess_iterable:
|
if preprocess_iterable:
|
||||||
streaming_ds = True
|
streaming_ds = True
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
for config_dataset in for_d_in_datasets(cfg_datasets):
|
for config_dataset in datasets_w_name_generator(cfg_datasets):
|
||||||
ds: Union[Dataset, DatasetDict] = load_dataset_w_config(
|
ds: Union[Dataset, DatasetDict] = load_dataset_w_config(
|
||||||
config_dataset, use_auth_token, streaming=streaming_ds
|
config_dataset, use_auth_token, streaming=streaming_ds
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
dataset loading shared utils
|
dataset loading shared utils
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
@@ -29,9 +30,43 @@ def get_ds_type(config_dataset: DictDefault):
|
|||||||
return ds_type
|
return ds_type
|
||||||
|
|
||||||
|
|
||||||
|
def datasets_w_name_generator(dataset_configs: list[DictDefault]):
|
||||||
|
"""
|
||||||
|
Yields dataset configs handling multiple names or preprocess_shards
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_configs: list of dataset configs (equivalent to cfg.datasets)
|
||||||
|
"""
|
||||||
|
for dataset in dataset_configs:
|
||||||
|
if dataset.name and isinstance(dataset.name, list):
|
||||||
|
# load_dataset doesn't properly handle multiple named configurations
|
||||||
|
# at the same time for a given dataset
|
||||||
|
for name in dataset.name:
|
||||||
|
yield DictDefault({**dataset, "name": name})
|
||||||
|
elif dataset.preprocess_shards and not dataset.shards:
|
||||||
|
for shard in range(dataset.preprocess_shards):
|
||||||
|
yield DictDefault(
|
||||||
|
{
|
||||||
|
**dataset,
|
||||||
|
"shards": dataset.preprocess_shards,
|
||||||
|
"shards_idx": shard,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
yield dataset
|
||||||
|
|
||||||
|
|
||||||
def load_dataset_w_config(
|
def load_dataset_w_config(
|
||||||
config_dataset, auth_token, streaming=False
|
config_dataset: DictDefault, use_auth_token: bool, streaming=False
|
||||||
) -> Union[Dataset, DatasetDict]:
|
) -> Union[Dataset, DatasetDict]:
|
||||||
|
"""
|
||||||
|
Load a dataset from a config
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_dataset: single dataset config
|
||||||
|
use_auth_token: whether to use HF auth token
|
||||||
|
streaming: whether to stream the dataset
|
||||||
|
"""
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
ds: Optional[Union[Dataset, DatasetDict]] = None # pylint: disable=invalid-name
|
ds: Optional[Union[Dataset, DatasetDict]] = None # pylint: disable=invalid-name
|
||||||
ds_from_hub = False
|
ds_from_hub = False
|
||||||
@@ -43,7 +78,7 @@ def load_dataset_w_config(
|
|||||||
config_dataset.path,
|
config_dataset.path,
|
||||||
name=config_dataset.name,
|
name=config_dataset.name,
|
||||||
streaming=True,
|
streaming=True,
|
||||||
token=auth_token,
|
token=use_auth_token,
|
||||||
revision=config_dataset.revision,
|
revision=config_dataset.revision,
|
||||||
trust_remote_code=ds_trust_remote_code,
|
trust_remote_code=ds_trust_remote_code,
|
||||||
)
|
)
|
||||||
@@ -161,7 +196,7 @@ def load_dataset_w_config(
|
|||||||
name=config_dataset.name,
|
name=config_dataset.name,
|
||||||
streaming=streaming,
|
streaming=streaming,
|
||||||
data_files=config_dataset.data_files,
|
data_files=config_dataset.data_files,
|
||||||
token=auth_token,
|
token=use_auth_token,
|
||||||
revision=config_dataset.revision,
|
revision=config_dataset.revision,
|
||||||
trust_remote_code=config_dataset.trust_remote_code,
|
trust_remote_code=config_dataset.trust_remote_code,
|
||||||
**load_ds_kwargs,
|
**load_ds_kwargs,
|
||||||
|
|||||||
Reference in New Issue
Block a user