option to not concatenate during pretraining (#2263)
* option to not concatenate during pretraining * simplify conditional and add doc to config.qmd
This commit is contained in:
@@ -244,6 +244,8 @@ total_num_tokens:
|
|||||||
sample_packing_group_size: 100000
|
sample_packing_group_size: 100000
|
||||||
# The number of samples which can be packed into one sequence. Increase if using a large sequence_len with many short samples.
|
# The number of samples which can be packed into one sequence. Increase if using a large sequence_len with many short samples.
|
||||||
sample_packing_bin_size: 200
|
sample_packing_bin_size: 200
|
||||||
|
# whether to concatenate samples during pretraining
|
||||||
|
pretraining_sample_concatenation:
|
||||||
|
|
||||||
# Use batch flattening for speedups when not using sample_packing
|
# Use batch flattening for speedups when not using sample_packing
|
||||||
batch_flattening:
|
batch_flattening:
|
||||||
|
|||||||
@@ -1877,6 +1877,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
|
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
|
||||||
):
|
):
|
||||||
if training_args.pretraining:
|
if training_args.pretraining:
|
||||||
|
if self.cfg.pretraining_sample_concatenation is False:
|
||||||
|
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if self.cfg.model_config_type == "mamba":
|
if self.cfg.model_config_type == "mamba":
|
||||||
|
|||||||
@@ -706,6 +706,12 @@ class AxolotlInputConfig(
|
|||||||
pad_to_sequence_len: Optional[bool] = None
|
pad_to_sequence_len: Optional[bool] = None
|
||||||
curriculum_sampling: Optional[bool] = None
|
curriculum_sampling: Optional[bool] = None
|
||||||
multipack_real_batches: Optional[bool] = None
|
multipack_real_batches: Optional[bool] = None
|
||||||
|
pretraining_sample_concatenation: Optional[bool] = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "whether to soft pack/concatenate samples during pretraining",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
batch_flattening: Optional[Union[Literal["auto"], bool]] = None
|
batch_flattening: Optional[Union[Literal["auto"], bool]] = None
|
||||||
|
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ def encode_pretraining(
|
|||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
examples: Dict[str, List],
|
examples: Dict[str, List],
|
||||||
text_column: str = "text",
|
text_column: str = "text",
|
||||||
|
concatenate: bool = True,
|
||||||
) -> Dict[str, List]:
|
) -> Dict[str, List]:
|
||||||
res = tokenizer(
|
res = tokenizer(
|
||||||
examples[text_column],
|
examples[text_column],
|
||||||
@@ -33,6 +34,13 @@ def encode_pretraining(
|
|||||||
input_ids = [torch.tensor(seq) for seq in res["input_ids"]]
|
input_ids = [torch.tensor(seq) for seq in res["input_ids"]]
|
||||||
targets = [torch.tensor(seq) for seq in res["input_ids"]]
|
targets = [torch.tensor(seq) for seq in res["input_ids"]]
|
||||||
attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]]
|
attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]]
|
||||||
|
if not concatenate:
|
||||||
|
return {
|
||||||
|
"input_ids": [seq.tolist() for seq in input_ids],
|
||||||
|
"labels": [seq.tolist() for seq in targets],
|
||||||
|
"attention_mask": [seq.tolist() for seq in attention_mask],
|
||||||
|
}
|
||||||
|
|
||||||
new_input_ids = []
|
new_input_ids = []
|
||||||
new_labels = []
|
new_labels = []
|
||||||
new_attention_mask = []
|
new_attention_mask = []
|
||||||
@@ -204,6 +212,7 @@ def wrap_pretraining_dataset(
|
|||||||
tokenizer,
|
tokenizer,
|
||||||
max_tokens,
|
max_tokens,
|
||||||
text_column=cfg.pretraining_dataset[0].text_column or "text",
|
text_column=cfg.pretraining_dataset[0].text_column or "text",
|
||||||
|
concatenate=cfg.pretraining_sample_concatenation is True,
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.shuffle_merged_datasets:
|
if cfg.shuffle_merged_datasets:
|
||||||
|
|||||||
Reference in New Issue
Block a user