improve handling of train len
This commit is contained in:
@@ -102,7 +102,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
|
|||||||
)
|
)
|
||||||
batch_max_len = train_batch_size * self.args.max_seq_length
|
batch_max_len = train_batch_size * self.args.max_seq_length
|
||||||
|
|
||||||
return MultipackBatchSampler(
|
sampler = MultipackBatchSampler(
|
||||||
base_sampler,
|
base_sampler,
|
||||||
lengths=get_dataset_lengths(dataset),
|
lengths=get_dataset_lengths(dataset),
|
||||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||||
@@ -114,6 +114,9 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
|
|||||||
drop_last=True,
|
drop_last=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
len(sampler)
|
||||||
|
return sampler
|
||||||
|
|
||||||
def _get_train_sampler(
|
def _get_train_sampler(
|
||||||
self, train_dataset: Optional[Dataset] = None
|
self, train_dataset: Optional[Dataset] = None
|
||||||
) -> Optional[Sampler]:
|
) -> Optional[Sampler]:
|
||||||
|
|||||||
@@ -397,7 +397,6 @@ class PluginManager: # pylint: disable=too-many-public-methods
|
|||||||
training_args = []
|
training_args = []
|
||||||
for plugin in self.plugins.values():
|
for plugin in self.plugins.values():
|
||||||
training_args_from_plugin = plugin.get_training_args_mixin()
|
training_args_from_plugin = plugin.get_training_args_mixin()
|
||||||
print(f"Training args from plugin: {plugin.__class__.__name__}")
|
|
||||||
if training_args_from_plugin is not None:
|
if training_args_from_plugin is not None:
|
||||||
training_args.append(training_args_from_plugin)
|
training_args.append(training_args_from_plugin)
|
||||||
return training_args
|
return training_args
|
||||||
|
|||||||
@@ -443,10 +443,18 @@ class MultipackBatchSampler(BatchSampler):
|
|||||||
|
|
||||||
if self._len_across_ranks is None:
|
if self._len_across_ranks is None:
|
||||||
# Sample multiple times to get stable estimate
|
# Sample multiple times to get stable estimate
|
||||||
len_batches = min( # pylint: disable=consider-using-generator
|
_sampled_lens = []
|
||||||
[len(self._batches) for _ in range(self.num_count_samples)]
|
for _ in range(self.num_count_samples):
|
||||||
)
|
self._batches = None # Reset cached batches
|
||||||
|
_sampled_lens.append(len(self.generate_batches(set_stats=False)))
|
||||||
|
len_batches = min(_sampled_lens)
|
||||||
|
|
||||||
# Gather minimum across all ranks
|
# Gather minimum across all ranks
|
||||||
self._len_across_ranks = self.gather_len_batches(len_batches)
|
if self._len_across_ranks is None:
|
||||||
|
self._len_across_ranks = self.gather_len_batches(len_batches)
|
||||||
|
else:
|
||||||
|
self._len_across_ranks = min(
|
||||||
|
self._len_across_ranks, self.gather_len_batches(len_batches)
|
||||||
|
)
|
||||||
|
|
||||||
return self._len_across_ranks
|
return self._len_across_ranks
|
||||||
|
|||||||
@@ -481,6 +481,9 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
data_loader_len * cfg.num_epochs * cfg.sequence_parallel_degree
|
data_loader_len * cfg.num_epochs * cfg.sequence_parallel_degree
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
if cfg.dataloader_drop_last:
|
||||||
|
# drop the last batch for each epoch
|
||||||
|
total_num_steps -= int(math.ceil(cfg.num_epochs))
|
||||||
|
|
||||||
def calc_sample_packing_eff_est(estimates: List[float]):
|
def calc_sample_packing_eff_est(estimates: List[float]):
|
||||||
LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}")
|
LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}")
|
||||||
|
|||||||
Reference in New Issue
Block a user