From c196776996a813144be3aac013869d375d5a7fb0 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 15 Jan 2025 22:45:02 -0500 Subject: [PATCH] option to not concatenate during pretraining --- src/axolotl/core/trainer_builder.py | 2 ++ .../utils/config/models/input/v0_4_1/__init__.py | 6 ++++++ src/axolotl/utils/data/pretraining.py | 16 +++++++++++++++- 3 files changed, 23 insertions(+), 1 deletion(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index b52dc73a3..15fa1cd8a 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1877,6 +1877,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs ): if training_args.pretraining: + if self.cfg.pretraining_sample_concatenation is False: + return DataCollatorForSeq2Seq(self.tokenizer, **kwargs) return None if self.cfg.model_config_type == "mamba": diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index bb88a0baa..a895a2eed 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -698,6 +698,12 @@ class AxolotlInputConfig( pad_to_sequence_len: Optional[bool] = None curriculum_sampling: Optional[bool] = None multipack_real_batches: Optional[bool] = None + pretraining_sample_concatenation: Optional[bool] = Field( + default=None, + json_schema_extra={ + "description": "whether to soft pack/concatenate samples during pretraining", + }, + ) batch_flattening: Optional[Union[Literal["auto"], bool]] = None diff --git a/src/axolotl/utils/data/pretraining.py b/src/axolotl/utils/data/pretraining.py index f493db70e..f05085ab9 100644 --- a/src/axolotl/utils/data/pretraining.py +++ b/src/axolotl/utils/data/pretraining.py @@ -18,7 +18,10 @@ LOG = logging.getLogger("axolotl") def encode_pretraining( - tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: Dict[str, List] + tokenizer: PreTrainedTokenizerBase, + max_tokens: int, + examples: Dict[str, List], + concatenate: bool = True, ) -> Dict[str, List]: res = tokenizer( examples["text"], @@ -30,6 +33,13 @@ def encode_pretraining( input_ids = [torch.tensor(seq) for seq in res["input_ids"]] targets = [torch.tensor(seq) for seq in res["input_ids"]] attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]] + if not concatenate: + return { + "input_ids": [seq.tolist() for seq in input_ids], + "labels": [seq.tolist() for seq in targets], + "attention_mask": [seq.tolist() for seq in attention_mask], + } + new_input_ids = [] new_labels = [] new_attention_mask = [] @@ -195,6 +205,10 @@ def wrap_pretraining_dataset( ) # set this to 1 so downstream data_loader doesn't try to increase the batch again cfg.micro_batch_size = 1 + elif cfg.pretraining_sample_concatenation is False: + encode = functools.partial( + encode_pretraining, tokenizer, max_tokens, concatenate=False + ) else: encode = functools.partial(encode_pretraining, tokenizer, max_tokens)