From 9640338d37d0398cd3c0c0ab6e629b6dd9dcd5d3 Mon Sep 17 00:00:00 2001 From: salman Date: Tue, 9 Sep 2025 15:50:21 +0100 Subject: [PATCH] Default `include_tkps` to true (#3134) * default true * force e2e * causal trainer only * fix eval loggin [skip-ci] * revert setup.py * force tests * guarding * guarding * fix test case * use evaluate [skip-e2e] * use evaluate [skip-e2e] * kick off ci * fixing * reverting --- src/axolotl/core/builders/base.py | 7 ------- src/axolotl/core/builders/causal.py | 7 +++++++ src/axolotl/core/trainers/base.py | 4 ++-- src/axolotl/utils/callbacks/tokens_per_second.py | 16 +++++++++------- src/axolotl/utils/schemas/config.py | 4 ++-- 5 files changed, 20 insertions(+), 18 deletions(-) diff --git a/src/axolotl/core/builders/base.py b/src/axolotl/core/builders/base.py index bee291fa2..1ec818004 100644 --- a/src/axolotl/core/builders/base.py +++ b/src/axolotl/core/builders/base.py @@ -36,7 +36,6 @@ from axolotl.utils.callbacks import ( SaveModelOnFirstStepCallback, ) from axolotl.utils.callbacks.profiler import PytorchProfilerCallback -from axolotl.utils.callbacks.tokens_per_second import TokensPerSecondCallback from axolotl.utils.distributed import build_parallelism_config from axolotl.utils.schemas.enums import CustomSupportedOptimizers @@ -145,12 +144,6 @@ class TrainerBuilderBase(abc.ABC): profiler_steps_start=self.cfg.profiler_steps_start, ) ) - if self.cfg.include_tkps: - callbacks.append( - TokensPerSecondCallback( - self.cfg.tensor_parallel_size, self.cfg.context_parallel_size - ) - ) return callbacks diff --git a/src/axolotl/core/builders/causal.py b/src/axolotl/core/builders/causal.py index 057d0ab5c..ee6383d47 100644 --- a/src/axolotl/core/builders/causal.py +++ b/src/axolotl/core/builders/causal.py @@ -39,6 +39,7 @@ from axolotl.utils.collators import ( MambaDataCollator, V2BatchSamplerDataCollatorForSeq2Seq, ) +from axolotl.utils.callbacks.tokens_per_second import TokensPerSecondCallback from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator from axolotl.utils.import_helper import get_cls_from_module_str from axolotl.utils.logging import get_logger @@ -71,6 +72,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): if self.cfg.qat: callbacks.append(QATCallback(self.cfg.qat)) + if self.cfg.include_tkps: + callbacks.append( + TokensPerSecondCallback( + self.cfg.tensor_parallel_size, self.cfg.context_parallel_size + ) + ) return callbacks def get_post_trainer_create_callbacks(self, trainer): diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 06eef445b..d7555261f 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -342,10 +342,10 @@ class AxolotlTrainer( inputs_key = "labels" if "labels" in inputs else "input_ids" if hasattr(self.state, "num_tokens"): self.state.num_tokens = ( - self.state.num_tokens + (inputs[inputs_key] != -100).sum() + self.state.num_tokens + (inputs[inputs_key] != -100).sum().cpu() ) else: - self.state.num_tokens = (inputs[inputs_key] != -100).sum() + self.state.num_tokens = (inputs[inputs_key] != -100).sum().cpu() if self.args.orpo_alpha: return self.orpo_compute_loss( diff --git a/src/axolotl/utils/callbacks/tokens_per_second.py b/src/axolotl/utils/callbacks/tokens_per_second.py index 85bcd5041..ead129240 100644 --- a/src/axolotl/utils/callbacks/tokens_per_second.py +++ b/src/axolotl/utils/callbacks/tokens_per_second.py @@ -43,11 +43,12 @@ class TokensPerSecondCallback(TrainerCallback): control: TrainerControl, **kwargs, ): # pylint: disable=unused-argument - step_time = time.perf_counter() - self.start_time - num_tokens_per_device = state.num_tokens.clone() - # non data parallel groups have duplicated tokens, so we avoid double-counting - num_tokens_per_device = num_tokens_per_device / self.non_data_parallel_size - state.last_tokens_per_second = num_tokens_per_device / step_time + if hasattr(state, "num_tokens"): + step_time = time.perf_counter() - self.start_time + num_tokens_per_device = state.num_tokens.clone() + # non data parallel groups have duplicated tokens, so we avoid double-counting + num_tokens_per_device = num_tokens_per_device / self.non_data_parallel_size + state.last_tokens_per_second = num_tokens_per_device / step_time def on_log( self, @@ -58,5 +59,6 @@ class TokensPerSecondCallback(TrainerCallback): **kwargs, ): # pylint: disable=unused-argument # after logging, clear the running metrics - state.last_tokens_per_second.zero_() - state.num_tokens = 0 + if hasattr(state, "last_tokens_per_second"): + state.last_tokens_per_second.zero_() + state.num_tokens = torch.zeros(1) diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 32d7b68e7..e4c1fdf29 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -855,9 +855,9 @@ class AxolotlInputConfig( }, ) include_tkps: bool | None = Field( - default=None, + default=True, json_schema_extra={ - "description": "bool of whether to report tokens per second during training by measuring throughput of non-padding tokens." + "description": "bool of whether to report tokens per second per-gpu during training by measuring throughput of non-padding tokens." }, ) neftune_noise_alpha: float | None = Field(