diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 7d7420fb8..7896c6088 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -43,7 +43,7 @@ from axolotl.core.trainers.utils import ( from axolotl.utils import get_not_null from axolotl.utils.bench import get_gpu_memory_usage from axolotl.utils.dict import DictDefault -from axolotl.utils.distributed import is_main_process +from axolotl.utils.distributed import is_distributed, is_main_process from axolotl.utils.logging import get_logger from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths @@ -350,6 +350,11 @@ class AxolotlTrainer( # track number of tokens for tokens per second calculation if self.args.include_tkps: inputs_key = "labels" if "labels" in inputs else "input_ids" + num_tokens = (inputs[inputs_key] != -100).sum() + if is_distributed(): + torch.distributed.all_reduce( + num_tokens, op=torch.distributed.ReduceOp.SUM + ) if hasattr(self.state, "num_tokens"): self.state.num_tokens = ( self.state.num_tokens + (inputs[inputs_key] != -100).sum().cpu() @@ -357,6 +362,11 @@ class AxolotlTrainer( else: self.state.num_tokens = (inputs[inputs_key] != -100).sum().cpu() + if hasattr(self.state, "total_tokens"): + self.state.total_tokens += num_tokens + else: + self.state.total_tokens = num_tokens + if self.args.orpo_alpha: return self.orpo_compute_loss( model, @@ -621,6 +631,7 @@ class AxolotlTrainer( logs["tokens_per_second_per_gpu"] = round( self.state.last_tokens_per_second.item() / self.args.logging_steps, 2 ) + logs["total_tokens"] = int(self.state.total_tokens.item()) del self._stored_metrics[train_eval]