fix pretraining_ on odd datasets (#1463)
* can configure name of split of pretraining dataset * streaming data and dataset map * text column customized * allow text_column to be set in pretrain * pretrain type * load a bit of the dataset * fix dataset where splits have separate configs * ok name param here is the config * whitespace
This commit is contained in:
@@ -20,10 +20,11 @@ class PretrainTokenizationStrategy(PromptTokenizingStrategy):
|
|||||||
def supports_batched(self):
|
def supports_batched(self):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def __init__(self, *args, max_length=None, **kwargs):
|
def __init__(self, *args, max_length=None, text_column="text", **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
if max_length:
|
if max_length:
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
|
self.text_column = text_column
|
||||||
|
|
||||||
def _tokenize(
|
def _tokenize(
|
||||||
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
|
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
|
||||||
@@ -44,7 +45,7 @@ class PretrainTokenizationStrategy(PromptTokenizingStrategy):
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
def tokenize_prompt(self, prompt):
|
def tokenize_prompt(self, prompt):
|
||||||
return self._tokenize(prompt["text"])
|
return self._tokenize(prompt[self.text_column])
|
||||||
|
|
||||||
|
|
||||||
def load(tokenizer, cfg):
|
def load(tokenizer, cfg):
|
||||||
@@ -53,6 +54,7 @@ def load(tokenizer, cfg):
|
|||||||
tokenizer,
|
tokenizer,
|
||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
|
text_column=cfg.pretraining_dataset[0]["text_column"] or "text",
|
||||||
max_length=cfg.sequence_len * 64,
|
max_length=cfg.sequence_len * 64,
|
||||||
)
|
)
|
||||||
return strat
|
return strat
|
||||||
|
|||||||
@@ -61,7 +61,11 @@ class RemappedParameters(BaseModel):
|
|||||||
class PretrainingDataset(BaseModel):
|
class PretrainingDataset(BaseModel):
|
||||||
"""pretraining dataset configuration subset"""
|
"""pretraining dataset configuration subset"""
|
||||||
|
|
||||||
|
name: Optional[str] = None
|
||||||
path: Optional[str] = None
|
path: Optional[str] = None
|
||||||
|
split: Optional[str] = "train"
|
||||||
|
text_column: Optional[str] = "text"
|
||||||
|
type: Optional[str] = "pretrain"
|
||||||
|
|
||||||
|
|
||||||
class UserDefinedPrompterType(BaseModel):
|
class UserDefinedPrompterType(BaseModel):
|
||||||
@@ -448,7 +452,7 @@ class AxolotlInputConfig(
|
|||||||
dataset_shard_idx: Optional[int] = None
|
dataset_shard_idx: Optional[int] = None
|
||||||
|
|
||||||
pretraining_dataset: Optional[ # type: ignore
|
pretraining_dataset: Optional[ # type: ignore
|
||||||
conlist(Union[SFTDataset, PretrainingDataset], min_length=1)
|
conlist(Union[PretrainingDataset, SFTDataset], min_length=1)
|
||||||
] = Field(
|
] = Field(
|
||||||
default=None, metadata={"help": {"streaming dataset to use for pretraining"}}
|
default=None, metadata={"help": {"streaming dataset to use for pretraining"}}
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -82,12 +82,15 @@ def prepare_dataset(cfg, tokenizer):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
path = cfg.pretraining_dataset
|
path = cfg.pretraining_dataset
|
||||||
|
split = "train"
|
||||||
name = None
|
name = None
|
||||||
if isinstance(cfg.pretraining_dataset, list) and isinstance(
|
if isinstance(cfg.pretraining_dataset, list) and isinstance(
|
||||||
cfg.pretraining_dataset[0], dict
|
cfg.pretraining_dataset[0], dict
|
||||||
):
|
):
|
||||||
path = cfg.pretraining_dataset[0]["path"]
|
path = cfg.pretraining_dataset[0]["path"]
|
||||||
name = cfg.pretraining_dataset[0]["name"]
|
name = cfg.pretraining_dataset[0]["name"]
|
||||||
|
if "split" in cfg.pretraining_dataset[0]:
|
||||||
|
split = cfg.pretraining_dataset[0]["split"]
|
||||||
|
|
||||||
ds_wrapper_partial = functools.partial(
|
ds_wrapper_partial = functools.partial(
|
||||||
get_dataset_wrapper,
|
get_dataset_wrapper,
|
||||||
@@ -98,7 +101,7 @@ def prepare_dataset(cfg, tokenizer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
train_dataset = wrap_pretraining_dataset(
|
train_dataset = wrap_pretraining_dataset(
|
||||||
load_dataset(path, streaming=True, split="train", name=name),
|
load_dataset(path, streaming=True, split=split, name=name),
|
||||||
tokenizer,
|
tokenizer,
|
||||||
cfg,
|
cfg,
|
||||||
ds_wrapper_partial,
|
ds_wrapper_partial,
|
||||||
@@ -831,14 +834,23 @@ def wrap_pretraining_dataset(
|
|||||||
else:
|
else:
|
||||||
LOG.debug("NOT shuffling merged pretraining datasets")
|
LOG.debug("NOT shuffling merged pretraining datasets")
|
||||||
|
|
||||||
|
# remove all the existing columns after mapping since they end up having
|
||||||
|
# a different length than the encoded/tokenized column
|
||||||
|
# this is empty during streaming/pretraining
|
||||||
|
remove_columns = []
|
||||||
|
if dataset.features is None:
|
||||||
|
for first_row in dataset:
|
||||||
|
remove_columns = first_row.keys()
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
remove_columns = dataset.features.keys()
|
||||||
|
|
||||||
dataset = dataset.map(
|
dataset = dataset.map(
|
||||||
encode,
|
encode,
|
||||||
batched=True,
|
batched=True,
|
||||||
batch_size=buffer_size,
|
batch_size=buffer_size,
|
||||||
# input_columns="text",
|
# input_columns="text",
|
||||||
# remove all the existing columns after mapping since they end up having
|
remove_columns=remove_columns,
|
||||||
# a different length than the encoded/tokenized column
|
|
||||||
remove_columns=dataset.features.keys(),
|
|
||||||
)
|
)
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user