Default include_tkps to true (#3134)

* default true

* force e2e

* causal trainer only

* fix eval loggin [skip-ci]

* revert setup.py

* force tests

* guarding

* guarding

* fix test case

* use evaluate [skip-e2e]

* use evaluate [skip-e2e]

* kick off ci

* fixing

* reverting
This commit is contained in:
salman
2025-09-09 15:50:21 +01:00
committed by GitHub
parent b5d4c7ff54
commit 9640338d37
5 changed files with 20 additions and 18 deletions

View File

@@ -36,7 +36,6 @@ from axolotl.utils.callbacks import (
SaveModelOnFirstStepCallback, SaveModelOnFirstStepCallback,
) )
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
from axolotl.utils.callbacks.tokens_per_second import TokensPerSecondCallback
from axolotl.utils.distributed import build_parallelism_config from axolotl.utils.distributed import build_parallelism_config
from axolotl.utils.schemas.enums import CustomSupportedOptimizers from axolotl.utils.schemas.enums import CustomSupportedOptimizers
@@ -145,12 +144,6 @@ class TrainerBuilderBase(abc.ABC):
profiler_steps_start=self.cfg.profiler_steps_start, profiler_steps_start=self.cfg.profiler_steps_start,
) )
) )
if self.cfg.include_tkps:
callbacks.append(
TokensPerSecondCallback(
self.cfg.tensor_parallel_size, self.cfg.context_parallel_size
)
)
return callbacks return callbacks

View File

@@ -39,6 +39,7 @@ from axolotl.utils.collators import (
MambaDataCollator, MambaDataCollator,
V2BatchSamplerDataCollatorForSeq2Seq, V2BatchSamplerDataCollatorForSeq2Seq,
) )
from axolotl.utils.callbacks.tokens_per_second import TokensPerSecondCallback
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
from axolotl.utils.import_helper import get_cls_from_module_str from axolotl.utils.import_helper import get_cls_from_module_str
from axolotl.utils.logging import get_logger from axolotl.utils.logging import get_logger
@@ -71,6 +72,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.qat: if self.cfg.qat:
callbacks.append(QATCallback(self.cfg.qat)) callbacks.append(QATCallback(self.cfg.qat))
if self.cfg.include_tkps:
callbacks.append(
TokensPerSecondCallback(
self.cfg.tensor_parallel_size, self.cfg.context_parallel_size
)
)
return callbacks return callbacks
def get_post_trainer_create_callbacks(self, trainer): def get_post_trainer_create_callbacks(self, trainer):

View File

@@ -342,10 +342,10 @@ class AxolotlTrainer(
inputs_key = "labels" if "labels" in inputs else "input_ids" inputs_key = "labels" if "labels" in inputs else "input_ids"
if hasattr(self.state, "num_tokens"): if hasattr(self.state, "num_tokens"):
self.state.num_tokens = ( self.state.num_tokens = (
self.state.num_tokens + (inputs[inputs_key] != -100).sum() self.state.num_tokens + (inputs[inputs_key] != -100).sum().cpu()
) )
else: else:
self.state.num_tokens = (inputs[inputs_key] != -100).sum() self.state.num_tokens = (inputs[inputs_key] != -100).sum().cpu()
if self.args.orpo_alpha: if self.args.orpo_alpha:
return self.orpo_compute_loss( return self.orpo_compute_loss(

View File

@@ -43,11 +43,12 @@ class TokensPerSecondCallback(TrainerCallback):
control: TrainerControl, control: TrainerControl,
**kwargs, **kwargs,
): # pylint: disable=unused-argument ): # pylint: disable=unused-argument
step_time = time.perf_counter() - self.start_time if hasattr(state, "num_tokens"):
num_tokens_per_device = state.num_tokens.clone() step_time = time.perf_counter() - self.start_time
# non data parallel groups have duplicated tokens, so we avoid double-counting num_tokens_per_device = state.num_tokens.clone()
num_tokens_per_device = num_tokens_per_device / self.non_data_parallel_size # non data parallel groups have duplicated tokens, so we avoid double-counting
state.last_tokens_per_second = num_tokens_per_device / step_time num_tokens_per_device = num_tokens_per_device / self.non_data_parallel_size
state.last_tokens_per_second = num_tokens_per_device / step_time
def on_log( def on_log(
self, self,
@@ -58,5 +59,6 @@ class TokensPerSecondCallback(TrainerCallback):
**kwargs, **kwargs,
): # pylint: disable=unused-argument ): # pylint: disable=unused-argument
# after logging, clear the running metrics # after logging, clear the running metrics
state.last_tokens_per_second.zero_() if hasattr(state, "last_tokens_per_second"):
state.num_tokens = 0 state.last_tokens_per_second.zero_()
state.num_tokens = torch.zeros(1)

View File

@@ -855,9 +855,9 @@ class AxolotlInputConfig(
}, },
) )
include_tkps: bool | None = Field( include_tkps: bool | None = Field(
default=None, default=True,
json_schema_extra={ json_schema_extra={
"description": "bool of whether to report tokens per second during training by measuring throughput of non-padding tokens." "description": "bool of whether to report tokens per second per-gpu during training by measuring throughput of non-padding tokens."
}, },
) )
neftune_noise_alpha: float | None = Field( neftune_noise_alpha: float | None = Field(