improve handling of train len

This commit is contained in:
Wing Lian
2025-06-06 22:07:29 -07:00
parent 2c66483a47
commit 2491303c46
4 changed files with 19 additions and 6 deletions

View File

@@ -102,7 +102,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
)
batch_max_len = train_batch_size * self.args.max_seq_length
return MultipackBatchSampler(
sampler = MultipackBatchSampler(
base_sampler,
lengths=get_dataset_lengths(dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
@@ -114,6 +114,9 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
drop_last=True,
)
len(sampler)
return sampler
def _get_train_sampler(
self, train_dataset: Optional[Dataset] = None
) -> Optional[Sampler]:

View File

@@ -397,7 +397,6 @@ class PluginManager: # pylint: disable=too-many-public-methods
training_args = []
for plugin in self.plugins.values():
training_args_from_plugin = plugin.get_training_args_mixin()
print(f"Training args from plugin: {plugin.__class__.__name__}")
if training_args_from_plugin is not None:
training_args.append(training_args_from_plugin)
return training_args

View File

@@ -443,10 +443,18 @@ class MultipackBatchSampler(BatchSampler):
if self._len_across_ranks is None:
# Sample multiple times to get stable estimate
len_batches = min( # pylint: disable=consider-using-generator
[len(self._batches) for _ in range(self.num_count_samples)]
)
_sampled_lens = []
for _ in range(self.num_count_samples):
self._batches = None # Reset cached batches
_sampled_lens.append(len(self.generate_batches(set_stats=False)))
len_batches = min(_sampled_lens)
# Gather minimum across all ranks
self._len_across_ranks = self.gather_len_batches(len_batches)
if self._len_across_ranks is None:
self._len_across_ranks = self.gather_len_batches(len_batches)
else:
self._len_across_ranks = min(
self._len_across_ranks, self.gather_len_batches(len_batches)
)
return self._len_across_ranks

View File

@@ -481,6 +481,9 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
data_loader_len * cfg.num_epochs * cfg.sequence_parallel_degree
)
)
if cfg.dataloader_drop_last:
# drop the last batch for each epoch
total_num_steps -= int(math.ceil(cfg.num_epochs))
def calc_sample_packing_eff_est(estimates: List[float]):
LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}")