eval dataloader and sampler changes

This commit is contained in:
Dan Saunders
2025-03-13 19:24:30 +00:00
parent d0e178d52f
commit 4ff97bc9d4
2 changed files with 173 additions and 53 deletions

View File

@@ -398,7 +398,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
)
return super()._wrap_model(model, training=training, dataloader=dataloader)
def _create_multipack_sampler(self, base_sampler):
def _create_multipack_sampler(self, base_sampler, dataset, group_size):
"""Helper method to create a MultipackBatchSampler"""
if self.args.multipack_real_batches:
batch_size = self.args.per_device_train_batch_size
@@ -412,11 +412,11 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
return MultipackBatchSampler(
base_sampler,
lengths=get_dataset_lengths(self.train_dataset),
lengths=get_dataset_lengths(dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
batch_max_len=batch_max_len,
batch_size=batch_size,
group_size=self.args.sample_packing_group_size,
group_size=group_size,
bin_size=self.args.sample_packing_bin_size,
drop_last=True,
)
@@ -439,18 +439,26 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
# Apply multipack wrapper if needed
if self.args.sample_packing and not self.args.pretraining:
return self._create_multipack_sampler(base_sampler)
return self._create_multipack_sampler(
base_sampler=base_sampler,
dataset=self.train_dataset,
group_size=self.args.sample_packing_group_size,
)
return base_sampler
# Handle non-SP mode
if self.args.sample_packing and not self.args.pretraining:
sampler = (
base_sampler = (
SequentialSampler(self.train_dataset)
if self.args.curriculum_sampling
else RandomSampler(self.train_dataset)
)
return self._create_multipack_sampler(sampler)
return self._create_multipack_sampler(
base_sampler=base_sampler,
dataset=self.train_dataset,
group_size=self.args.sample_packing_group_size,
)
if self.args.curriculum_sampling:
return SequentialSampler(self.train_dataset)
@@ -458,27 +466,55 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
return super()._get_train_sampler()
def _get_eval_sampler(
self, eval_dataset: Dataset
self, eval_dataset: Optional[Dataset] = None
) -> Optional[torch.utils.data.Sampler]:
if self.args.sample_packing and self.args.eval_sample_packing is not False:
if self.args.multipack_real_batches:
batch_size = self.args.per_device_eval_batch_size
batch_max_len = self.args.max_seq_length
else:
batch_size = 1
batch_max_len = (
self.args.per_device_eval_batch_size * self.args.max_seq_length
)
return MultipackBatchSampler(
SequentialSampler(eval_dataset),
lengths=get_dataset_lengths(self.eval_dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
batch_max_len=batch_max_len,
batch_size=batch_size,
group_size=self.args.sample_packing_group_size,
bin_size=self.args.sample_packing_bin_size,
drop_last=True,
"""Get evaluation sampler"""
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
# Handle sequence parallelism
if self.args.sequence_parallel_size > 1:
# Create sampler for SP groups
num_sp_groups = self.args.world_size // self.args.sequence_parallel_size
sp_group_id = dist.get_rank() // self.args.sequence_parallel_size
# Create distributed sampler for the SP group
base_sampler = torch.utils.data.distributed.DistributedSampler(
eval_dataset,
num_replicas=num_sp_groups,
rank=sp_group_id,
shuffle=False,
drop_last=False,
)
if self.args.sample_packing and self.args.eval_sample_packing is not False:
group_size = (
self.args.eval_packing_group_size
if hasattr(self.args, "eval_packing_group_size")
else self.args.sample_packing_group_size
)
return self._create_multipack_sampler(
base_sampler=base_sampler,
dataset=eval_dataset,
group_size=group_size,
)
return base_sampler
if self.args.sample_packing and self.args.eval_sample_packing is not False:
base_sampler = SequentialSampler(eval_dataset)
group_size = (
self.args.eval_packing_group_size
if hasattr(self.args, "eval_packing_group_size")
else self.args.sample_packing_group_size
)
return self._create_multipack_sampler(
base_sampler=base_sampler,
dataset=eval_dataset,
group_size=group_size,
)
return super()._get_eval_sampler(eval_dataset)
def get_train_dataloader(self) -> DataLoader:
@@ -546,25 +582,30 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
return self.accelerator.prepare_data_loader(dataloader)
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
"""Get dataloader for evaluation"""
if self.args.sample_packing and self.args.eval_sample_packing is False:
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
self.eval_data_collator
)
if eval_dataset:
if eval_dataset and "length" in eval_dataset.features:
eval_dataset = eval_dataset.remove_columns(["length"])
dataloader = super().get_eval_dataloader(eval_dataset)
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
self.train_data_collator
)
return dataloader
return dataloader
if self.args.sample_packing and self.args.eval_sample_packing is not False:
eval_dataset = (
eval_dataset if eval_dataset is not None else self.eval_dataset
)
eval_sampler = self._get_eval_sampler(eval_dataset)
eval_dataset = eval_dataset.remove_columns(["length"])
# Only remove length column if it exists
if "length" in eval_dataset.features:
eval_dataset = eval_dataset.remove_columns(["length"])
data_collator = self.data_collator
dataloader_params = {
"batch_size": self.args.eval_batch_size,
@@ -572,6 +613,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
if self.args.dataloader_prefetch_factor:
dataloader_params["prefetch_factor"] = (
self.args.dataloader_prefetch_factor
@@ -585,9 +627,50 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
dataloader_params["drop_last"] = self.args.dataloader_drop_last
self.accelerator.even_batches = False
return self.accelerator.prepare_data_loader(
DataLoader(eval_dataset, **dataloader_params)
# Create dataloader
dataloader = DataLoader(eval_dataset, **dataloader_params)
# Don't prepare dataloader for sequence parallelism
if self.args.sequence_parallel_size > 1:
return dataloader
return self.accelerator.prepare_data_loader(dataloader)
if self.args.sequence_parallel_size > 1:
# We need to customize the default dataloader for sequence parallelism
eval_dataset = (
eval_dataset if eval_dataset is not None else self.eval_dataset
)
data_collator = (
self.eval_data_collator
if self.eval_data_collator
else self.data_collator
)
# Handle dataset preprocessing as in the parent implementation
if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
eval_dataset = self._remove_unused_columns(
eval_dataset, description="evaluation"
)
else:
data_collator = self._get_collator_with_removed_columns(
data_collator, description="evaluation"
)
# Build dataloader parameters
dataloader_params = {
"batch_size": self.args.per_device_eval_batch_size,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
if not isinstance(eval_dataset, torch.utils.data.IterableDataset):
sampler = self._get_eval_sampler(eval_dataset)
dataloader_params["sampler"] = sampler
# Create dataloader without accelerator preparation for sequence parallelism
return DataLoader(eval_dataset, **dataloader_params)
return super().get_eval_dataloader(eval_dataset)