fixing sample packing
This commit is contained in:
@@ -398,34 +398,60 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
)
|
||||
return super()._wrap_model(model, training=training, dataloader=dataloader)
|
||||
|
||||
def _create_multipack_sampler(self, base_sampler):
|
||||
"""Helper method to create a MultipackBatchSampler"""
|
||||
if self.args.multipack_real_batches:
|
||||
batch_size = self.args.per_device_train_batch_size
|
||||
batch_max_len = self.args.max_seq_length
|
||||
else:
|
||||
batch_size = 1
|
||||
train_batch_size = (
|
||||
self.state.train_batch_size or self.args.per_device_train_batch_size
|
||||
)
|
||||
batch_max_len = train_batch_size * self.args.max_seq_length
|
||||
|
||||
return MultipackBatchSampler(
|
||||
base_sampler,
|
||||
lengths=get_dataset_lengths(self.train_dataset),
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
batch_max_len=batch_max_len,
|
||||
batch_size=batch_size,
|
||||
group_size=self.args.sample_packing_group_size,
|
||||
bin_size=self.args.sample_packing_bin_size,
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||
if self.args.sample_packing and not self.args.pretraining:
|
||||
if self.args.multipack_real_batches:
|
||||
batch_size = self.args.per_device_train_batch_size
|
||||
batch_max_len = self.args.max_seq_length
|
||||
else:
|
||||
batch_size = 1
|
||||
train_batch_size = (
|
||||
self.state.train_batch_size or self.args.per_device_train_batch_size
|
||||
)
|
||||
batch_max_len = train_batch_size * self.args.max_seq_length
|
||||
# Handle sequence parallelism
|
||||
if self.args.sequence_parallel_size > 1:
|
||||
num_sp_groups = self.args.world_size // self.args.sequence_parallel_size
|
||||
sp_group_id = dist.get_rank() // self.args.sequence_parallel_size
|
||||
|
||||
if self.args.curriculum_sampling:
|
||||
sampler = SequentialSampler(self.train_dataset)
|
||||
else:
|
||||
sampler = RandomSampler(self.train_dataset)
|
||||
|
||||
return MultipackBatchSampler(
|
||||
sampler,
|
||||
lengths=get_dataset_lengths(self.train_dataset),
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
batch_max_len=batch_max_len,
|
||||
batch_size=batch_size,
|
||||
group_size=self.args.sample_packing_group_size,
|
||||
bin_size=self.args.sample_packing_bin_size,
|
||||
# Create base sampler for SP groups
|
||||
base_sampler = torch.utils.data.distributed.DistributedSampler(
|
||||
self.train_dataset,
|
||||
num_replicas=num_sp_groups,
|
||||
rank=sp_group_id,
|
||||
seed=self.args.seed if not self.args.curriculum_sampling else None,
|
||||
shuffle=not self.args.curriculum_sampling,
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
# Apply multipack wrapper if needed
|
||||
if self.args.sample_packing and not self.args.pretraining:
|
||||
return self._create_multipack_sampler(base_sampler)
|
||||
|
||||
return base_sampler
|
||||
|
||||
# Handle non-SP mode
|
||||
if self.args.sample_packing and not self.args.pretraining:
|
||||
sampler = (
|
||||
SequentialSampler(self.train_dataset)
|
||||
if self.args.curriculum_sampling
|
||||
else RandomSampler(self.train_dataset)
|
||||
)
|
||||
return self._create_multipack_sampler(sampler)
|
||||
|
||||
if self.args.curriculum_sampling:
|
||||
return SequentialSampler(self.train_dataset)
|
||||
|
||||
@@ -489,13 +515,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
if self.args.dataloader_prefetch_factor:
|
||||
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
|
||||
|
||||
# Handle sequence parallelism (takes precedence over other sampling methods)
|
||||
if self.args.sequence_parallel_size > 1:
|
||||
return self._get_sequence_parallel_dataloader(
|
||||
train_dataset, dataloader_params
|
||||
)
|
||||
|
||||
# Handle other sampler cases
|
||||
# Use the same sampling logic for all modes, including sequence parallelism
|
||||
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
|
||||
sampler = self._get_train_sampler()
|
||||
|
||||
@@ -510,43 +530,20 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
|
||||
dataloader_params["worker_init_fn"] = seed_worker
|
||||
|
||||
# Create and possibly prepare the dataloader
|
||||
# Create dataloader
|
||||
dataloader = DataLoader(train_dataset, **dataloader_params)
|
||||
|
||||
# Sample packing with accelerator preparation
|
||||
if self.args.sample_packing and not self.args.pretraining:
|
||||
self.accelerator.even_batches = False
|
||||
return self.accelerator.prepare_data_loader(dataloader)
|
||||
|
||||
return self.accelerator.prepare(dataloader)
|
||||
# Don't prepare dataloader for sequence parallelism
|
||||
# We use a distributed sampler in this case
|
||||
if self.args.sequence_parallel_size > 1:
|
||||
return dataloader
|
||||
|
||||
def _get_sequence_parallel_dataloader(self, dataset, dataloader_params):
|
||||
"""Create a dataloader with sequence-parallel-aware sampling."""
|
||||
# Calculate SP group information
|
||||
num_sp_groups = self.args.world_size // self.args.sequence_parallel_size
|
||||
global_rank = dist.get_rank()
|
||||
sp_group_id = global_rank // self.args.sequence_parallel_size
|
||||
|
||||
# Create SP-group-aware sampler
|
||||
if isinstance(dataset, torch.utils.data.IterableDataset):
|
||||
sampler = None
|
||||
else:
|
||||
# Use different seeds for different SP groups to ensure they get different samples
|
||||
sampler = torch.utils.data.distributed.DistributedSampler(
|
||||
dataset,
|
||||
num_replicas=num_sp_groups,
|
||||
rank=sp_group_id,
|
||||
seed=self.args.seed,
|
||||
drop_last=self.args.dataloader_drop_last,
|
||||
)
|
||||
|
||||
# Create dataloader without accelerator preparation
|
||||
return DataLoader(
|
||||
dataset,
|
||||
sampler=sampler,
|
||||
worker_init_fn=seed_worker,
|
||||
**dataloader_params,
|
||||
)
|
||||
# Prepare dataloader for accelerate distributed training
|
||||
return self.accelerator.prepare_data_loader(dataloader)
|
||||
|
||||
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
||||
if self.args.sample_packing and self.args.eval_sample_packing is False:
|
||||
|
||||
@@ -261,6 +261,12 @@ class DataCollatorForSeq2Seq:
|
||||
# # Replace invalid labels with -100 (ignore index)
|
||||
# batch["labels"][invalid_mask] = -100
|
||||
|
||||
if key == "attention_mask":
|
||||
logger.info(
|
||||
f"GPU {self.rank}/{self.world_size}, SP Rank {self.local_rank}/{self.local_world_size}: "
|
||||
f"Attention mask: {batch['attention_mask']}, "
|
||||
)
|
||||
|
||||
# Handle position_ids if present
|
||||
if "position_ids" in batch:
|
||||
pos_ids = batch["position_ids"]
|
||||
@@ -281,6 +287,11 @@ class DataCollatorForSeq2Seq:
|
||||
batch["position_ids"], start_idx
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"GPU {self.rank}/{self.world_size}, SP Rank {self.local_rank}/{self.local_world_size}: "
|
||||
f"Position IDs: {batch['position_ids']}, "
|
||||
)
|
||||
|
||||
# if dist.get_rank() == 0:
|
||||
# import ipdb; ipdb.set_trace()
|
||||
# dist.barrier()
|
||||
|
||||
@@ -104,9 +104,7 @@ def allocate(
|
||||
|
||||
|
||||
class MultipackBatchSampler(BatchSampler):
|
||||
"""
|
||||
Batch Sampler class for multipack
|
||||
"""
|
||||
"""Batch sampler class for multipack"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user