Add a config not to shuffle merged dataset (#1394) [skip ci]

* Add a config not to shuffle merged dataset

* Update README.md

* Update src/axolotl/utils/config/models/input/v0_4_1/__init__.py

Co-authored-by: Wing Lian <wing.lian@gmail.com>

* invert the condition name

* update README

* info -> debug

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
Seungduk Kim
2024-03-19 20:51:00 +09:00
committed by GitHub
parent b1e3e1b25f
commit 43bdc5d3de
3 changed files with 15 additions and 3 deletions

View File

@@ -678,6 +678,10 @@ datasets:
# For `completion` datsets only, uses the provided field instead of `text` column
field:
# If false, the datasets will not be shuffled and will keep their original order in `datasets`.
# The same applies to the `test_datasets` option and the `pretraining_dataset` option. Default is true.
shuffle_merged_datasets: true
# A list of one or more datasets to eval the model with.
# You can use either test_datasets, or val_set_size, but not both.
test_datasets:

View File

@@ -416,6 +416,7 @@ class AxolotlInputConfig(
datasets: Optional[conlist(Union[SFTDataset, DPODataset], min_length=1)] = None # type: ignore
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset], min_length=1)] = None # type: ignore
shuffle_merged_datasets: Optional[bool] = True
dataset_prepared_path: Optional[str] = None
dataset_shard_num: Optional[int] = None
dataset_shard_idx: Optional[int] = None

View File

@@ -415,8 +415,11 @@ def load_tokenized_prepared_datasets(
dataset = concatenate_datasets(datasets)
if len(datasets) > 1:
LOG.info("shuffle merged datasets")
dataset = dataset.shuffle(seed=seed)
if cfg.shuffle_merged_datasets:
LOG.debug("shuffle merged datasets")
dataset = dataset.shuffle(seed=seed)
else:
LOG.debug("NOT shuffling merged datasets")
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
@@ -819,7 +822,11 @@ def wrap_pretraining_dataset(
else:
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size)
if cfg.shuffle_merged_datasets:
dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size)
else:
LOG.debug("NOT shuffling merged pretraining datasets")
dataset = dataset.map(
encode,
batched=True,