option to not concatenate during pretraining (#2263)

* option to not concatenate during pretraining

* simplify conditional and add doc to config.qmd
This commit is contained in:
Wing Lian
2025-01-20 14:07:34 -05:00
committed by GitHub
parent 8606093921
commit af727eedf7
4 changed files with 19 additions and 0 deletions

View File

@@ -244,6 +244,8 @@ total_num_tokens:
sample_packing_group_size: 100000 sample_packing_group_size: 100000
# The number of samples which can be packed into one sequence. Increase if using a large sequence_len with many short samples. # The number of samples which can be packed into one sequence. Increase if using a large sequence_len with many short samples.
sample_packing_bin_size: 200 sample_packing_bin_size: 200
# whether to concatenate samples during pretraining
pretraining_sample_concatenation:
# Use batch flattening for speedups when not using sample_packing # Use batch flattening for speedups when not using sample_packing
batch_flattening: batch_flattening:

View File

@@ -1877,6 +1877,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
): ):
if training_args.pretraining: if training_args.pretraining:
if self.cfg.pretraining_sample_concatenation is False:
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
return None return None
if self.cfg.model_config_type == "mamba": if self.cfg.model_config_type == "mamba":

View File

@@ -706,6 +706,12 @@ class AxolotlInputConfig(
pad_to_sequence_len: Optional[bool] = None pad_to_sequence_len: Optional[bool] = None
curriculum_sampling: Optional[bool] = None curriculum_sampling: Optional[bool] = None
multipack_real_batches: Optional[bool] = None multipack_real_batches: Optional[bool] = None
pretraining_sample_concatenation: Optional[bool] = Field(
default=None,
json_schema_extra={
"description": "whether to soft pack/concatenate samples during pretraining",
},
)
batch_flattening: Optional[Union[Literal["auto"], bool]] = None batch_flattening: Optional[Union[Literal["auto"], bool]] = None

View File

@@ -22,6 +22,7 @@ def encode_pretraining(
max_tokens: int, max_tokens: int,
examples: Dict[str, List], examples: Dict[str, List],
text_column: str = "text", text_column: str = "text",
concatenate: bool = True,
) -> Dict[str, List]: ) -> Dict[str, List]:
res = tokenizer( res = tokenizer(
examples[text_column], examples[text_column],
@@ -33,6 +34,13 @@ def encode_pretraining(
input_ids = [torch.tensor(seq) for seq in res["input_ids"]] input_ids = [torch.tensor(seq) for seq in res["input_ids"]]
targets = [torch.tensor(seq) for seq in res["input_ids"]] targets = [torch.tensor(seq) for seq in res["input_ids"]]
attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]] attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]]
if not concatenate:
return {
"input_ids": [seq.tolist() for seq in input_ids],
"labels": [seq.tolist() for seq in targets],
"attention_mask": [seq.tolist() for seq in attention_mask],
}
new_input_ids = [] new_input_ids = []
new_labels = [] new_labels = []
new_attention_mask = [] new_attention_mask = []
@@ -204,6 +212,7 @@ def wrap_pretraining_dataset(
tokenizer, tokenizer,
max_tokens, max_tokens,
text_column=cfg.pretraining_dataset[0].text_column or "text", text_column=cfg.pretraining_dataset[0].text_column or "text",
concatenate=cfg.pretraining_sample_concatenation is True,
) )
if cfg.shuffle_merged_datasets: if cfg.shuffle_merged_datasets: