fixing sample packing

This commit is contained in:
Dan Saunders
2025-03-12 20:44:02 +00:00
parent b7738d57c4
commit 5731cdc0cf
3 changed files with 69 additions and 63 deletions

View File

@@ -398,34 +398,60 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
)
return super()._wrap_model(model, training=training, dataloader=dataloader)
def _create_multipack_sampler(self, base_sampler):
"""Helper method to create a MultipackBatchSampler"""
if self.args.multipack_real_batches:
batch_size = self.args.per_device_train_batch_size
batch_max_len = self.args.max_seq_length
else:
batch_size = 1
train_batch_size = (
self.state.train_batch_size or self.args.per_device_train_batch_size
)
batch_max_len = train_batch_size * self.args.max_seq_length
return MultipackBatchSampler(
base_sampler,
lengths=get_dataset_lengths(self.train_dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
batch_max_len=batch_max_len,
batch_size=batch_size,
group_size=self.args.sample_packing_group_size,
bin_size=self.args.sample_packing_bin_size,
drop_last=True,
)
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.args.sample_packing and not self.args.pretraining:
if self.args.multipack_real_batches:
batch_size = self.args.per_device_train_batch_size
batch_max_len = self.args.max_seq_length
else:
batch_size = 1
train_batch_size = (
self.state.train_batch_size or self.args.per_device_train_batch_size
)
batch_max_len = train_batch_size * self.args.max_seq_length
# Handle sequence parallelism
if self.args.sequence_parallel_size > 1:
num_sp_groups = self.args.world_size // self.args.sequence_parallel_size
sp_group_id = dist.get_rank() // self.args.sequence_parallel_size
if self.args.curriculum_sampling:
sampler = SequentialSampler(self.train_dataset)
else:
sampler = RandomSampler(self.train_dataset)
return MultipackBatchSampler(
sampler,
lengths=get_dataset_lengths(self.train_dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
batch_max_len=batch_max_len,
batch_size=batch_size,
group_size=self.args.sample_packing_group_size,
bin_size=self.args.sample_packing_bin_size,
# Create base sampler for SP groups
base_sampler = torch.utils.data.distributed.DistributedSampler(
self.train_dataset,
num_replicas=num_sp_groups,
rank=sp_group_id,
seed=self.args.seed if not self.args.curriculum_sampling else None,
shuffle=not self.args.curriculum_sampling,
drop_last=True,
)
# Apply multipack wrapper if needed
if self.args.sample_packing and not self.args.pretraining:
return self._create_multipack_sampler(base_sampler)
return base_sampler
# Handle non-SP mode
if self.args.sample_packing and not self.args.pretraining:
sampler = (
SequentialSampler(self.train_dataset)
if self.args.curriculum_sampling
else RandomSampler(self.train_dataset)
)
return self._create_multipack_sampler(sampler)
if self.args.curriculum_sampling:
return SequentialSampler(self.train_dataset)
@@ -489,13 +515,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
if self.args.dataloader_prefetch_factor:
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
# Handle sequence parallelism (takes precedence over other sampling methods)
if self.args.sequence_parallel_size > 1:
return self._get_sequence_parallel_dataloader(
train_dataset, dataloader_params
)
# Handle other sampler cases
# Use the same sampling logic for all modes, including sequence parallelism
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
sampler = self._get_train_sampler()
@@ -510,43 +530,20 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
dataloader_params["worker_init_fn"] = seed_worker
# Create and possibly prepare the dataloader
# Create dataloader
dataloader = DataLoader(train_dataset, **dataloader_params)
# Sample packing with accelerator preparation
if self.args.sample_packing and not self.args.pretraining:
self.accelerator.even_batches = False
return self.accelerator.prepare_data_loader(dataloader)
return self.accelerator.prepare(dataloader)
# Don't prepare dataloader for sequence parallelism
# We use a distributed sampler in this case
if self.args.sequence_parallel_size > 1:
return dataloader
def _get_sequence_parallel_dataloader(self, dataset, dataloader_params):
"""Create a dataloader with sequence-parallel-aware sampling."""
# Calculate SP group information
num_sp_groups = self.args.world_size // self.args.sequence_parallel_size
global_rank = dist.get_rank()
sp_group_id = global_rank // self.args.sequence_parallel_size
# Create SP-group-aware sampler
if isinstance(dataset, torch.utils.data.IterableDataset):
sampler = None
else:
# Use different seeds for different SP groups to ensure they get different samples
sampler = torch.utils.data.distributed.DistributedSampler(
dataset,
num_replicas=num_sp_groups,
rank=sp_group_id,
seed=self.args.seed,
drop_last=self.args.dataloader_drop_last,
)
# Create dataloader without accelerator preparation
return DataLoader(
dataset,
sampler=sampler,
worker_init_fn=seed_worker,
**dataloader_params,
)
# Prepare dataloader for accelerate distributed training
return self.accelerator.prepare_data_loader(dataloader)
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
if self.args.sample_packing and self.args.eval_sample_packing is False:

View File

@@ -261,6 +261,12 @@ class DataCollatorForSeq2Seq:
# # Replace invalid labels with -100 (ignore index)
# batch["labels"][invalid_mask] = -100
if key == "attention_mask":
logger.info(
f"GPU {self.rank}/{self.world_size}, SP Rank {self.local_rank}/{self.local_world_size}: "
f"Attention mask: {batch['attention_mask']}, "
)
# Handle position_ids if present
if "position_ids" in batch:
pos_ids = batch["position_ids"]
@@ -281,6 +287,11 @@ class DataCollatorForSeq2Seq:
batch["position_ids"], start_idx
)
logger.info(
f"GPU {self.rank}/{self.world_size}, SP Rank {self.local_rank}/{self.local_world_size}: "
f"Position IDs: {batch['position_ids']}, "
)
# if dist.get_rank() == 0:
# import ipdb; ipdb.set_trace()
# dist.barrier()

View File

@@ -104,9 +104,7 @@ def allocate(
class MultipackBatchSampler(BatchSampler):
"""
Batch Sampler class for multipack
"""
"""Batch sampler class for multipack"""
def __init__(
self,