From a24957fa0438a94882f0976d18595b23de8c84fb Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 27 Jun 2025 10:35:23 -0400 Subject: [PATCH] fix for iterable datasets and pickling (#2831) [skip ci] * fix for iterable datasets and pickling * more fixes for pretraining * can't pickle mock generator dataset --- src/axolotl/core/builders/causal.py | 3 ++- src/axolotl/train.py | 2 ++ src/axolotl/utils/data/pretraining.py | 5 +++-- src/axolotl/utils/samplers/multipack.py | 5 +++-- src/axolotl/utils/schemas/validation.py | 14 ++++++++++++++ src/axolotl/utils/trainer.py | 1 + 6 files changed, 25 insertions(+), 5 deletions(-) diff --git a/src/axolotl/core/builders/causal.py b/src/axolotl/core/builders/causal.py index 47e33a332..2b7d902fa 100644 --- a/src/axolotl/core/builders/causal.py +++ b/src/axolotl/core/builders/causal.py @@ -413,7 +413,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): or self.cfg.micro_batch_size > 1 ): return DataCollatorForSeq2Seq(self.tokenizer, **kwargs) - return None + if not (self.cfg.sample_packing and self.cfg.pretrain_multipack_attn): + return None if self.cfg.model_config_type == "mamba": return MambaDataCollator(tokenizer=self.tokenizer) diff --git a/src/axolotl/train.py b/src/axolotl/train.py index d5dd431c1..a476385d0 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -223,6 +223,8 @@ def execute_training( ) LOG.info("Starting trainer...") + if cfg.bf16: + torch.set_default_dtype(torch.bfloat16) trainer.train(resume_from_checkpoint=resume_from_checkpoint) diff --git a/src/axolotl/utils/data/pretraining.py b/src/axolotl/utils/data/pretraining.py index 4ff108aee..f3422f990 100644 --- a/src/axolotl/utils/data/pretraining.py +++ b/src/axolotl/utils/data/pretraining.py @@ -224,10 +224,10 @@ def wrap_pretraining_dataset( remove_columns = [] if dataset.features is None: for first_row in dataset: - remove_columns = first_row.keys() + remove_columns = list(first_row.keys()) break else: - remove_columns = dataset.features.keys() + remove_columns = list(dataset.features.keys()) dataset = dataset.map( encode, @@ -267,6 +267,7 @@ def encode_packed_pretraining( batch_size=1, batch_max_len=batch_size * max_seq_length, drop_last=True, + num_processes=1, ) chunked_data = defaultdict(list) diff --git a/src/axolotl/utils/samplers/multipack.py b/src/axolotl/utils/samplers/multipack.py index 95d97e7a0..ee8640f41 100644 --- a/src/axolotl/utils/samplers/multipack.py +++ b/src/axolotl/utils/samplers/multipack.py @@ -260,7 +260,7 @@ class MultipackBatchSampler(BatchSampler): lengths: np.ndarray, # Sequence lengths packing_efficiency_estimate: float = 1.0, # Initial efficiency estimate drop_last: bool = True, # Whether to drop final batches (might be incomplete) - num_count_samples: int = 8, # Number of times to estimate batch count + num_count_samples: int = 4, # Number of times to estimate batch count sequential: bool = False, # Whether to use sequential packing group_size: int = 100_000, # Size of groups for parallel packing bin_size: int = 200, # The max number of samples that can be packed in a single bin @@ -335,12 +335,13 @@ class MultipackBatchSampler(BatchSampler): bins = [[indices[b_idx] for b_idx in bin_indices] for bin_indices in bins] else: # Use parallel packing + num_processes = self.num_processes or 1 all_bins = pack_parallel( lengths, bin_capacity=self.batch_max_len, group_size=self.group_size, bin_size=self.bin_size, - num_processes=max(4, self.num_processes) if self.num_processes else 4, + num_processes=min(4, num_processes) if num_processes else 4, safe_mode=self.safe_mode, mp_start_method=self.mp_start_method, ) diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index 5a6bf43b3..3a0c9cc9f 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -462,6 +462,20 @@ class TrainingValidationMixin: return data + @model_validator(mode="before") + @classmethod + def pretrain_with_tps(cls, data): + if data.get("pretraining_dataset") and data.get( + "include_tokens_per_second", False + ): + # combining these would raise `TypeError: cannot pickle 'dict_keys' object` + # due to trying to count the number of tokens total in the dataset + raise ValueError( + "pretraining_dataset and include_tokens_per_second cannot be used together." + ) + + return data + class LoRAValidationMixin: """Validation methods related to LoRA/QLoRA configuration.""" diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 554a55abc..278fbed5b 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -381,6 +381,7 @@ def process_pretraining_datasets_for_packing( if not skip_position_ids: train_dataset = train_dataset.map( add_position_ids, + batched=True, desc="Add position_id column (Pretraining Sample Packing)", ) if drop_attention_mask: