fix pretraining_ on odd datasets (#1463)

* can configure name of split of pretraining dataset

* streaming data and dataset map

* text column customized

* allow text_column to be set in pretrain

* pretrain type

* load a bit of the dataset

* fix dataset where splits have separate configs

* ok name param here is the config

* whitespace
This commit is contained in:
Nick Doiron
2024-04-01 23:48:59 -04:00
committed by GitHub
parent 86b7d22f35
commit 586bd8d221
3 changed files with 25 additions and 7 deletions

View File

@@ -20,10 +20,11 @@ class PretrainTokenizationStrategy(PromptTokenizingStrategy):
def supports_batched(self): def supports_batched(self):
return True return True
def __init__(self, *args, max_length=None, **kwargs): def __init__(self, *args, max_length=None, text_column="text", **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
if max_length: if max_length:
self.max_length = max_length self.max_length = max_length
self.text_column = text_column
def _tokenize( def _tokenize(
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
@@ -44,7 +45,7 @@ class PretrainTokenizationStrategy(PromptTokenizingStrategy):
return res return res
def tokenize_prompt(self, prompt): def tokenize_prompt(self, prompt):
return self._tokenize(prompt["text"]) return self._tokenize(prompt[self.text_column])
def load(tokenizer, cfg): def load(tokenizer, cfg):
@@ -53,6 +54,7 @@ def load(tokenizer, cfg):
tokenizer, tokenizer,
cfg.train_on_inputs, cfg.train_on_inputs,
cfg.sequence_len, cfg.sequence_len,
text_column=cfg.pretraining_dataset[0]["text_column"] or "text",
max_length=cfg.sequence_len * 64, max_length=cfg.sequence_len * 64,
) )
return strat return strat

View File

@@ -61,7 +61,11 @@ class RemappedParameters(BaseModel):
class PretrainingDataset(BaseModel): class PretrainingDataset(BaseModel):
"""pretraining dataset configuration subset""" """pretraining dataset configuration subset"""
name: Optional[str] = None
path: Optional[str] = None path: Optional[str] = None
split: Optional[str] = "train"
text_column: Optional[str] = "text"
type: Optional[str] = "pretrain"
class UserDefinedPrompterType(BaseModel): class UserDefinedPrompterType(BaseModel):
@@ -448,7 +452,7 @@ class AxolotlInputConfig(
dataset_shard_idx: Optional[int] = None dataset_shard_idx: Optional[int] = None
pretraining_dataset: Optional[ # type: ignore pretraining_dataset: Optional[ # type: ignore
conlist(Union[SFTDataset, PretrainingDataset], min_length=1) conlist(Union[PretrainingDataset, SFTDataset], min_length=1)
] = Field( ] = Field(
default=None, metadata={"help": {"streaming dataset to use for pretraining"}} default=None, metadata={"help": {"streaming dataset to use for pretraining"}}
) )

View File

@@ -82,12 +82,15 @@ def prepare_dataset(cfg, tokenizer):
) )
else: else:
path = cfg.pretraining_dataset path = cfg.pretraining_dataset
split = "train"
name = None name = None
if isinstance(cfg.pretraining_dataset, list) and isinstance( if isinstance(cfg.pretraining_dataset, list) and isinstance(
cfg.pretraining_dataset[0], dict cfg.pretraining_dataset[0], dict
): ):
path = cfg.pretraining_dataset[0]["path"] path = cfg.pretraining_dataset[0]["path"]
name = cfg.pretraining_dataset[0]["name"] name = cfg.pretraining_dataset[0]["name"]
if "split" in cfg.pretraining_dataset[0]:
split = cfg.pretraining_dataset[0]["split"]
ds_wrapper_partial = functools.partial( ds_wrapper_partial = functools.partial(
get_dataset_wrapper, get_dataset_wrapper,
@@ -98,7 +101,7 @@ def prepare_dataset(cfg, tokenizer):
) )
train_dataset = wrap_pretraining_dataset( train_dataset = wrap_pretraining_dataset(
load_dataset(path, streaming=True, split="train", name=name), load_dataset(path, streaming=True, split=split, name=name),
tokenizer, tokenizer,
cfg, cfg,
ds_wrapper_partial, ds_wrapper_partial,
@@ -831,14 +834,23 @@ def wrap_pretraining_dataset(
else: else:
LOG.debug("NOT shuffling merged pretraining datasets") LOG.debug("NOT shuffling merged pretraining datasets")
# remove all the existing columns after mapping since they end up having
# a different length than the encoded/tokenized column
# this is empty during streaming/pretraining
remove_columns = []
if dataset.features is None:
for first_row in dataset:
remove_columns = first_row.keys()
break
else:
remove_columns = dataset.features.keys()
dataset = dataset.map( dataset = dataset.map(
encode, encode,
batched=True, batched=True,
batch_size=buffer_size, batch_size=buffer_size,
# input_columns="text", # input_columns="text",
# remove all the existing columns after mapping since they end up having remove_columns=remove_columns,
# a different length than the encoded/tokenized column
remove_columns=dataset.features.keys(),
) )
return dataset return dataset