fix for iterable datasets and pickling (#2831) [skip ci]
* fix for iterable datasets and pickling * more fixes for pretraining * can't pickle mock generator dataset
This commit is contained in:
@@ -413,7 +413,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
or self.cfg.micro_batch_size > 1
|
or self.cfg.micro_batch_size > 1
|
||||||
):
|
):
|
||||||
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
|
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
|
||||||
return None
|
if not (self.cfg.sample_packing and self.cfg.pretrain_multipack_attn):
|
||||||
|
return None
|
||||||
|
|
||||||
if self.cfg.model_config_type == "mamba":
|
if self.cfg.model_config_type == "mamba":
|
||||||
return MambaDataCollator(tokenizer=self.tokenizer)
|
return MambaDataCollator(tokenizer=self.tokenizer)
|
||||||
|
|||||||
@@ -223,6 +223,8 @@ def execute_training(
|
|||||||
)
|
)
|
||||||
|
|
||||||
LOG.info("Starting trainer...")
|
LOG.info("Starting trainer...")
|
||||||
|
if cfg.bf16:
|
||||||
|
torch.set_default_dtype(torch.bfloat16)
|
||||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -224,10 +224,10 @@ def wrap_pretraining_dataset(
|
|||||||
remove_columns = []
|
remove_columns = []
|
||||||
if dataset.features is None:
|
if dataset.features is None:
|
||||||
for first_row in dataset:
|
for first_row in dataset:
|
||||||
remove_columns = first_row.keys()
|
remove_columns = list(first_row.keys())
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
remove_columns = dataset.features.keys()
|
remove_columns = list(dataset.features.keys())
|
||||||
|
|
||||||
dataset = dataset.map(
|
dataset = dataset.map(
|
||||||
encode,
|
encode,
|
||||||
@@ -267,6 +267,7 @@ def encode_packed_pretraining(
|
|||||||
batch_size=1,
|
batch_size=1,
|
||||||
batch_max_len=batch_size * max_seq_length,
|
batch_max_len=batch_size * max_seq_length,
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
|
num_processes=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
chunked_data = defaultdict(list)
|
chunked_data = defaultdict(list)
|
||||||
|
|||||||
@@ -260,7 +260,7 @@ class MultipackBatchSampler(BatchSampler):
|
|||||||
lengths: np.ndarray, # Sequence lengths
|
lengths: np.ndarray, # Sequence lengths
|
||||||
packing_efficiency_estimate: float = 1.0, # Initial efficiency estimate
|
packing_efficiency_estimate: float = 1.0, # Initial efficiency estimate
|
||||||
drop_last: bool = True, # Whether to drop final batches (might be incomplete)
|
drop_last: bool = True, # Whether to drop final batches (might be incomplete)
|
||||||
num_count_samples: int = 8, # Number of times to estimate batch count
|
num_count_samples: int = 4, # Number of times to estimate batch count
|
||||||
sequential: bool = False, # Whether to use sequential packing
|
sequential: bool = False, # Whether to use sequential packing
|
||||||
group_size: int = 100_000, # Size of groups for parallel packing
|
group_size: int = 100_000, # Size of groups for parallel packing
|
||||||
bin_size: int = 200, # The max number of samples that can be packed in a single bin
|
bin_size: int = 200, # The max number of samples that can be packed in a single bin
|
||||||
@@ -335,12 +335,13 @@ class MultipackBatchSampler(BatchSampler):
|
|||||||
bins = [[indices[b_idx] for b_idx in bin_indices] for bin_indices in bins]
|
bins = [[indices[b_idx] for b_idx in bin_indices] for bin_indices in bins]
|
||||||
else:
|
else:
|
||||||
# Use parallel packing
|
# Use parallel packing
|
||||||
|
num_processes = self.num_processes or 1
|
||||||
all_bins = pack_parallel(
|
all_bins = pack_parallel(
|
||||||
lengths,
|
lengths,
|
||||||
bin_capacity=self.batch_max_len,
|
bin_capacity=self.batch_max_len,
|
||||||
group_size=self.group_size,
|
group_size=self.group_size,
|
||||||
bin_size=self.bin_size,
|
bin_size=self.bin_size,
|
||||||
num_processes=max(4, self.num_processes) if self.num_processes else 4,
|
num_processes=min(4, num_processes) if num_processes else 4,
|
||||||
safe_mode=self.safe_mode,
|
safe_mode=self.safe_mode,
|
||||||
mp_start_method=self.mp_start_method,
|
mp_start_method=self.mp_start_method,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -462,6 +462,20 @@ class TrainingValidationMixin:
|
|||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def pretrain_with_tps(cls, data):
|
||||||
|
if data.get("pretraining_dataset") and data.get(
|
||||||
|
"include_tokens_per_second", False
|
||||||
|
):
|
||||||
|
# combining these would raise `TypeError: cannot pickle 'dict_keys' object`
|
||||||
|
# due to trying to count the number of tokens total in the dataset
|
||||||
|
raise ValueError(
|
||||||
|
"pretraining_dataset and include_tokens_per_second cannot be used together."
|
||||||
|
)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
class LoRAValidationMixin:
|
class LoRAValidationMixin:
|
||||||
"""Validation methods related to LoRA/QLoRA configuration."""
|
"""Validation methods related to LoRA/QLoRA configuration."""
|
||||||
|
|||||||
@@ -381,6 +381,7 @@ def process_pretraining_datasets_for_packing(
|
|||||||
if not skip_position_ids:
|
if not skip_position_ids:
|
||||||
train_dataset = train_dataset.map(
|
train_dataset = train_dataset.map(
|
||||||
add_position_ids,
|
add_position_ids,
|
||||||
|
batched=True,
|
||||||
desc="Add position_id column (Pretraining Sample Packing)",
|
desc="Add position_id column (Pretraining Sample Packing)",
|
||||||
)
|
)
|
||||||
if drop_attention_mask:
|
if drop_attention_mask:
|
||||||
|
|||||||
Reference in New Issue
Block a user