eval dataloader and sampler changes
This commit is contained in:
@@ -398,7 +398,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
)
|
||||
return super()._wrap_model(model, training=training, dataloader=dataloader)
|
||||
|
||||
def _create_multipack_sampler(self, base_sampler):
|
||||
def _create_multipack_sampler(self, base_sampler, dataset, group_size):
|
||||
"""Helper method to create a MultipackBatchSampler"""
|
||||
if self.args.multipack_real_batches:
|
||||
batch_size = self.args.per_device_train_batch_size
|
||||
@@ -412,11 +412,11 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
|
||||
return MultipackBatchSampler(
|
||||
base_sampler,
|
||||
lengths=get_dataset_lengths(self.train_dataset),
|
||||
lengths=get_dataset_lengths(dataset),
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
batch_max_len=batch_max_len,
|
||||
batch_size=batch_size,
|
||||
group_size=self.args.sample_packing_group_size,
|
||||
group_size=group_size,
|
||||
bin_size=self.args.sample_packing_bin_size,
|
||||
drop_last=True,
|
||||
)
|
||||
@@ -439,18 +439,26 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
|
||||
# Apply multipack wrapper if needed
|
||||
if self.args.sample_packing and not self.args.pretraining:
|
||||
return self._create_multipack_sampler(base_sampler)
|
||||
return self._create_multipack_sampler(
|
||||
base_sampler=base_sampler,
|
||||
dataset=self.train_dataset,
|
||||
group_size=self.args.sample_packing_group_size,
|
||||
)
|
||||
|
||||
return base_sampler
|
||||
|
||||
# Handle non-SP mode
|
||||
if self.args.sample_packing and not self.args.pretraining:
|
||||
sampler = (
|
||||
base_sampler = (
|
||||
SequentialSampler(self.train_dataset)
|
||||
if self.args.curriculum_sampling
|
||||
else RandomSampler(self.train_dataset)
|
||||
)
|
||||
return self._create_multipack_sampler(sampler)
|
||||
|
||||
return self._create_multipack_sampler(
|
||||
base_sampler=base_sampler,
|
||||
dataset=self.train_dataset,
|
||||
group_size=self.args.sample_packing_group_size,
|
||||
)
|
||||
|
||||
if self.args.curriculum_sampling:
|
||||
return SequentialSampler(self.train_dataset)
|
||||
@@ -458,27 +466,55 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
return super()._get_train_sampler()
|
||||
|
||||
def _get_eval_sampler(
|
||||
self, eval_dataset: Dataset
|
||||
self, eval_dataset: Optional[Dataset] = None
|
||||
) -> Optional[torch.utils.data.Sampler]:
|
||||
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
||||
if self.args.multipack_real_batches:
|
||||
batch_size = self.args.per_device_eval_batch_size
|
||||
batch_max_len = self.args.max_seq_length
|
||||
else:
|
||||
batch_size = 1
|
||||
batch_max_len = (
|
||||
self.args.per_device_eval_batch_size * self.args.max_seq_length
|
||||
)
|
||||
return MultipackBatchSampler(
|
||||
SequentialSampler(eval_dataset),
|
||||
lengths=get_dataset_lengths(self.eval_dataset),
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
batch_max_len=batch_max_len,
|
||||
batch_size=batch_size,
|
||||
group_size=self.args.sample_packing_group_size,
|
||||
bin_size=self.args.sample_packing_bin_size,
|
||||
drop_last=True,
|
||||
"""Get evaluation sampler"""
|
||||
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||
|
||||
# Handle sequence parallelism
|
||||
if self.args.sequence_parallel_size > 1:
|
||||
# Create sampler for SP groups
|
||||
num_sp_groups = self.args.world_size // self.args.sequence_parallel_size
|
||||
sp_group_id = dist.get_rank() // self.args.sequence_parallel_size
|
||||
|
||||
# Create distributed sampler for the SP group
|
||||
base_sampler = torch.utils.data.distributed.DistributedSampler(
|
||||
eval_dataset,
|
||||
num_replicas=num_sp_groups,
|
||||
rank=sp_group_id,
|
||||
shuffle=False,
|
||||
drop_last=False,
|
||||
)
|
||||
|
||||
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
||||
group_size = (
|
||||
self.args.eval_packing_group_size
|
||||
if hasattr(self.args, "eval_packing_group_size")
|
||||
else self.args.sample_packing_group_size
|
||||
)
|
||||
|
||||
return self._create_multipack_sampler(
|
||||
base_sampler=base_sampler,
|
||||
dataset=eval_dataset,
|
||||
group_size=group_size,
|
||||
)
|
||||
|
||||
return base_sampler
|
||||
|
||||
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
||||
base_sampler = SequentialSampler(eval_dataset)
|
||||
group_size = (
|
||||
self.args.eval_packing_group_size
|
||||
if hasattr(self.args, "eval_packing_group_size")
|
||||
else self.args.sample_packing_group_size
|
||||
)
|
||||
|
||||
return self._create_multipack_sampler(
|
||||
base_sampler=base_sampler,
|
||||
dataset=eval_dataset,
|
||||
group_size=group_size,
|
||||
)
|
||||
|
||||
return super()._get_eval_sampler(eval_dataset)
|
||||
|
||||
def get_train_dataloader(self) -> DataLoader:
|
||||
@@ -546,25 +582,30 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
return self.accelerator.prepare_data_loader(dataloader)
|
||||
|
||||
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
||||
"""Get dataloader for evaluation"""
|
||||
if self.args.sample_packing and self.args.eval_sample_packing is False:
|
||||
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
||||
self.eval_data_collator
|
||||
)
|
||||
if eval_dataset:
|
||||
if eval_dataset and "length" in eval_dataset.features:
|
||||
eval_dataset = eval_dataset.remove_columns(["length"])
|
||||
dataloader = super().get_eval_dataloader(eval_dataset)
|
||||
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
||||
self.train_data_collator
|
||||
)
|
||||
return dataloader
|
||||
|
||||
return dataloader
|
||||
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
||||
eval_dataset = (
|
||||
eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||
)
|
||||
|
||||
eval_sampler = self._get_eval_sampler(eval_dataset)
|
||||
eval_dataset = eval_dataset.remove_columns(["length"])
|
||||
|
||||
# Only remove length column if it exists
|
||||
if "length" in eval_dataset.features:
|
||||
eval_dataset = eval_dataset.remove_columns(["length"])
|
||||
|
||||
data_collator = self.data_collator
|
||||
dataloader_params = {
|
||||
"batch_size": self.args.eval_batch_size,
|
||||
@@ -572,6 +613,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
"num_workers": self.args.dataloader_num_workers,
|
||||
"pin_memory": self.args.dataloader_pin_memory,
|
||||
}
|
||||
|
||||
if self.args.dataloader_prefetch_factor:
|
||||
dataloader_params["prefetch_factor"] = (
|
||||
self.args.dataloader_prefetch_factor
|
||||
@@ -585,9 +627,50 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||
|
||||
self.accelerator.even_batches = False
|
||||
return self.accelerator.prepare_data_loader(
|
||||
DataLoader(eval_dataset, **dataloader_params)
|
||||
|
||||
# Create dataloader
|
||||
dataloader = DataLoader(eval_dataset, **dataloader_params)
|
||||
|
||||
# Don't prepare dataloader for sequence parallelism
|
||||
if self.args.sequence_parallel_size > 1:
|
||||
return dataloader
|
||||
|
||||
return self.accelerator.prepare_data_loader(dataloader)
|
||||
if self.args.sequence_parallel_size > 1:
|
||||
# We need to customize the default dataloader for sequence parallelism
|
||||
eval_dataset = (
|
||||
eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||
)
|
||||
data_collator = (
|
||||
self.eval_data_collator
|
||||
if self.eval_data_collator
|
||||
else self.data_collator
|
||||
)
|
||||
|
||||
# Handle dataset preprocessing as in the parent implementation
|
||||
if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
|
||||
eval_dataset = self._remove_unused_columns(
|
||||
eval_dataset, description="evaluation"
|
||||
)
|
||||
else:
|
||||
data_collator = self._get_collator_with_removed_columns(
|
||||
data_collator, description="evaluation"
|
||||
)
|
||||
|
||||
# Build dataloader parameters
|
||||
dataloader_params = {
|
||||
"batch_size": self.args.per_device_eval_batch_size,
|
||||
"collate_fn": data_collator,
|
||||
"num_workers": self.args.dataloader_num_workers,
|
||||
"pin_memory": self.args.dataloader_pin_memory,
|
||||
}
|
||||
|
||||
if not isinstance(eval_dataset, torch.utils.data.IterableDataset):
|
||||
sampler = self._get_eval_sampler(eval_dataset)
|
||||
dataloader_params["sampler"] = sampler
|
||||
|
||||
# Create dataloader without accelerator preparation for sequence parallelism
|
||||
return DataLoader(eval_dataset, **dataloader_params)
|
||||
|
||||
return super().get_eval_dataloader(eval_dataset)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user