From af727eedf75518bc603545b03a54a28fa99beeec Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 20 Jan 2025 14:07:34 -0500 Subject: [PATCH] option to not concatenate during pretraining (#2263) * option to not concatenate during pretraining * simplify conditional and add doc to config.qmd --- docs/config.qmd | 2 ++ src/axolotl/core/trainer_builder.py | 2 ++ src/axolotl/utils/config/models/input/v0_4_1/__init__.py | 6 ++++++ src/axolotl/utils/data/pretraining.py | 9 +++++++++ 4 files changed, 19 insertions(+) diff --git a/docs/config.qmd b/docs/config.qmd index 70679791e..179ee9ed1 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -244,6 +244,8 @@ total_num_tokens: sample_packing_group_size: 100000 # The number of samples which can be packed into one sequence. Increase if using a large sequence_len with many short samples. sample_packing_bin_size: 200 +# whether to concatenate samples during pretraining +pretraining_sample_concatenation: # Use batch flattening for speedups when not using sample_packing batch_flattening: diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 176ce4174..6f1bae1ef 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1877,6 +1877,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs ): if training_args.pretraining: + if self.cfg.pretraining_sample_concatenation is False: + return DataCollatorForSeq2Seq(self.tokenizer, **kwargs) return None if self.cfg.model_config_type == "mamba": diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 4f368994a..98cdee009 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -706,6 +706,12 @@ class AxolotlInputConfig( pad_to_sequence_len: Optional[bool] = None curriculum_sampling: Optional[bool] = None multipack_real_batches: Optional[bool] = None + pretraining_sample_concatenation: Optional[bool] = Field( + default=None, + json_schema_extra={ + "description": "whether to soft pack/concatenate samples during pretraining", + }, + ) batch_flattening: Optional[Union[Literal["auto"], bool]] = None diff --git a/src/axolotl/utils/data/pretraining.py b/src/axolotl/utils/data/pretraining.py index 369d2d6fe..c30d62575 100644 --- a/src/axolotl/utils/data/pretraining.py +++ b/src/axolotl/utils/data/pretraining.py @@ -22,6 +22,7 @@ def encode_pretraining( max_tokens: int, examples: Dict[str, List], text_column: str = "text", + concatenate: bool = True, ) -> Dict[str, List]: res = tokenizer( examples[text_column], @@ -33,6 +34,13 @@ def encode_pretraining( input_ids = [torch.tensor(seq) for seq in res["input_ids"]] targets = [torch.tensor(seq) for seq in res["input_ids"]] attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]] + if not concatenate: + return { + "input_ids": [seq.tolist() for seq in input_ids], + "labels": [seq.tolist() for seq in targets], + "attention_mask": [seq.tolist() for seq in attention_mask], + } + new_input_ids = [] new_labels = [] new_attention_mask = [] @@ -204,6 +212,7 @@ def wrap_pretraining_dataset( tokenizer, max_tokens, text_column=cfg.pretraining_dataset[0].text_column or "text", + concatenate=cfg.pretraining_sample_concatenation is True, ) if cfg.shuffle_merged_datasets: