compute loss only if training and update token metric naming (#3293) [skip ci]
* compute loss only if training * save total_tokens for checkpiont * check if string * refactor total_tokens/ num_tokens * refactor 2 * rplc trainable_step/trian_per_sec_per_gpu * lint + log trainable/tokens * consolidate it in the callback. * test for total_tokes aftr remuse * check if tokenstate exist after ckpt --------- Co-authored-by: Ved <ved.work2024@gmail.com>
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
E2E tests for resuming training
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
|
||||
@@ -9,6 +10,7 @@ from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.callbacks.tokens_per_second import TOKENS_STATE_FILE
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
@@ -58,6 +60,7 @@ class TestResumeLlama:
|
||||
"use_tensorboard": True,
|
||||
"save_safetensors": True,
|
||||
"save_first_step": False,
|
||||
"include_tkps": True,
|
||||
}
|
||||
)
|
||||
if is_torch_bf16_gpu_available():
|
||||
@@ -68,8 +71,19 @@ class TestResumeLlama:
|
||||
normalize_config(cfg)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
initial_total_num_tokens = cfg.total_num_tokens
|
||||
assert initial_total_num_tokens is not None, (
|
||||
"total_num_tokens should be calculated during load_datasets"
|
||||
)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
|
||||
checkpoint_path = f"{temp_dir}/checkpoint-9"
|
||||
tokens_state_path = os.path.join(checkpoint_path, TOKENS_STATE_FILE)
|
||||
assert os.path.isfile(tokens_state_path), (
|
||||
f"{TOKENS_STATE_FILE} should exist in checkpoint at {tokens_state_path}"
|
||||
)
|
||||
|
||||
resume_cfg = cfg | DictDefault(
|
||||
{
|
||||
"resume_from_checkpoint": f"{temp_dir}/checkpoint-9/",
|
||||
@@ -77,7 +91,24 @@ class TestResumeLlama:
|
||||
)
|
||||
normalize_config(resume_cfg)
|
||||
|
||||
train(cfg=resume_cfg, dataset_meta=dataset_meta)
|
||||
assert resume_cfg.total_num_tokens == initial_total_num_tokens, (
|
||||
f"total_num_tokens should be preserved on resume. "
|
||||
f"Expected {initial_total_num_tokens}, got {resume_cfg.total_num_tokens}"
|
||||
)
|
||||
|
||||
resume_dataset_meta = load_datasets(cfg=resume_cfg)
|
||||
|
||||
assert resume_cfg.total_num_tokens == initial_total_num_tokens, (
|
||||
f"total_num_tokens should not be recalculated when resuming. "
|
||||
f"Expected {initial_total_num_tokens}, got {resume_cfg.total_num_tokens}"
|
||||
)
|
||||
|
||||
train(cfg=resume_cfg, dataset_meta=resume_dataset_meta)
|
||||
|
||||
assert resume_cfg.total_num_tokens == initial_total_num_tokens, (
|
||||
f"total_num_tokens should remain unchanged after resume training. "
|
||||
f"Expected {initial_total_num_tokens}, got {resume_cfg.total_num_tokens}"
|
||||
)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
tb_log_path_1 = most_recent_subdir(temp_dir + "/runs")
|
||||
|
||||
Reference in New Issue
Block a user