replace tensorboard checks with helper function (#2120) [skip ci]
* replace tensorboard checks with helper function * move helper function * use relative
This commit is contained in:
@@ -7,15 +7,13 @@ import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from tbparse import SummaryReader
|
||||
|
||||
from axolotl.cli import load_datasets
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import most_recent_subdir, with_temp_dir
|
||||
from .utils import check_tensorboard, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
@@ -66,12 +64,9 @@ class TestEmbeddingsLrScale(unittest.TestCase):
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||
|
||||
tb_log_path = most_recent_subdir(temp_dir + "/runs")
|
||||
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
|
||||
reader = SummaryReader(event_file)
|
||||
df = reader.scalars # pylint: disable=invalid-name
|
||||
df = df[(df.tag == "train/train_loss")] # pylint: disable=invalid-name
|
||||
assert df.value.values[-1] < 2.0, "Loss is too high"
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.0, "Loss is too high"
|
||||
)
|
||||
|
||||
@with_temp_dir
|
||||
def test_train_w_embedding_lr(self, temp_dir):
|
||||
@@ -113,9 +108,6 @@ class TestEmbeddingsLrScale(unittest.TestCase):
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||
|
||||
tb_log_path = most_recent_subdir(temp_dir + "/runs")
|
||||
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
|
||||
reader = SummaryReader(event_file)
|
||||
df = reader.scalars # pylint: disable=invalid-name
|
||||
df = df[(df.tag == "train/train_loss")] # pylint: disable=invalid-name
|
||||
assert df.value.values[-1] < 2.0, "Loss is too high"
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.0, "Loss is too high"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user