skip over rows in pretraining dataset (#2223)

* skip over rows in pretraining dataset

* update docs
This commit is contained in:
Wing Lian
2025-01-13 10:44:45 -05:00
committed by GitHub
parent bc1c9c20e3
commit f89e962119
3 changed files with 16 additions and 4 deletions

View File

@@ -129,6 +129,7 @@ class PretrainingDataset(BaseModel):
type: Optional[str] = "pretrain"
trust_remote_code: Optional[bool] = False
data_files: Optional[str] = None
skip: Optional[int] = None
class UserDefinedPrompterType(BaseModel):

View File

@@ -89,11 +89,13 @@ def prepare_dataset(cfg, tokenizer, processor=None):
split = "train"
name = None
data_files = None
skip = 0
if isinstance(cfg.pretraining_dataset, list) and isinstance(
cfg.pretraining_dataset[0], dict
):
path = cfg.pretraining_dataset[0]["path"]
name = cfg.pretraining_dataset[0]["name"]
skip = cfg.pretraining_dataset[0]["skip"]
if "split" in cfg.pretraining_dataset[0]:
split = cfg.pretraining_dataset[0]["split"]
@@ -107,10 +109,12 @@ def prepare_dataset(cfg, tokenizer, processor=None):
cfg.pretraining_dataset[0]["type"] or "pretrain",
)
iter_ds = load_dataset(path, streaming=True, split=split, name=name, data_files=data_files)
if skip:
LOG.info(f"Skipping {skip} samples from the dataset")
iter_ds = iter_ds.skip(skip)
train_dataset = wrap_pretraining_dataset(
load_dataset(
path, streaming=True, split=split, name=name, data_files=data_files
),
iter_ds,
tokenizer,
cfg,
ds_wrapper_partial,