Use SequentialSampler if curriculum_sampling is enabled with sample_packing (#2235)

This commit is contained in:
Vincenzo di Cicco
2025-01-09 22:01:22 +01:00
committed by GitHub
parent 5e0124e2ab
commit 6553683170

View File

@@ -608,8 +608,14 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
self.state.train_batch_size or self.args.per_device_train_batch_size
)
batch_max_len = train_batch_size * self.args.max_seq_length
if self.args.curriculum_sampling:
sampler = SequentialSampler(self.train_dataset)
else:
sampler = RandomSampler(self.train_dataset)
return MultipackBatchSampler(
RandomSampler(self.train_dataset),
sampler,
lengths=get_dataset_lengths(self.train_dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
batch_max_len=batch_max_len,