From f89e9621191f4460afe4425b3785dd8ca0482a9c Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 13 Jan 2025 10:44:45 -0500 Subject: [PATCH] skip over rows in pretraining dataset (#2223) * skip over rows in pretraining dataset * update docs --- docs/dataset-formats/pretraining.qmd | 9 ++++++++- .../utils/config/models/input/v0_4_1/__init__.py | 1 + src/axolotl/utils/data/sft.py | 10 +++++++--- 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/docs/dataset-formats/pretraining.qmd b/docs/dataset-formats/pretraining.qmd index bb591328e..600fb63e0 100644 --- a/docs/dataset-formats/pretraining.qmd +++ b/docs/dataset-formats/pretraining.qmd @@ -19,7 +19,14 @@ For pretraining, there is no prompt template or roles. The only required field Axolotl usually loads the entire dataset into memory. This will be challenging for large datasets. Use the following config to enable streaming: ```{.yaml filename="config.yaml"} -pretraining_dataset: # hf path only +pretraining_dataset: + - name: + path: + split: + text_column: # column in dataset with the data, usually `text` + type: pretrain + trust_remote_code: + skip: # number of rows of data to skip over from the beginning ... ``` diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 19ce7b18c..4f368994a 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -129,6 +129,7 @@ class PretrainingDataset(BaseModel): type: Optional[str] = "pretrain" trust_remote_code: Optional[bool] = False data_files: Optional[str] = None + skip: Optional[int] = None class UserDefinedPrompterType(BaseModel): diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index cfc40406e..aff047675 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -89,11 +89,13 @@ def prepare_dataset(cfg, tokenizer, processor=None): split = "train" name = None data_files = None + skip = 0 if isinstance(cfg.pretraining_dataset, list) and isinstance( cfg.pretraining_dataset[0], dict ): path = cfg.pretraining_dataset[0]["path"] name = cfg.pretraining_dataset[0]["name"] + skip = cfg.pretraining_dataset[0]["skip"] if "split" in cfg.pretraining_dataset[0]: split = cfg.pretraining_dataset[0]["split"] @@ -107,10 +109,12 @@ def prepare_dataset(cfg, tokenizer, processor=None): cfg.pretraining_dataset[0]["type"] or "pretrain", ) + iter_ds = load_dataset(path, streaming=True, split=split, name=name, data_files=data_files) + if skip: + LOG.info(f"Skipping {skip} samples from the dataset") + iter_ds = iter_ds.skip(skip) train_dataset = wrap_pretraining_dataset( - load_dataset( - path, streaming=True, split=split, name=name, data_files=data_files - ), + iter_ds, tokenizer, cfg, ds_wrapper_partial,