lint and additional train metric checks for kd

This commit is contained in:
Wing Lian
2025-02-26 03:19:42 -05:00
parent afbb44f08b
commit 23f029a89c
5 changed files with 35 additions and 4 deletions

View File

@@ -102,7 +102,11 @@ def is_hopper():
def check_tensorboard(
temp_run_dir: str, tag: str, lt_val: float, assertion_err: str
temp_run_dir: str,
tag: str,
comparison_val: float,
assertion_err: str,
lt: bool = True,
) -> None:
"""
helper function to parse and check tensorboard logs
@@ -112,10 +116,20 @@ def check_tensorboard(
reader = SummaryReader(event_file)
df = reader.scalars # pylint: disable=invalid-name
df = df[(df.tag == tag)] # pylint: disable=invalid-name
if "%s" in assertion_err:
assert df.value.values[-1] < lt_val, assertion_err % df.value.values[-1]
if lt:
if "%s" in assertion_err:
assert df.value.values[-1] < comparison_val, (
assertion_err % df.value.values[-1]
)
else:
assert df.value.values[-1] < comparison_val, assertion_err
else:
assert df.value.values[-1] < lt_val, assertion_err
if "%s" in assertion_err:
assert df.value.values[-1] > comparison_val, (
assertion_err % df.value.values[-1]
)
else:
assert df.value.values[-1] > comparison_val, assertion_err
def check_model_output_exists(temp_dir: str, cfg: DictDefault) -> None: