diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index e73d2af8b..27f3ee410 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -102,7 +102,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer): ) batch_max_len = train_batch_size * self.args.max_seq_length - return MultipackBatchSampler( + sampler = MultipackBatchSampler( base_sampler, lengths=get_dataset_lengths(dataset), packing_efficiency_estimate=self.args.sample_packing_efficiency, @@ -114,6 +114,9 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer): drop_last=True, ) + len(sampler) + return sampler + def _get_train_sampler( self, train_dataset: Optional[Dataset] = None ) -> Optional[Sampler]: diff --git a/src/axolotl/integrations/base.py b/src/axolotl/integrations/base.py index 91b380138..9162bc745 100644 --- a/src/axolotl/integrations/base.py +++ b/src/axolotl/integrations/base.py @@ -397,7 +397,6 @@ class PluginManager: # pylint: disable=too-many-public-methods training_args = [] for plugin in self.plugins.values(): training_args_from_plugin = plugin.get_training_args_mixin() - print(f"Training args from plugin: {plugin.__class__.__name__}") if training_args_from_plugin is not None: training_args.append(training_args_from_plugin) return training_args diff --git a/src/axolotl/utils/samplers/multipack.py b/src/axolotl/utils/samplers/multipack.py index 13c9d4ea1..2a3793ad5 100644 --- a/src/axolotl/utils/samplers/multipack.py +++ b/src/axolotl/utils/samplers/multipack.py @@ -443,10 +443,18 @@ class MultipackBatchSampler(BatchSampler): if self._len_across_ranks is None: # Sample multiple times to get stable estimate - len_batches = min( # pylint: disable=consider-using-generator - [len(self._batches) for _ in range(self.num_count_samples)] - ) + _sampled_lens = [] + for _ in range(self.num_count_samples): + self._batches = None # Reset cached batches + _sampled_lens.append(len(self.generate_batches(set_stats=False))) + len_batches = min(_sampled_lens) + # Gather minimum across all ranks - self._len_across_ranks = self.gather_len_batches(len_batches) + if self._len_across_ranks is None: + self._len_across_ranks = self.gather_len_batches(len_batches) + else: + self._len_across_ranks = min( + self._len_across_ranks, self.gather_len_batches(len_batches) + ) return self._len_across_ranks diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 68146b07c..eb3bc351b 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -481,6 +481,9 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): data_loader_len * cfg.num_epochs * cfg.sequence_parallel_degree ) ) + if cfg.dataloader_drop_last: + # drop the last batch for each epoch + total_num_steps -= int(math.ceil(cfg.num_epochs)) def calc_sample_packing_eff_est(estimates: List[float]): LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}")