compute loss only if training and update token metric naming (#3293) [skip ci]
* compute loss only if training * save total_tokens for checkpiont * check if string * refactor total_tokens/ num_tokens * refactor 2 * rplc trainable_step/trian_per_sec_per_gpu * lint + log trainable/tokens * consolidate it in the callback. * test for total_tokes aftr remuse * check if tokenstate exist after ckpt --------- Co-authored-by: Ved <ved.work2024@gmail.com>
This commit is contained in:
@@ -72,7 +72,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.include_tkps:
|
||||
callbacks.append(
|
||||
TokensPerSecondCallback(
|
||||
self.cfg.tensor_parallel_size, self.cfg.context_parallel_size
|
||||
self.cfg.tensor_parallel_size,
|
||||
self.cfg.context_parallel_size,
|
||||
resume_from_checkpoint=self.cfg.resume_from_checkpoint,
|
||||
)
|
||||
)
|
||||
return callbacks
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
from collections import defaultdict
|
||||
@@ -50,6 +51,8 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
TOKENS_STATE_FILE = "tokens_state."
|
||||
|
||||
REDUCTION_FNS = {
|
||||
"mean": torch.mean,
|
||||
"min": torch.min,
|
||||
@@ -349,24 +352,34 @@ class AxolotlTrainer(
|
||||
# return (loss, outputs) if return_outputs else loss
|
||||
|
||||
# track number of tokens for tokens per second calculation
|
||||
if self.args.include_tkps:
|
||||
if self.args.include_tkps and model.training:
|
||||
inputs_key = "labels" if "labels" in inputs else "input_ids"
|
||||
num_tokens = (inputs[inputs_key] != -100).sum()
|
||||
trainable_tokens = (inputs[inputs_key] != -100).sum()
|
||||
total_tokens = inputs[inputs_key].numel()
|
||||
|
||||
if is_distributed():
|
||||
torch.distributed.all_reduce(
|
||||
num_tokens, op=torch.distributed.ReduceOp.SUM
|
||||
trainable_tokens, op=torch.distributed.ReduceOp.SUM
|
||||
)
|
||||
if hasattr(self.state, "num_tokens"):
|
||||
self.state.num_tokens = (
|
||||
self.state.num_tokens + (inputs[inputs_key] != -100).sum().cpu()
|
||||
torch.distributed.all_reduce(
|
||||
total_tokens, op=torch.distributed.ReduceOp.SUM
|
||||
)
|
||||
else:
|
||||
self.state.num_tokens = (inputs[inputs_key] != -100).sum().cpu()
|
||||
|
||||
if hasattr(self.state, "total_tokens"):
|
||||
self.state.total_tokens += num_tokens
|
||||
else:
|
||||
self.state.total_tokens = num_tokens
|
||||
if not hasattr(self.state, "tokens"):
|
||||
self.state.tokens = {
|
||||
"trainable": torch.zeros(1),
|
||||
"total": torch.zeros(1),
|
||||
}
|
||||
|
||||
# trainable tokens for throughput and total token slots for summaries
|
||||
self.state.tokens["trainable"] = (
|
||||
self.state.tokens["trainable"] + trainable_tokens.detach().cpu()
|
||||
)
|
||||
self.state.tokens["total"] = (
|
||||
self.state.tokens["total"] + torch.as_tensor(total_tokens).cpu()
|
||||
)
|
||||
# Store per-step trainable tokens for throughput calculation
|
||||
self.state.tokens["trainable_tokens"] = trainable_tokens.detach().cpu()
|
||||
|
||||
if self.args.orpo_alpha:
|
||||
return self.orpo_compute_loss(
|
||||
@@ -638,10 +651,14 @@ class AxolotlTrainer(
|
||||
except (ValueError, TypeError, FileNotFoundError):
|
||||
pass
|
||||
|
||||
if self.args.include_tkps and train_eval == "train":
|
||||
if (
|
||||
self.args.include_tkps
|
||||
and train_eval == "train"
|
||||
and hasattr(self.state, "tokens")
|
||||
):
|
||||
# each rank will log its own tokens per second
|
||||
# for logging_steps > 1 we obtain a moving average of this metric
|
||||
logs["tokens_per_second_per_gpu"] = round(
|
||||
logs["tokens/train_per_sec_per_gpu"] = round(
|
||||
self.state.last_tokens_per_second.item() / self.args.logging_steps, 2
|
||||
)
|
||||
if (
|
||||
@@ -683,6 +700,19 @@ class AxolotlTrainer(
|
||||
run_dir = self._get_output_dir(trial=trial)
|
||||
output_dir = os.path.join(run_dir, checkpoint_folder)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Save total_tokens state if tracking is enabled
|
||||
if self.args.include_tkps and hasattr(self.state, "tokens"):
|
||||
tokens_state = {
|
||||
"total": int(torch.as_tensor(self.state.tokens.get("total", 0)).item()),
|
||||
"trainable": int(
|
||||
torch.as_tensor(self.state.tokens.get("trainable", 0)).item()
|
||||
),
|
||||
}
|
||||
tokens_state_path = os.path.join(output_dir, TOKENS_STATE_FILE)
|
||||
with open(tokens_state_path, "w", encoding="utf-8") as f:
|
||||
json.dump(tokens_state, f)
|
||||
|
||||
return super()._save_checkpoint(model, trial, **kwargs)
|
||||
|
||||
# TODO(wing): remove once https://github.com/huggingface/transformers/pull/39866/files is merged
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""A callback for calculating tokens per second during training."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
|
||||
import torch
|
||||
@@ -10,22 +12,52 @@ from transformers import (
|
||||
TrainingArguments,
|
||||
)
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
TOKENS_STATE_FILE = "tokens_state.json"
|
||||
|
||||
|
||||
class TokensPerSecondCallback(TrainerCallback):
|
||||
"""
|
||||
A callback to measure and log tokens per second during training.
|
||||
Also handles saving/restoring total_tokens state across checkpoint resumes.
|
||||
"""
|
||||
|
||||
def __init__(self, tensor_parallel_size, context_parallel_size):
|
||||
def __init__(
|
||||
self, tensor_parallel_size, context_parallel_size, resume_from_checkpoint=None
|
||||
):
|
||||
super().__init__()
|
||||
self.step_time = 0.0
|
||||
self.start_time = 0.0
|
||||
self.non_data_parallel_size = 1
|
||||
self.resume_from_checkpoint = resume_from_checkpoint
|
||||
if tensor_parallel_size is not None:
|
||||
self.non_data_parallel_size *= tensor_parallel_size
|
||||
if context_parallel_size is not None:
|
||||
self.non_data_parallel_size *= context_parallel_size
|
||||
|
||||
def on_train_begin(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
): # pylint: disable=unused-argument
|
||||
"""Restore total_tokens state when resuming from checkpoint."""
|
||||
if not isinstance(self.resume_from_checkpoint, str):
|
||||
return
|
||||
tokens_state_path = os.path.join(self.resume_from_checkpoint, TOKENS_STATE_FILE)
|
||||
if os.path.isfile(tokens_state_path):
|
||||
with open(tokens_state_path, "r", encoding="utf-8") as f:
|
||||
tokens_state = json.load(f)
|
||||
state.tokens = {
|
||||
"total": torch.tensor(tokens_state.get("total", 0)),
|
||||
"trainable": torch.tensor(tokens_state.get("trainable", 0)),
|
||||
}
|
||||
LOG.info(f"Restored total_tokens: {state.tokens['total']}")
|
||||
|
||||
def on_step_begin(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
@@ -33,6 +65,8 @@ class TokensPerSecondCallback(TrainerCallback):
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
): # pylint: disable=unused-argument
|
||||
if not hasattr(state, "tokens"):
|
||||
state.tokens = {"trainable": torch.zeros(1), "total": torch.zeros(1)}
|
||||
self.start_time = time.perf_counter()
|
||||
state.last_tokens_per_second = torch.zeros(1)
|
||||
|
||||
@@ -43,9 +77,10 @@ class TokensPerSecondCallback(TrainerCallback):
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
): # pylint: disable=unused-argument
|
||||
if hasattr(state, "num_tokens"):
|
||||
tokens = getattr(state, "tokens", None)
|
||||
if tokens and "trainable_tokens" in tokens:
|
||||
step_time = time.perf_counter() - self.start_time
|
||||
num_tokens_per_device = state.num_tokens.clone()
|
||||
num_tokens_per_device = tokens["trainable_tokens"].clone()
|
||||
# non data parallel groups have duplicated tokens, so we avoid double-counting
|
||||
num_tokens_per_device = num_tokens_per_device / self.non_data_parallel_size
|
||||
state.last_tokens_per_second = num_tokens_per_device / step_time
|
||||
@@ -60,5 +95,15 @@ class TokensPerSecondCallback(TrainerCallback):
|
||||
): # pylint: disable=unused-argument
|
||||
# after logging, clear the running metrics
|
||||
if hasattr(state, "last_tokens_per_second"):
|
||||
logs["tokens/train_per_sec_per_gpu"] = state.last_tokens_per_second.item()
|
||||
state.last_tokens_per_second.zero_()
|
||||
state.num_tokens = torch.zeros(1)
|
||||
tokens = getattr(state, "tokens", None)
|
||||
# Clear per-step tokens after logging
|
||||
if tokens and "trainable_tokens" in tokens:
|
||||
tokens["trainable_tokens"] = torch.zeros_like(tokens["trainable_tokens"])
|
||||
|
||||
if tokens and "total" in tokens:
|
||||
logs["tokens/total"] = tokens["total"].item()
|
||||
|
||||
if tokens and "trainable" in tokens:
|
||||
logs["tokens/trainable"] = tokens["trainable"].item()
|
||||
|
||||
Reference in New Issue
Block a user