Tokens per second logging [skip-e2e] (#3072)
This commit is contained in:
@@ -24,9 +24,7 @@ from pathlib import Path
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import (
|
from transformers import TrainerCallback
|
||||||
TrainerCallback,
|
|
||||||
)
|
|
||||||
from transformers.trainer_pt_utils import AcceleratorConfig
|
from transformers.trainer_pt_utils import AcceleratorConfig
|
||||||
|
|
||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
@@ -38,6 +36,7 @@ from axolotl.utils.callbacks import (
|
|||||||
SaveModelOnFirstStepCallback,
|
SaveModelOnFirstStepCallback,
|
||||||
)
|
)
|
||||||
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
|
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
|
||||||
|
from axolotl.utils.callbacks.tokens_per_second import TokensPerSecondCallback
|
||||||
from axolotl.utils.distributed import build_parallelism_config
|
from axolotl.utils.distributed import build_parallelism_config
|
||||||
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
|
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
|
||||||
|
|
||||||
@@ -146,6 +145,12 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
profiler_steps_start=self.cfg.profiler_steps_start,
|
profiler_steps_start=self.cfg.profiler_steps_start,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
if self.cfg.include_tkps:
|
||||||
|
callbacks.append(
|
||||||
|
TokensPerSecondCallback(
|
||||||
|
self.cfg.tensor_parallel_size, self.cfg.context_parallel_size
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
@@ -512,6 +517,7 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
self.cfg.eval_batch_size
|
self.cfg.eval_batch_size
|
||||||
)
|
)
|
||||||
|
|
||||||
|
training_args_kwargs["include_tkps"] = self.cfg.include_tkps
|
||||||
training_args_kwargs["max_steps"] = self.cfg.max_steps or total_num_steps or -1
|
training_args_kwargs["max_steps"] = self.cfg.max_steps or total_num_steps or -1
|
||||||
training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs
|
training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs
|
||||||
|
|
||||||
|
|||||||
@@ -88,7 +88,6 @@ class AxolotlTrainer(
|
|||||||
self._signature_columns = None # workaround for pylint
|
self._signature_columns = None # workaround for pylint
|
||||||
|
|
||||||
super().__init__(*_args, **kwargs)
|
super().__init__(*_args, **kwargs)
|
||||||
|
|
||||||
self.train_data_collator = self.data_collator
|
self.train_data_collator = self.data_collator
|
||||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||||
if self.args.orpo_alpha:
|
if self.args.orpo_alpha:
|
||||||
@@ -337,6 +336,17 @@ class AxolotlTrainer(
|
|||||||
# outputs = model(**inputs)
|
# outputs = model(**inputs)
|
||||||
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
|
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
|
||||||
# return (loss, outputs) if return_outputs else loss
|
# return (loss, outputs) if return_outputs else loss
|
||||||
|
|
||||||
|
# track number of tokens for tokens per second calculation
|
||||||
|
if self.args.include_tkps:
|
||||||
|
inputs_key = "labels" if "labels" in inputs else "input_ids"
|
||||||
|
if hasattr(self.state, "num_tokens"):
|
||||||
|
self.state.num_tokens = (
|
||||||
|
self.state.num_tokens + (inputs[inputs_key] != -100).sum()
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.state.num_tokens = (inputs[inputs_key] != -100).sum()
|
||||||
|
|
||||||
if self.args.orpo_alpha:
|
if self.args.orpo_alpha:
|
||||||
return self.orpo_compute_loss(
|
return self.orpo_compute_loss(
|
||||||
model,
|
model,
|
||||||
@@ -536,9 +546,6 @@ class AxolotlTrainer(
|
|||||||
|
|
||||||
super().create_accelerator_and_postprocess()
|
super().create_accelerator_and_postprocess()
|
||||||
|
|
||||||
# now we need to put parallelism_config back on the PartialState since we rely on that info in other places
|
|
||||||
# PartialState().parallelism_config = self.accelerator.state.parallelism_config
|
|
||||||
|
|
||||||
if self.is_fsdp_enabled:
|
if self.is_fsdp_enabled:
|
||||||
if (
|
if (
|
||||||
"limit_all_gathers" in self.args.fsdp_config
|
"limit_all_gathers" in self.args.fsdp_config
|
||||||
@@ -586,12 +593,19 @@ class AxolotlTrainer(
|
|||||||
# Add memory usage
|
# Add memory usage
|
||||||
try:
|
try:
|
||||||
active, allocated, reserved = get_gpu_memory_usage()
|
active, allocated, reserved = get_gpu_memory_usage()
|
||||||
logs["memory/max_mem_active(gib)"] = round(active, 2)
|
logs["memory/max_active (GiB)"] = round(active, 2)
|
||||||
logs["memory/max_mem_allocated(gib)"] = round(allocated, 2)
|
logs["memory/max_allocated (GiB)"] = round(allocated, 2)
|
||||||
logs["memory/device_mem_reserved(gib)"] = round(reserved, 2)
|
logs["memory/device_reserved (GiB)"] = round(reserved, 2)
|
||||||
except (ValueError, TypeError, FileNotFoundError):
|
except (ValueError, TypeError, FileNotFoundError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
if self.args.include_tkps and train_eval == "train":
|
||||||
|
# each rank will log its own tokens per second
|
||||||
|
# for logging_steps > 1 we obtain a moving average of this metric
|
||||||
|
logs["tokens_per_second_per_gpu"] = round(
|
||||||
|
self.state.last_tokens_per_second.item() / self.args.logging_steps, 2
|
||||||
|
)
|
||||||
|
|
||||||
del self._stored_metrics[train_eval]
|
del self._stored_metrics[train_eval]
|
||||||
|
|
||||||
return super().log(logs, start_time)
|
return super().log(logs, start_time)
|
||||||
|
|||||||
@@ -49,6 +49,12 @@ class AxolotlTrainingMixins:
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Use real batches for efficient training."},
|
metadata={"help": "Use real batches for efficient training."},
|
||||||
)
|
)
|
||||||
|
include_tkps: bool = field(
|
||||||
|
default=True,
|
||||||
|
metadata={
|
||||||
|
"help": "Whether to include tokens per second in the training metrics."
|
||||||
|
},
|
||||||
|
)
|
||||||
eval_sample_packing: Optional[bool] = field(
|
eval_sample_packing: Optional[bool] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Use sample packing for efficient evals."},
|
metadata={"help": "Use sample packing for efficient evals."},
|
||||||
|
|||||||
@@ -60,13 +60,14 @@ def gpu_memory_usage_all(device=0):
|
|||||||
active = torch.cuda.memory_stats().get("active_bytes.all.peak", 0) / 1024.0**3
|
active = torch.cuda.memory_stats().get("active_bytes.all.peak", 0) / 1024.0**3
|
||||||
allocated = torch.cuda.max_memory_allocated(device) / 1024.0**3
|
allocated = torch.cuda.max_memory_allocated(device) / 1024.0**3
|
||||||
reserved = torch.cuda.max_memory_reserved(device) / 1024.0**3
|
reserved = torch.cuda.max_memory_reserved(device) / 1024.0**3
|
||||||
|
torch.cuda.reset_peak_memory_stats(device)
|
||||||
return active, allocated, reserved
|
return active, allocated, reserved
|
||||||
|
|
||||||
|
|
||||||
def mps_memory_usage_all():
|
def mps_memory_usage_all():
|
||||||
usage = torch.mps.current_allocated_memory() / 1024.0**3
|
active = torch.mps.current_allocated_memory() / 1024.0**3
|
||||||
reserved = torch.mps.driver_allocated_memory() / 1024.0**3
|
allocated = torch.mps.driver_allocated_memory() / 1024.0**3
|
||||||
return usage, reserved - usage, 0
|
return active, allocated, 0
|
||||||
|
|
||||||
|
|
||||||
def npu_memory_usage_all(device=0):
|
def npu_memory_usage_all(device=0):
|
||||||
|
|||||||
62
src/axolotl/utils/callbacks/tokens_per_second.py
Normal file
62
src/axolotl/utils/callbacks/tokens_per_second.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
"""A callback for calculating tokens per second during training."""
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import (
|
||||||
|
TrainerCallback,
|
||||||
|
TrainerControl,
|
||||||
|
TrainerState,
|
||||||
|
TrainingArguments,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TokensPerSecondCallback(TrainerCallback):
|
||||||
|
"""
|
||||||
|
A callback to measure and log tokens per second during training.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, tensor_parallel_size, context_parallel_size):
|
||||||
|
super().__init__()
|
||||||
|
self.step_time = 0.0
|
||||||
|
self.start_time = 0.0
|
||||||
|
self.non_data_parallel_size = 1
|
||||||
|
if tensor_parallel_size is not None:
|
||||||
|
self.non_data_parallel_size *= tensor_parallel_size
|
||||||
|
if context_parallel_size is not None:
|
||||||
|
self.non_data_parallel_size *= context_parallel_size
|
||||||
|
|
||||||
|
def on_step_begin(
|
||||||
|
self,
|
||||||
|
args: TrainingArguments,
|
||||||
|
state: TrainerState,
|
||||||
|
control: TrainerControl,
|
||||||
|
**kwargs,
|
||||||
|
): # pylint: disable=unused-argument
|
||||||
|
self.start_time = time.perf_counter()
|
||||||
|
state.last_tokens_per_second = torch.zeros(1)
|
||||||
|
|
||||||
|
def on_step_end(
|
||||||
|
self,
|
||||||
|
args: TrainingArguments,
|
||||||
|
state: TrainerState,
|
||||||
|
control: TrainerControl,
|
||||||
|
**kwargs,
|
||||||
|
): # pylint: disable=unused-argument
|
||||||
|
step_time = time.perf_counter() - self.start_time
|
||||||
|
num_tokens_per_device = state.num_tokens.clone()
|
||||||
|
# non data parallel groups have duplicated tokens, so we avoid double-counting
|
||||||
|
num_tokens_per_device = num_tokens_per_device / self.non_data_parallel_size
|
||||||
|
state.last_tokens_per_second = num_tokens_per_device / step_time
|
||||||
|
|
||||||
|
def on_log(
|
||||||
|
self,
|
||||||
|
args: TrainingArguments,
|
||||||
|
state: TrainerState,
|
||||||
|
control: TrainerControl,
|
||||||
|
logs=None,
|
||||||
|
**kwargs,
|
||||||
|
): # pylint: disable=unused-argument
|
||||||
|
# after logging, clear the running metrics
|
||||||
|
state.last_tokens_per_second.zero_()
|
||||||
|
state.num_tokens = 0
|
||||||
@@ -830,10 +830,15 @@ class AxolotlInputConfig(
|
|||||||
include_tokens_per_second: bool | None = Field(
|
include_tokens_per_second: bool | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "bool of whether to include tokens trainer per second in the training metrics. This iterates over the entire dataset once, so it takes some time."
|
"description": "bool of whether to report tokens per second at the end of training. This is not supported with pre-training datasets."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
include_tkps: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "bool of whether to report tokens per second during training by measuring throughput of non-padding tokens."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
neftune_noise_alpha: float | None = Field(
|
neftune_noise_alpha: float | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
|
|||||||
Reference in New Issue
Block a user