Tokens per second logging [skip-e2e] (#3072)
This commit is contained in:
@@ -24,9 +24,7 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from transformers import (
|
||||
TrainerCallback,
|
||||
)
|
||||
from transformers import TrainerCallback
|
||||
from transformers.trainer_pt_utils import AcceleratorConfig
|
||||
|
||||
from axolotl.integrations.base import PluginManager
|
||||
@@ -38,6 +36,7 @@ from axolotl.utils.callbacks import (
|
||||
SaveModelOnFirstStepCallback,
|
||||
)
|
||||
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
|
||||
from axolotl.utils.callbacks.tokens_per_second import TokensPerSecondCallback
|
||||
from axolotl.utils.distributed import build_parallelism_config
|
||||
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
|
||||
|
||||
@@ -146,6 +145,12 @@ class TrainerBuilderBase(abc.ABC):
|
||||
profiler_steps_start=self.cfg.profiler_steps_start,
|
||||
)
|
||||
)
|
||||
if self.cfg.include_tkps:
|
||||
callbacks.append(
|
||||
TokensPerSecondCallback(
|
||||
self.cfg.tensor_parallel_size, self.cfg.context_parallel_size
|
||||
)
|
||||
)
|
||||
|
||||
return callbacks
|
||||
|
||||
@@ -512,6 +517,7 @@ class TrainerBuilderBase(abc.ABC):
|
||||
self.cfg.eval_batch_size
|
||||
)
|
||||
|
||||
training_args_kwargs["include_tkps"] = self.cfg.include_tkps
|
||||
training_args_kwargs["max_steps"] = self.cfg.max_steps or total_num_steps or -1
|
||||
training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs
|
||||
|
||||
|
||||
@@ -88,7 +88,6 @@ class AxolotlTrainer(
|
||||
self._signature_columns = None # workaround for pylint
|
||||
|
||||
super().__init__(*_args, **kwargs)
|
||||
|
||||
self.train_data_collator = self.data_collator
|
||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||
if self.args.orpo_alpha:
|
||||
@@ -337,6 +336,17 @@ class AxolotlTrainer(
|
||||
# outputs = model(**inputs)
|
||||
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
|
||||
# return (loss, outputs) if return_outputs else loss
|
||||
|
||||
# track number of tokens for tokens per second calculation
|
||||
if self.args.include_tkps:
|
||||
inputs_key = "labels" if "labels" in inputs else "input_ids"
|
||||
if hasattr(self.state, "num_tokens"):
|
||||
self.state.num_tokens = (
|
||||
self.state.num_tokens + (inputs[inputs_key] != -100).sum()
|
||||
)
|
||||
else:
|
||||
self.state.num_tokens = (inputs[inputs_key] != -100).sum()
|
||||
|
||||
if self.args.orpo_alpha:
|
||||
return self.orpo_compute_loss(
|
||||
model,
|
||||
@@ -536,9 +546,6 @@ class AxolotlTrainer(
|
||||
|
||||
super().create_accelerator_and_postprocess()
|
||||
|
||||
# now we need to put parallelism_config back on the PartialState since we rely on that info in other places
|
||||
# PartialState().parallelism_config = self.accelerator.state.parallelism_config
|
||||
|
||||
if self.is_fsdp_enabled:
|
||||
if (
|
||||
"limit_all_gathers" in self.args.fsdp_config
|
||||
@@ -586,12 +593,19 @@ class AxolotlTrainer(
|
||||
# Add memory usage
|
||||
try:
|
||||
active, allocated, reserved = get_gpu_memory_usage()
|
||||
logs["memory/max_mem_active(gib)"] = round(active, 2)
|
||||
logs["memory/max_mem_allocated(gib)"] = round(allocated, 2)
|
||||
logs["memory/device_mem_reserved(gib)"] = round(reserved, 2)
|
||||
logs["memory/max_active (GiB)"] = round(active, 2)
|
||||
logs["memory/max_allocated (GiB)"] = round(allocated, 2)
|
||||
logs["memory/device_reserved (GiB)"] = round(reserved, 2)
|
||||
except (ValueError, TypeError, FileNotFoundError):
|
||||
pass
|
||||
|
||||
if self.args.include_tkps and train_eval == "train":
|
||||
# each rank will log its own tokens per second
|
||||
# for logging_steps > 1 we obtain a moving average of this metric
|
||||
logs["tokens_per_second_per_gpu"] = round(
|
||||
self.state.last_tokens_per_second.item() / self.args.logging_steps, 2
|
||||
)
|
||||
|
||||
del self._stored_metrics[train_eval]
|
||||
|
||||
return super().log(logs, start_time)
|
||||
|
||||
@@ -49,6 +49,12 @@ class AxolotlTrainingMixins:
|
||||
default=False,
|
||||
metadata={"help": "Use real batches for efficient training."},
|
||||
)
|
||||
include_tkps: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "Whether to include tokens per second in the training metrics."
|
||||
},
|
||||
)
|
||||
eval_sample_packing: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={"help": "Use sample packing for efficient evals."},
|
||||
|
||||
@@ -60,13 +60,14 @@ def gpu_memory_usage_all(device=0):
|
||||
active = torch.cuda.memory_stats().get("active_bytes.all.peak", 0) / 1024.0**3
|
||||
allocated = torch.cuda.max_memory_allocated(device) / 1024.0**3
|
||||
reserved = torch.cuda.max_memory_reserved(device) / 1024.0**3
|
||||
torch.cuda.reset_peak_memory_stats(device)
|
||||
return active, allocated, reserved
|
||||
|
||||
|
||||
def mps_memory_usage_all():
|
||||
usage = torch.mps.current_allocated_memory() / 1024.0**3
|
||||
reserved = torch.mps.driver_allocated_memory() / 1024.0**3
|
||||
return usage, reserved - usage, 0
|
||||
active = torch.mps.current_allocated_memory() / 1024.0**3
|
||||
allocated = torch.mps.driver_allocated_memory() / 1024.0**3
|
||||
return active, allocated, 0
|
||||
|
||||
|
||||
def npu_memory_usage_all(device=0):
|
||||
|
||||
62
src/axolotl/utils/callbacks/tokens_per_second.py
Normal file
62
src/axolotl/utils/callbacks/tokens_per_second.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""A callback for calculating tokens per second during training."""
|
||||
|
||||
import time
|
||||
|
||||
import torch
|
||||
from transformers import (
|
||||
TrainerCallback,
|
||||
TrainerControl,
|
||||
TrainerState,
|
||||
TrainingArguments,
|
||||
)
|
||||
|
||||
|
||||
class TokensPerSecondCallback(TrainerCallback):
|
||||
"""
|
||||
A callback to measure and log tokens per second during training.
|
||||
"""
|
||||
|
||||
def __init__(self, tensor_parallel_size, context_parallel_size):
|
||||
super().__init__()
|
||||
self.step_time = 0.0
|
||||
self.start_time = 0.0
|
||||
self.non_data_parallel_size = 1
|
||||
if tensor_parallel_size is not None:
|
||||
self.non_data_parallel_size *= tensor_parallel_size
|
||||
if context_parallel_size is not None:
|
||||
self.non_data_parallel_size *= context_parallel_size
|
||||
|
||||
def on_step_begin(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
): # pylint: disable=unused-argument
|
||||
self.start_time = time.perf_counter()
|
||||
state.last_tokens_per_second = torch.zeros(1)
|
||||
|
||||
def on_step_end(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
): # pylint: disable=unused-argument
|
||||
step_time = time.perf_counter() - self.start_time
|
||||
num_tokens_per_device = state.num_tokens.clone()
|
||||
# non data parallel groups have duplicated tokens, so we avoid double-counting
|
||||
num_tokens_per_device = num_tokens_per_device / self.non_data_parallel_size
|
||||
state.last_tokens_per_second = num_tokens_per_device / step_time
|
||||
|
||||
def on_log(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
logs=None,
|
||||
**kwargs,
|
||||
): # pylint: disable=unused-argument
|
||||
# after logging, clear the running metrics
|
||||
state.last_tokens_per_second.zero_()
|
||||
state.num_tokens = 0
|
||||
@@ -830,10 +830,15 @@ class AxolotlInputConfig(
|
||||
include_tokens_per_second: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "bool of whether to include tokens trainer per second in the training metrics. This iterates over the entire dataset once, so it takes some time."
|
||||
"description": "bool of whether to report tokens per second at the end of training. This is not supported with pre-training datasets."
|
||||
},
|
||||
)
|
||||
include_tkps: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "bool of whether to report tokens per second during training by measuring throughput of non-padding tokens."
|
||||
},
|
||||
)
|
||||
|
||||
neftune_noise_alpha: float | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
|
||||
Reference in New Issue
Block a user