diff --git a/src/axolotl/core/builders/base.py b/src/axolotl/core/builders/base.py index 44699e6ac..bee291fa2 100644 --- a/src/axolotl/core/builders/base.py +++ b/src/axolotl/core/builders/base.py @@ -24,9 +24,7 @@ from pathlib import Path from typing import Any import torch -from transformers import ( - TrainerCallback, -) +from transformers import TrainerCallback from transformers.trainer_pt_utils import AcceleratorConfig from axolotl.integrations.base import PluginManager @@ -38,6 +36,7 @@ from axolotl.utils.callbacks import ( SaveModelOnFirstStepCallback, ) from axolotl.utils.callbacks.profiler import PytorchProfilerCallback +from axolotl.utils.callbacks.tokens_per_second import TokensPerSecondCallback from axolotl.utils.distributed import build_parallelism_config from axolotl.utils.schemas.enums import CustomSupportedOptimizers @@ -146,6 +145,12 @@ class TrainerBuilderBase(abc.ABC): profiler_steps_start=self.cfg.profiler_steps_start, ) ) + if self.cfg.include_tkps: + callbacks.append( + TokensPerSecondCallback( + self.cfg.tensor_parallel_size, self.cfg.context_parallel_size + ) + ) return callbacks @@ -512,6 +517,7 @@ class TrainerBuilderBase(abc.ABC): self.cfg.eval_batch_size ) + training_args_kwargs["include_tkps"] = self.cfg.include_tkps training_args_kwargs["max_steps"] = self.cfg.max_steps or total_num_steps or -1 training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index f707d4b5a..06eef445b 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -88,7 +88,6 @@ class AxolotlTrainer( self._signature_columns = None # workaround for pylint super().__init__(*_args, **kwargs) - self.train_data_collator = self.data_collator self._stored_metrics = defaultdict(lambda: defaultdict(list)) if self.args.orpo_alpha: @@ -337,6 +336,17 @@ class AxolotlTrainer( # outputs = model(**inputs) # loss = trainer_weighted_loss(outputs, labels, shift_labels=True) # return (loss, outputs) if return_outputs else loss + + # track number of tokens for tokens per second calculation + if self.args.include_tkps: + inputs_key = "labels" if "labels" in inputs else "input_ids" + if hasattr(self.state, "num_tokens"): + self.state.num_tokens = ( + self.state.num_tokens + (inputs[inputs_key] != -100).sum() + ) + else: + self.state.num_tokens = (inputs[inputs_key] != -100).sum() + if self.args.orpo_alpha: return self.orpo_compute_loss( model, @@ -536,9 +546,6 @@ class AxolotlTrainer( super().create_accelerator_and_postprocess() - # now we need to put parallelism_config back on the PartialState since we rely on that info in other places - # PartialState().parallelism_config = self.accelerator.state.parallelism_config - if self.is_fsdp_enabled: if ( "limit_all_gathers" in self.args.fsdp_config @@ -586,12 +593,19 @@ class AxolotlTrainer( # Add memory usage try: active, allocated, reserved = get_gpu_memory_usage() - logs["memory/max_mem_active(gib)"] = round(active, 2) - logs["memory/max_mem_allocated(gib)"] = round(allocated, 2) - logs["memory/device_mem_reserved(gib)"] = round(reserved, 2) + logs["memory/max_active (GiB)"] = round(active, 2) + logs["memory/max_allocated (GiB)"] = round(allocated, 2) + logs["memory/device_reserved (GiB)"] = round(reserved, 2) except (ValueError, TypeError, FileNotFoundError): pass + if self.args.include_tkps and train_eval == "train": + # each rank will log its own tokens per second + # for logging_steps > 1 we obtain a moving average of this metric + logs["tokens_per_second_per_gpu"] = round( + self.state.last_tokens_per_second.item() / self.args.logging_steps, 2 + ) + del self._stored_metrics[train_eval] return super().log(logs, start_time) diff --git a/src/axolotl/core/training_args_base.py b/src/axolotl/core/training_args_base.py index a9cc7d224..41ee8e91e 100644 --- a/src/axolotl/core/training_args_base.py +++ b/src/axolotl/core/training_args_base.py @@ -49,6 +49,12 @@ class AxolotlTrainingMixins: default=False, metadata={"help": "Use real batches for efficient training."}, ) + include_tkps: bool = field( + default=True, + metadata={ + "help": "Whether to include tokens per second in the training metrics." + }, + ) eval_sample_packing: Optional[bool] = field( default=None, metadata={"help": "Use sample packing for efficient evals."}, diff --git a/src/axolotl/utils/bench.py b/src/axolotl/utils/bench.py index dd3a85b8c..0a4594991 100644 --- a/src/axolotl/utils/bench.py +++ b/src/axolotl/utils/bench.py @@ -60,13 +60,14 @@ def gpu_memory_usage_all(device=0): active = torch.cuda.memory_stats().get("active_bytes.all.peak", 0) / 1024.0**3 allocated = torch.cuda.max_memory_allocated(device) / 1024.0**3 reserved = torch.cuda.max_memory_reserved(device) / 1024.0**3 + torch.cuda.reset_peak_memory_stats(device) return active, allocated, reserved def mps_memory_usage_all(): - usage = torch.mps.current_allocated_memory() / 1024.0**3 - reserved = torch.mps.driver_allocated_memory() / 1024.0**3 - return usage, reserved - usage, 0 + active = torch.mps.current_allocated_memory() / 1024.0**3 + allocated = torch.mps.driver_allocated_memory() / 1024.0**3 + return active, allocated, 0 def npu_memory_usage_all(device=0): diff --git a/src/axolotl/utils/callbacks/tokens_per_second.py b/src/axolotl/utils/callbacks/tokens_per_second.py new file mode 100644 index 000000000..85bcd5041 --- /dev/null +++ b/src/axolotl/utils/callbacks/tokens_per_second.py @@ -0,0 +1,62 @@ +"""A callback for calculating tokens per second during training.""" + +import time + +import torch +from transformers import ( + TrainerCallback, + TrainerControl, + TrainerState, + TrainingArguments, +) + + +class TokensPerSecondCallback(TrainerCallback): + """ + A callback to measure and log tokens per second during training. + """ + + def __init__(self, tensor_parallel_size, context_parallel_size): + super().__init__() + self.step_time = 0.0 + self.start_time = 0.0 + self.non_data_parallel_size = 1 + if tensor_parallel_size is not None: + self.non_data_parallel_size *= tensor_parallel_size + if context_parallel_size is not None: + self.non_data_parallel_size *= context_parallel_size + + def on_step_begin( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): # pylint: disable=unused-argument + self.start_time = time.perf_counter() + state.last_tokens_per_second = torch.zeros(1) + + def on_step_end( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): # pylint: disable=unused-argument + step_time = time.perf_counter() - self.start_time + num_tokens_per_device = state.num_tokens.clone() + # non data parallel groups have duplicated tokens, so we avoid double-counting + num_tokens_per_device = num_tokens_per_device / self.non_data_parallel_size + state.last_tokens_per_second = num_tokens_per_device / step_time + + def on_log( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + logs=None, + **kwargs, + ): # pylint: disable=unused-argument + # after logging, clear the running metrics + state.last_tokens_per_second.zero_() + state.num_tokens = 0 diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 4d660d4b7..4b5f571dc 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -830,10 +830,15 @@ class AxolotlInputConfig( include_tokens_per_second: bool | None = Field( default=None, json_schema_extra={ - "description": "bool of whether to include tokens trainer per second in the training metrics. This iterates over the entire dataset once, so it takes some time." + "description": "bool of whether to report tokens per second at the end of training. This is not supported with pre-training datasets." + }, + ) + include_tkps: bool | None = Field( + default=None, + json_schema_extra={ + "description": "bool of whether to report tokens per second during training by measuring throughput of non-padding tokens." }, ) - neftune_noise_alpha: float | None = Field( default=None, json_schema_extra={