compute loss only if training and update token metric naming (#3293) [skip ci]
* compute loss only if training * save total_tokens for checkpiont * check if string * refactor total_tokens/ num_tokens * refactor 2 * rplc trainable_step/trian_per_sec_per_gpu * lint + log trainable/tokens * consolidate it in the callback. * test for total_tokes aftr remuse * check if tokenstate exist after ckpt --------- Co-authored-by: Ved <ved.work2024@gmail.com>
This commit is contained in:
@@ -72,7 +72,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.include_tkps:
|
if self.cfg.include_tkps:
|
||||||
callbacks.append(
|
callbacks.append(
|
||||||
TokensPerSecondCallback(
|
TokensPerSecondCallback(
|
||||||
self.cfg.tensor_parallel_size, self.cfg.context_parallel_size
|
self.cfg.tensor_parallel_size,
|
||||||
|
self.cfg.context_parallel_size,
|
||||||
|
resume_from_checkpoint=self.cfg.resume_from_checkpoint,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
@@ -50,6 +51,8 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
|||||||
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
TOKENS_STATE_FILE = "tokens_state."
|
||||||
|
|
||||||
REDUCTION_FNS = {
|
REDUCTION_FNS = {
|
||||||
"mean": torch.mean,
|
"mean": torch.mean,
|
||||||
"min": torch.min,
|
"min": torch.min,
|
||||||
@@ -349,24 +352,34 @@ class AxolotlTrainer(
|
|||||||
# return (loss, outputs) if return_outputs else loss
|
# return (loss, outputs) if return_outputs else loss
|
||||||
|
|
||||||
# track number of tokens for tokens per second calculation
|
# track number of tokens for tokens per second calculation
|
||||||
if self.args.include_tkps:
|
if self.args.include_tkps and model.training:
|
||||||
inputs_key = "labels" if "labels" in inputs else "input_ids"
|
inputs_key = "labels" if "labels" in inputs else "input_ids"
|
||||||
num_tokens = (inputs[inputs_key] != -100).sum()
|
trainable_tokens = (inputs[inputs_key] != -100).sum()
|
||||||
|
total_tokens = inputs[inputs_key].numel()
|
||||||
|
|
||||||
if is_distributed():
|
if is_distributed():
|
||||||
torch.distributed.all_reduce(
|
torch.distributed.all_reduce(
|
||||||
num_tokens, op=torch.distributed.ReduceOp.SUM
|
trainable_tokens, op=torch.distributed.ReduceOp.SUM
|
||||||
)
|
)
|
||||||
if hasattr(self.state, "num_tokens"):
|
torch.distributed.all_reduce(
|
||||||
self.state.num_tokens = (
|
total_tokens, op=torch.distributed.ReduceOp.SUM
|
||||||
self.state.num_tokens + (inputs[inputs_key] != -100).sum().cpu()
|
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
self.state.num_tokens = (inputs[inputs_key] != -100).sum().cpu()
|
|
||||||
|
|
||||||
if hasattr(self.state, "total_tokens"):
|
if not hasattr(self.state, "tokens"):
|
||||||
self.state.total_tokens += num_tokens
|
self.state.tokens = {
|
||||||
else:
|
"trainable": torch.zeros(1),
|
||||||
self.state.total_tokens = num_tokens
|
"total": torch.zeros(1),
|
||||||
|
}
|
||||||
|
|
||||||
|
# trainable tokens for throughput and total token slots for summaries
|
||||||
|
self.state.tokens["trainable"] = (
|
||||||
|
self.state.tokens["trainable"] + trainable_tokens.detach().cpu()
|
||||||
|
)
|
||||||
|
self.state.tokens["total"] = (
|
||||||
|
self.state.tokens["total"] + torch.as_tensor(total_tokens).cpu()
|
||||||
|
)
|
||||||
|
# Store per-step trainable tokens for throughput calculation
|
||||||
|
self.state.tokens["trainable_tokens"] = trainable_tokens.detach().cpu()
|
||||||
|
|
||||||
if self.args.orpo_alpha:
|
if self.args.orpo_alpha:
|
||||||
return self.orpo_compute_loss(
|
return self.orpo_compute_loss(
|
||||||
@@ -638,10 +651,14 @@ class AxolotlTrainer(
|
|||||||
except (ValueError, TypeError, FileNotFoundError):
|
except (ValueError, TypeError, FileNotFoundError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if self.args.include_tkps and train_eval == "train":
|
if (
|
||||||
|
self.args.include_tkps
|
||||||
|
and train_eval == "train"
|
||||||
|
and hasattr(self.state, "tokens")
|
||||||
|
):
|
||||||
# each rank will log its own tokens per second
|
# each rank will log its own tokens per second
|
||||||
# for logging_steps > 1 we obtain a moving average of this metric
|
# for logging_steps > 1 we obtain a moving average of this metric
|
||||||
logs["tokens_per_second_per_gpu"] = round(
|
logs["tokens/train_per_sec_per_gpu"] = round(
|
||||||
self.state.last_tokens_per_second.item() / self.args.logging_steps, 2
|
self.state.last_tokens_per_second.item() / self.args.logging_steps, 2
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
@@ -683,6 +700,19 @@ class AxolotlTrainer(
|
|||||||
run_dir = self._get_output_dir(trial=trial)
|
run_dir = self._get_output_dir(trial=trial)
|
||||||
output_dir = os.path.join(run_dir, checkpoint_folder)
|
output_dir = os.path.join(run_dir, checkpoint_folder)
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# Save total_tokens state if tracking is enabled
|
||||||
|
if self.args.include_tkps and hasattr(self.state, "tokens"):
|
||||||
|
tokens_state = {
|
||||||
|
"total": int(torch.as_tensor(self.state.tokens.get("total", 0)).item()),
|
||||||
|
"trainable": int(
|
||||||
|
torch.as_tensor(self.state.tokens.get("trainable", 0)).item()
|
||||||
|
),
|
||||||
|
}
|
||||||
|
tokens_state_path = os.path.join(output_dir, TOKENS_STATE_FILE)
|
||||||
|
with open(tokens_state_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(tokens_state, f)
|
||||||
|
|
||||||
return super()._save_checkpoint(model, trial, **kwargs)
|
return super()._save_checkpoint(model, trial, **kwargs)
|
||||||
|
|
||||||
# TODO(wing): remove once https://github.com/huggingface/transformers/pull/39866/files is merged
|
# TODO(wing): remove once https://github.com/huggingface/transformers/pull/39866/files is merged
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
"""A callback for calculating tokens per second during training."""
|
"""A callback for calculating tokens per second during training."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -10,22 +12,52 @@ from transformers import (
|
|||||||
TrainingArguments,
|
TrainingArguments,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
TOKENS_STATE_FILE = "tokens_state.json"
|
||||||
|
|
||||||
|
|
||||||
class TokensPerSecondCallback(TrainerCallback):
|
class TokensPerSecondCallback(TrainerCallback):
|
||||||
"""
|
"""
|
||||||
A callback to measure and log tokens per second during training.
|
A callback to measure and log tokens per second during training.
|
||||||
|
Also handles saving/restoring total_tokens state across checkpoint resumes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, tensor_parallel_size, context_parallel_size):
|
def __init__(
|
||||||
|
self, tensor_parallel_size, context_parallel_size, resume_from_checkpoint=None
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.step_time = 0.0
|
self.step_time = 0.0
|
||||||
self.start_time = 0.0
|
self.start_time = 0.0
|
||||||
self.non_data_parallel_size = 1
|
self.non_data_parallel_size = 1
|
||||||
|
self.resume_from_checkpoint = resume_from_checkpoint
|
||||||
if tensor_parallel_size is not None:
|
if tensor_parallel_size is not None:
|
||||||
self.non_data_parallel_size *= tensor_parallel_size
|
self.non_data_parallel_size *= tensor_parallel_size
|
||||||
if context_parallel_size is not None:
|
if context_parallel_size is not None:
|
||||||
self.non_data_parallel_size *= context_parallel_size
|
self.non_data_parallel_size *= context_parallel_size
|
||||||
|
|
||||||
|
def on_train_begin(
|
||||||
|
self,
|
||||||
|
args: TrainingArguments,
|
||||||
|
state: TrainerState,
|
||||||
|
control: TrainerControl,
|
||||||
|
**kwargs,
|
||||||
|
): # pylint: disable=unused-argument
|
||||||
|
"""Restore total_tokens state when resuming from checkpoint."""
|
||||||
|
if not isinstance(self.resume_from_checkpoint, str):
|
||||||
|
return
|
||||||
|
tokens_state_path = os.path.join(self.resume_from_checkpoint, TOKENS_STATE_FILE)
|
||||||
|
if os.path.isfile(tokens_state_path):
|
||||||
|
with open(tokens_state_path, "r", encoding="utf-8") as f:
|
||||||
|
tokens_state = json.load(f)
|
||||||
|
state.tokens = {
|
||||||
|
"total": torch.tensor(tokens_state.get("total", 0)),
|
||||||
|
"trainable": torch.tensor(tokens_state.get("trainable", 0)),
|
||||||
|
}
|
||||||
|
LOG.info(f"Restored total_tokens: {state.tokens['total']}")
|
||||||
|
|
||||||
def on_step_begin(
|
def on_step_begin(
|
||||||
self,
|
self,
|
||||||
args: TrainingArguments,
|
args: TrainingArguments,
|
||||||
@@ -33,6 +65,8 @@ class TokensPerSecondCallback(TrainerCallback):
|
|||||||
control: TrainerControl,
|
control: TrainerControl,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
): # pylint: disable=unused-argument
|
): # pylint: disable=unused-argument
|
||||||
|
if not hasattr(state, "tokens"):
|
||||||
|
state.tokens = {"trainable": torch.zeros(1), "total": torch.zeros(1)}
|
||||||
self.start_time = time.perf_counter()
|
self.start_time = time.perf_counter()
|
||||||
state.last_tokens_per_second = torch.zeros(1)
|
state.last_tokens_per_second = torch.zeros(1)
|
||||||
|
|
||||||
@@ -43,9 +77,10 @@ class TokensPerSecondCallback(TrainerCallback):
|
|||||||
control: TrainerControl,
|
control: TrainerControl,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
): # pylint: disable=unused-argument
|
): # pylint: disable=unused-argument
|
||||||
if hasattr(state, "num_tokens"):
|
tokens = getattr(state, "tokens", None)
|
||||||
|
if tokens and "trainable_tokens" in tokens:
|
||||||
step_time = time.perf_counter() - self.start_time
|
step_time = time.perf_counter() - self.start_time
|
||||||
num_tokens_per_device = state.num_tokens.clone()
|
num_tokens_per_device = tokens["trainable_tokens"].clone()
|
||||||
# non data parallel groups have duplicated tokens, so we avoid double-counting
|
# non data parallel groups have duplicated tokens, so we avoid double-counting
|
||||||
num_tokens_per_device = num_tokens_per_device / self.non_data_parallel_size
|
num_tokens_per_device = num_tokens_per_device / self.non_data_parallel_size
|
||||||
state.last_tokens_per_second = num_tokens_per_device / step_time
|
state.last_tokens_per_second = num_tokens_per_device / step_time
|
||||||
@@ -60,5 +95,15 @@ class TokensPerSecondCallback(TrainerCallback):
|
|||||||
): # pylint: disable=unused-argument
|
): # pylint: disable=unused-argument
|
||||||
# after logging, clear the running metrics
|
# after logging, clear the running metrics
|
||||||
if hasattr(state, "last_tokens_per_second"):
|
if hasattr(state, "last_tokens_per_second"):
|
||||||
|
logs["tokens/train_per_sec_per_gpu"] = state.last_tokens_per_second.item()
|
||||||
state.last_tokens_per_second.zero_()
|
state.last_tokens_per_second.zero_()
|
||||||
state.num_tokens = torch.zeros(1)
|
tokens = getattr(state, "tokens", None)
|
||||||
|
# Clear per-step tokens after logging
|
||||||
|
if tokens and "trainable_tokens" in tokens:
|
||||||
|
tokens["trainable_tokens"] = torch.zeros_like(tokens["trainable_tokens"])
|
||||||
|
|
||||||
|
if tokens and "total" in tokens:
|
||||||
|
logs["tokens/total"] = tokens["total"].item()
|
||||||
|
|
||||||
|
if tokens and "trainable" in tokens:
|
||||||
|
logs["tokens/trainable"] = tokens["trainable"].item()
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
E2E tests for resuming training
|
E2E tests for resuming training
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
import re
|
import re
|
||||||
import subprocess
|
import subprocess
|
||||||
|
|
||||||
@@ -9,6 +10,7 @@ from transformers.utils import is_torch_bf16_gpu_available
|
|||||||
|
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
|
from axolotl.utils.callbacks.tokens_per_second import TOKENS_STATE_FILE
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
@@ -58,6 +60,7 @@ class TestResumeLlama:
|
|||||||
"use_tensorboard": True,
|
"use_tensorboard": True,
|
||||||
"save_safetensors": True,
|
"save_safetensors": True,
|
||||||
"save_first_step": False,
|
"save_first_step": False,
|
||||||
|
"include_tkps": True,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
if is_torch_bf16_gpu_available():
|
if is_torch_bf16_gpu_available():
|
||||||
@@ -68,8 +71,19 @@ class TestResumeLlama:
|
|||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
dataset_meta = load_datasets(cfg=cfg)
|
dataset_meta = load_datasets(cfg=cfg)
|
||||||
|
|
||||||
|
initial_total_num_tokens = cfg.total_num_tokens
|
||||||
|
assert initial_total_num_tokens is not None, (
|
||||||
|
"total_num_tokens should be calculated during load_datasets"
|
||||||
|
)
|
||||||
|
|
||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
|
|
||||||
|
checkpoint_path = f"{temp_dir}/checkpoint-9"
|
||||||
|
tokens_state_path = os.path.join(checkpoint_path, TOKENS_STATE_FILE)
|
||||||
|
assert os.path.isfile(tokens_state_path), (
|
||||||
|
f"{TOKENS_STATE_FILE} should exist in checkpoint at {tokens_state_path}"
|
||||||
|
)
|
||||||
|
|
||||||
resume_cfg = cfg | DictDefault(
|
resume_cfg = cfg | DictDefault(
|
||||||
{
|
{
|
||||||
"resume_from_checkpoint": f"{temp_dir}/checkpoint-9/",
|
"resume_from_checkpoint": f"{temp_dir}/checkpoint-9/",
|
||||||
@@ -77,7 +91,24 @@ class TestResumeLlama:
|
|||||||
)
|
)
|
||||||
normalize_config(resume_cfg)
|
normalize_config(resume_cfg)
|
||||||
|
|
||||||
train(cfg=resume_cfg, dataset_meta=dataset_meta)
|
assert resume_cfg.total_num_tokens == initial_total_num_tokens, (
|
||||||
|
f"total_num_tokens should be preserved on resume. "
|
||||||
|
f"Expected {initial_total_num_tokens}, got {resume_cfg.total_num_tokens}"
|
||||||
|
)
|
||||||
|
|
||||||
|
resume_dataset_meta = load_datasets(cfg=resume_cfg)
|
||||||
|
|
||||||
|
assert resume_cfg.total_num_tokens == initial_total_num_tokens, (
|
||||||
|
f"total_num_tokens should not be recalculated when resuming. "
|
||||||
|
f"Expected {initial_total_num_tokens}, got {resume_cfg.total_num_tokens}"
|
||||||
|
)
|
||||||
|
|
||||||
|
train(cfg=resume_cfg, dataset_meta=resume_dataset_meta)
|
||||||
|
|
||||||
|
assert resume_cfg.total_num_tokens == initial_total_num_tokens, (
|
||||||
|
f"total_num_tokens should remain unchanged after resume training. "
|
||||||
|
f"Expected {initial_total_num_tokens}, got {resume_cfg.total_num_tokens}"
|
||||||
|
)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|
||||||
tb_log_path_1 = most_recent_subdir(temp_dir + "/runs")
|
tb_log_path_1 = most_recent_subdir(temp_dir + "/runs")
|
||||||
|
|||||||
Reference in New Issue
Block a user