Use SequentialSampler if curriculum_sampling is enabled with sample_packing (#2235)
This commit is contained in:
committed by
GitHub
parent
5e0124e2ab
commit
6553683170
@@ -608,8 +608,14 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
self.state.train_batch_size or self.args.per_device_train_batch_size
|
self.state.train_batch_size or self.args.per_device_train_batch_size
|
||||||
)
|
)
|
||||||
batch_max_len = train_batch_size * self.args.max_seq_length
|
batch_max_len = train_batch_size * self.args.max_seq_length
|
||||||
|
|
||||||
|
if self.args.curriculum_sampling:
|
||||||
|
sampler = SequentialSampler(self.train_dataset)
|
||||||
|
else:
|
||||||
|
sampler = RandomSampler(self.train_dataset)
|
||||||
|
|
||||||
return MultipackBatchSampler(
|
return MultipackBatchSampler(
|
||||||
RandomSampler(self.train_dataset),
|
sampler,
|
||||||
lengths=get_dataset_lengths(self.train_dataset),
|
lengths=get_dataset_lengths(self.train_dataset),
|
||||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||||
batch_max_len=batch_max_len,
|
batch_max_len=batch_max_len,
|
||||||
|
|||||||
Reference in New Issue
Block a user