fix pretraining_ on odd datasets (#1463)

* can configure name of split of pretraining dataset

* streaming data and dataset map

* text column customized

* allow text_column to be set in pretrain

* pretrain type

* load a bit of the dataset

* fix dataset where splits have separate configs

* ok name param here is the config

* whitespace
This commit is contained in:
Nick Doiron
2024-04-01 23:48:59 -04:00
committed by GitHub
parent 86b7d22f35
commit 586bd8d221
3 changed files with 25 additions and 7 deletions

View File

@@ -20,10 +20,11 @@ class PretrainTokenizationStrategy(PromptTokenizingStrategy):
def supports_batched(self):
return True
def __init__(self, *args, max_length=None, **kwargs):
def __init__(self, *args, max_length=None, text_column="text", **kwargs):
super().__init__(*args, **kwargs)
if max_length:
self.max_length = max_length
self.text_column = text_column
def _tokenize(
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
@@ -44,7 +45,7 @@ class PretrainTokenizationStrategy(PromptTokenizingStrategy):
return res
def tokenize_prompt(self, prompt):
return self._tokenize(prompt["text"])
return self._tokenize(prompt[self.text_column])
def load(tokenizer, cfg):
@@ -53,6 +54,7 @@ def load(tokenizer, cfg):
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
text_column=cfg.pretraining_dataset[0]["text_column"] or "text",
max_length=cfg.sequence_len * 64,
)
return strat

View File

@@ -61,7 +61,11 @@ class RemappedParameters(BaseModel):
class PretrainingDataset(BaseModel):
"""pretraining dataset configuration subset"""
name: Optional[str] = None
path: Optional[str] = None
split: Optional[str] = "train"
text_column: Optional[str] = "text"
type: Optional[str] = "pretrain"
class UserDefinedPrompterType(BaseModel):
@@ -448,7 +452,7 @@ class AxolotlInputConfig(
dataset_shard_idx: Optional[int] = None
pretraining_dataset: Optional[ # type: ignore
conlist(Union[SFTDataset, PretrainingDataset], min_length=1)
conlist(Union[PretrainingDataset, SFTDataset], min_length=1)
] = Field(
default=None, metadata={"help": {"streaming dataset to use for pretraining"}}
)

View File

@@ -82,12 +82,15 @@ def prepare_dataset(cfg, tokenizer):
)
else:
path = cfg.pretraining_dataset
split = "train"
name = None
if isinstance(cfg.pretraining_dataset, list) and isinstance(
cfg.pretraining_dataset[0], dict
):
path = cfg.pretraining_dataset[0]["path"]
name = cfg.pretraining_dataset[0]["name"]
if "split" in cfg.pretraining_dataset[0]:
split = cfg.pretraining_dataset[0]["split"]
ds_wrapper_partial = functools.partial(
get_dataset_wrapper,
@@ -98,7 +101,7 @@ def prepare_dataset(cfg, tokenizer):
)
train_dataset = wrap_pretraining_dataset(
load_dataset(path, streaming=True, split="train", name=name),
load_dataset(path, streaming=True, split=split, name=name),
tokenizer,
cfg,
ds_wrapper_partial,
@@ -831,14 +834,23 @@ def wrap_pretraining_dataset(
else:
LOG.debug("NOT shuffling merged pretraining datasets")
# remove all the existing columns after mapping since they end up having
# a different length than the encoded/tokenized column
# this is empty during streaming/pretraining
remove_columns = []
if dataset.features is None:
for first_row in dataset:
remove_columns = first_row.keys()
break
else:
remove_columns = dataset.features.keys()
dataset = dataset.map(
encode,
batched=True,
batch_size=buffer_size,
# input_columns="text",
# remove all the existing columns after mapping since they end up having
# a different length than the encoded/tokenized column
remove_columns=dataset.features.keys(),
remove_columns=remove_columns,
)
return dataset