skip over rows in pretraining dataset (#2223)
* skip over rows in pretraining dataset * update docs
This commit is contained in:
@@ -19,7 +19,14 @@ For pretraining, there is no prompt template or roles. The only required field
|
|||||||
Axolotl usually loads the entire dataset into memory. This will be challenging for large datasets. Use the following config to enable streaming:
|
Axolotl usually loads the entire dataset into memory. This will be challenging for large datasets. Use the following config to enable streaming:
|
||||||
|
|
||||||
```{.yaml filename="config.yaml"}
|
```{.yaml filename="config.yaml"}
|
||||||
pretraining_dataset: # hf path only
|
pretraining_dataset:
|
||||||
|
- name:
|
||||||
|
path:
|
||||||
|
split:
|
||||||
|
text_column: # column in dataset with the data, usually `text`
|
||||||
|
type: pretrain
|
||||||
|
trust_remote_code:
|
||||||
|
skip: # number of rows of data to skip over from the beginning
|
||||||
...
|
...
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -129,6 +129,7 @@ class PretrainingDataset(BaseModel):
|
|||||||
type: Optional[str] = "pretrain"
|
type: Optional[str] = "pretrain"
|
||||||
trust_remote_code: Optional[bool] = False
|
trust_remote_code: Optional[bool] = False
|
||||||
data_files: Optional[str] = None
|
data_files: Optional[str] = None
|
||||||
|
skip: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
class UserDefinedPrompterType(BaseModel):
|
class UserDefinedPrompterType(BaseModel):
|
||||||
|
|||||||
@@ -89,11 +89,13 @@ def prepare_dataset(cfg, tokenizer, processor=None):
|
|||||||
split = "train"
|
split = "train"
|
||||||
name = None
|
name = None
|
||||||
data_files = None
|
data_files = None
|
||||||
|
skip = 0
|
||||||
if isinstance(cfg.pretraining_dataset, list) and isinstance(
|
if isinstance(cfg.pretraining_dataset, list) and isinstance(
|
||||||
cfg.pretraining_dataset[0], dict
|
cfg.pretraining_dataset[0], dict
|
||||||
):
|
):
|
||||||
path = cfg.pretraining_dataset[0]["path"]
|
path = cfg.pretraining_dataset[0]["path"]
|
||||||
name = cfg.pretraining_dataset[0]["name"]
|
name = cfg.pretraining_dataset[0]["name"]
|
||||||
|
skip = cfg.pretraining_dataset[0]["skip"]
|
||||||
if "split" in cfg.pretraining_dataset[0]:
|
if "split" in cfg.pretraining_dataset[0]:
|
||||||
split = cfg.pretraining_dataset[0]["split"]
|
split = cfg.pretraining_dataset[0]["split"]
|
||||||
|
|
||||||
@@ -107,10 +109,12 @@ def prepare_dataset(cfg, tokenizer, processor=None):
|
|||||||
cfg.pretraining_dataset[0]["type"] or "pretrain",
|
cfg.pretraining_dataset[0]["type"] or "pretrain",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
iter_ds = load_dataset(path, streaming=True, split=split, name=name, data_files=data_files)
|
||||||
|
if skip:
|
||||||
|
LOG.info(f"Skipping {skip} samples from the dataset")
|
||||||
|
iter_ds = iter_ds.skip(skip)
|
||||||
train_dataset = wrap_pretraining_dataset(
|
train_dataset = wrap_pretraining_dataset(
|
||||||
load_dataset(
|
iter_ds,
|
||||||
path, streaming=True, split=split, name=name, data_files=data_files
|
|
||||||
),
|
|
||||||
tokenizer,
|
tokenizer,
|
||||||
cfg,
|
cfg,
|
||||||
ds_wrapper_partial,
|
ds_wrapper_partial,
|
||||||
|
|||||||
Reference in New Issue
Block a user