fix for iterable datasets and pickling (#2831) [skip ci]

* fix for iterable datasets and pickling

* more fixes for pretraining

* can't pickle mock generator dataset
This commit is contained in:
Wing Lian
2025-06-27 10:35:23 -04:00
committed by GitHub
parent 927bf530bc
commit a24957fa04
6 changed files with 25 additions and 5 deletions

View File

@@ -413,7 +413,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
or self.cfg.micro_batch_size > 1 or self.cfg.micro_batch_size > 1
): ):
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs) return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
return None if not (self.cfg.sample_packing and self.cfg.pretrain_multipack_attn):
return None
if self.cfg.model_config_type == "mamba": if self.cfg.model_config_type == "mamba":
return MambaDataCollator(tokenizer=self.tokenizer) return MambaDataCollator(tokenizer=self.tokenizer)

View File

@@ -223,6 +223,8 @@ def execute_training(
) )
LOG.info("Starting trainer...") LOG.info("Starting trainer...")
if cfg.bf16:
torch.set_default_dtype(torch.bfloat16)
trainer.train(resume_from_checkpoint=resume_from_checkpoint) trainer.train(resume_from_checkpoint=resume_from_checkpoint)

View File

@@ -224,10 +224,10 @@ def wrap_pretraining_dataset(
remove_columns = [] remove_columns = []
if dataset.features is None: if dataset.features is None:
for first_row in dataset: for first_row in dataset:
remove_columns = first_row.keys() remove_columns = list(first_row.keys())
break break
else: else:
remove_columns = dataset.features.keys() remove_columns = list(dataset.features.keys())
dataset = dataset.map( dataset = dataset.map(
encode, encode,
@@ -267,6 +267,7 @@ def encode_packed_pretraining(
batch_size=1, batch_size=1,
batch_max_len=batch_size * max_seq_length, batch_max_len=batch_size * max_seq_length,
drop_last=True, drop_last=True,
num_processes=1,
) )
chunked_data = defaultdict(list) chunked_data = defaultdict(list)

View File

@@ -260,7 +260,7 @@ class MultipackBatchSampler(BatchSampler):
lengths: np.ndarray, # Sequence lengths lengths: np.ndarray, # Sequence lengths
packing_efficiency_estimate: float = 1.0, # Initial efficiency estimate packing_efficiency_estimate: float = 1.0, # Initial efficiency estimate
drop_last: bool = True, # Whether to drop final batches (might be incomplete) drop_last: bool = True, # Whether to drop final batches (might be incomplete)
num_count_samples: int = 8, # Number of times to estimate batch count num_count_samples: int = 4, # Number of times to estimate batch count
sequential: bool = False, # Whether to use sequential packing sequential: bool = False, # Whether to use sequential packing
group_size: int = 100_000, # Size of groups for parallel packing group_size: int = 100_000, # Size of groups for parallel packing
bin_size: int = 200, # The max number of samples that can be packed in a single bin bin_size: int = 200, # The max number of samples that can be packed in a single bin
@@ -335,12 +335,13 @@ class MultipackBatchSampler(BatchSampler):
bins = [[indices[b_idx] for b_idx in bin_indices] for bin_indices in bins] bins = [[indices[b_idx] for b_idx in bin_indices] for bin_indices in bins]
else: else:
# Use parallel packing # Use parallel packing
num_processes = self.num_processes or 1
all_bins = pack_parallel( all_bins = pack_parallel(
lengths, lengths,
bin_capacity=self.batch_max_len, bin_capacity=self.batch_max_len,
group_size=self.group_size, group_size=self.group_size,
bin_size=self.bin_size, bin_size=self.bin_size,
num_processes=max(4, self.num_processes) if self.num_processes else 4, num_processes=min(4, num_processes) if num_processes else 4,
safe_mode=self.safe_mode, safe_mode=self.safe_mode,
mp_start_method=self.mp_start_method, mp_start_method=self.mp_start_method,
) )

View File

@@ -462,6 +462,20 @@ class TrainingValidationMixin:
return data return data
@model_validator(mode="before")
@classmethod
def pretrain_with_tps(cls, data):
if data.get("pretraining_dataset") and data.get(
"include_tokens_per_second", False
):
# combining these would raise `TypeError: cannot pickle 'dict_keys' object`
# due to trying to count the number of tokens total in the dataset
raise ValueError(
"pretraining_dataset and include_tokens_per_second cannot be used together."
)
return data
class LoRAValidationMixin: class LoRAValidationMixin:
"""Validation methods related to LoRA/QLoRA configuration.""" """Validation methods related to LoRA/QLoRA configuration."""

View File

@@ -381,6 +381,7 @@ def process_pretraining_datasets_for_packing(
if not skip_position_ids: if not skip_position_ids:
train_dataset = train_dataset.map( train_dataset = train_dataset.map(
add_position_ids, add_position_ids,
batched=True,
desc="Add position_id column (Pretraining Sample Packing)", desc="Add position_id column (Pretraining Sample Packing)",
) )
if drop_attention_mask: if drop_attention_mask: