From a98526ef7843a3e8aa006f260e6b4fb8912b5f1a Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 13 Feb 2025 17:39:19 -0500 Subject: [PATCH] add support for include_tokens_per_second in training args (#2269) * add support for include_tokens_per_second in training args * Update docs/config.qmd Co-authored-by: NanoCode012 * Update src/axolotl/core/trainer_builder.py Co-authored-by: NanoCode012 --------- Co-authored-by: NanoCode012 --- docs/config.qmd | 3 +++ src/axolotl/core/trainer_builder.py | 6 ++++++ src/axolotl/utils/config/models/input/v0_4_1/__init__.py | 1 + 3 files changed, 10 insertions(+) diff --git a/docs/config.qmd b/docs/config.qmd index 5221cbe7d..a7a150862 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -390,6 +390,9 @@ save_total_limit: # Checkpoints saved at a time # e.g., when 1 epoch is 1000 steps => `num_epochs: 2` and `max_steps: 100` will train for 100 steps max_steps: +# bool of whether to include tokens trainer per second in the training metrics. This iterates over the entire dataset once, so it takes some time. +include_tokens_per_second: + eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0 eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128 eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", "chrf", "perplexity"] diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index ae757cf43..12346b8a2 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -330,6 +330,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): ) training_arguments_kwargs = {} + + if self.cfg.include_tokens_per_second is not None: + training_arguments_kwargs[ + "include_tokens_per_second" + ] = self.cfg.include_tokens_per_second + if self.cfg.bf16 == "full": training_arguments_kwargs["bf16_full_eval"] = True else: diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 1f6fdc612..1e7a6aa8b 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -844,6 +844,7 @@ class AxolotlInputConfig( save_only_model: Optional[bool] = False use_tensorboard: Optional[bool] = None profiler_steps: Optional[int] = None + include_tokens_per_second: Optional[bool] = None neftune_noise_alpha: Optional[float] = None