fix pretraining_ on odd datasets (#1463)
* can configure name of split of pretraining dataset * streaming data and dataset map * text column customized * allow text_column to be set in pretrain * pretrain type * load a bit of the dataset * fix dataset where splits have separate configs * ok name param here is the config * whitespace
This commit is contained in:
@@ -20,10 +20,11 @@ class PretrainTokenizationStrategy(PromptTokenizingStrategy):
|
||||
def supports_batched(self):
|
||||
return True
|
||||
|
||||
def __init__(self, *args, max_length=None, **kwargs):
|
||||
def __init__(self, *args, max_length=None, text_column="text", **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
if max_length:
|
||||
self.max_length = max_length
|
||||
self.text_column = text_column
|
||||
|
||||
def _tokenize(
|
||||
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
|
||||
@@ -44,7 +45,7 @@ class PretrainTokenizationStrategy(PromptTokenizingStrategy):
|
||||
return res
|
||||
|
||||
def tokenize_prompt(self, prompt):
|
||||
return self._tokenize(prompt["text"])
|
||||
return self._tokenize(prompt[self.text_column])
|
||||
|
||||
|
||||
def load(tokenizer, cfg):
|
||||
@@ -53,6 +54,7 @@ def load(tokenizer, cfg):
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
text_column=cfg.pretraining_dataset[0]["text_column"] or "text",
|
||||
max_length=cfg.sequence_len * 64,
|
||||
)
|
||||
return strat
|
||||
|
||||
@@ -61,7 +61,11 @@ class RemappedParameters(BaseModel):
|
||||
class PretrainingDataset(BaseModel):
|
||||
"""pretraining dataset configuration subset"""
|
||||
|
||||
name: Optional[str] = None
|
||||
path: Optional[str] = None
|
||||
split: Optional[str] = "train"
|
||||
text_column: Optional[str] = "text"
|
||||
type: Optional[str] = "pretrain"
|
||||
|
||||
|
||||
class UserDefinedPrompterType(BaseModel):
|
||||
@@ -448,7 +452,7 @@ class AxolotlInputConfig(
|
||||
dataset_shard_idx: Optional[int] = None
|
||||
|
||||
pretraining_dataset: Optional[ # type: ignore
|
||||
conlist(Union[SFTDataset, PretrainingDataset], min_length=1)
|
||||
conlist(Union[PretrainingDataset, SFTDataset], min_length=1)
|
||||
] = Field(
|
||||
default=None, metadata={"help": {"streaming dataset to use for pretraining"}}
|
||||
)
|
||||
|
||||
@@ -82,12 +82,15 @@ def prepare_dataset(cfg, tokenizer):
|
||||
)
|
||||
else:
|
||||
path = cfg.pretraining_dataset
|
||||
split = "train"
|
||||
name = None
|
||||
if isinstance(cfg.pretraining_dataset, list) and isinstance(
|
||||
cfg.pretraining_dataset[0], dict
|
||||
):
|
||||
path = cfg.pretraining_dataset[0]["path"]
|
||||
name = cfg.pretraining_dataset[0]["name"]
|
||||
if "split" in cfg.pretraining_dataset[0]:
|
||||
split = cfg.pretraining_dataset[0]["split"]
|
||||
|
||||
ds_wrapper_partial = functools.partial(
|
||||
get_dataset_wrapper,
|
||||
@@ -98,7 +101,7 @@ def prepare_dataset(cfg, tokenizer):
|
||||
)
|
||||
|
||||
train_dataset = wrap_pretraining_dataset(
|
||||
load_dataset(path, streaming=True, split="train", name=name),
|
||||
load_dataset(path, streaming=True, split=split, name=name),
|
||||
tokenizer,
|
||||
cfg,
|
||||
ds_wrapper_partial,
|
||||
@@ -831,14 +834,23 @@ def wrap_pretraining_dataset(
|
||||
else:
|
||||
LOG.debug("NOT shuffling merged pretraining datasets")
|
||||
|
||||
# remove all the existing columns after mapping since they end up having
|
||||
# a different length than the encoded/tokenized column
|
||||
# this is empty during streaming/pretraining
|
||||
remove_columns = []
|
||||
if dataset.features is None:
|
||||
for first_row in dataset:
|
||||
remove_columns = first_row.keys()
|
||||
break
|
||||
else:
|
||||
remove_columns = dataset.features.keys()
|
||||
|
||||
dataset = dataset.map(
|
||||
encode,
|
||||
batched=True,
|
||||
batch_size=buffer_size,
|
||||
# input_columns="text",
|
||||
# remove all the existing columns after mapping since they end up having
|
||||
# a different length than the encoded/tokenized column
|
||||
remove_columns=dataset.features.keys(),
|
||||
remove_columns=remove_columns,
|
||||
)
|
||||
return dataset
|
||||
|
||||
|
||||
Reference in New Issue
Block a user