improve handling of train len
This commit is contained in:
@@ -102,7 +102,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
|
||||
)
|
||||
batch_max_len = train_batch_size * self.args.max_seq_length
|
||||
|
||||
return MultipackBatchSampler(
|
||||
sampler = MultipackBatchSampler(
|
||||
base_sampler,
|
||||
lengths=get_dataset_lengths(dataset),
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
@@ -114,6 +114,9 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
len(sampler)
|
||||
return sampler
|
||||
|
||||
def _get_train_sampler(
|
||||
self, train_dataset: Optional[Dataset] = None
|
||||
) -> Optional[Sampler]:
|
||||
|
||||
@@ -397,7 +397,6 @@ class PluginManager: # pylint: disable=too-many-public-methods
|
||||
training_args = []
|
||||
for plugin in self.plugins.values():
|
||||
training_args_from_plugin = plugin.get_training_args_mixin()
|
||||
print(f"Training args from plugin: {plugin.__class__.__name__}")
|
||||
if training_args_from_plugin is not None:
|
||||
training_args.append(training_args_from_plugin)
|
||||
return training_args
|
||||
|
||||
@@ -443,10 +443,18 @@ class MultipackBatchSampler(BatchSampler):
|
||||
|
||||
if self._len_across_ranks is None:
|
||||
# Sample multiple times to get stable estimate
|
||||
len_batches = min( # pylint: disable=consider-using-generator
|
||||
[len(self._batches) for _ in range(self.num_count_samples)]
|
||||
)
|
||||
_sampled_lens = []
|
||||
for _ in range(self.num_count_samples):
|
||||
self._batches = None # Reset cached batches
|
||||
_sampled_lens.append(len(self.generate_batches(set_stats=False)))
|
||||
len_batches = min(_sampled_lens)
|
||||
|
||||
# Gather minimum across all ranks
|
||||
self._len_across_ranks = self.gather_len_batches(len_batches)
|
||||
if self._len_across_ranks is None:
|
||||
self._len_across_ranks = self.gather_len_batches(len_batches)
|
||||
else:
|
||||
self._len_across_ranks = min(
|
||||
self._len_across_ranks, self.gather_len_batches(len_batches)
|
||||
)
|
||||
|
||||
return self._len_across_ranks
|
||||
|
||||
@@ -481,6 +481,9 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
data_loader_len * cfg.num_epochs * cfg.sequence_parallel_degree
|
||||
)
|
||||
)
|
||||
if cfg.dataloader_drop_last:
|
||||
# drop the last batch for each epoch
|
||||
total_num_steps -= int(math.ceil(cfg.num_epochs))
|
||||
|
||||
def calc_sample_packing_eff_est(estimates: List[float]):
|
||||
LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}")
|
||||
|
||||
Reference in New Issue
Block a user