lint and additional train metric checks for kd
This commit is contained in:
@@ -102,7 +102,11 @@ def is_hopper():
|
||||
|
||||
|
||||
def check_tensorboard(
|
||||
temp_run_dir: str, tag: str, lt_val: float, assertion_err: str
|
||||
temp_run_dir: str,
|
||||
tag: str,
|
||||
comparison_val: float,
|
||||
assertion_err: str,
|
||||
lt: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
helper function to parse and check tensorboard logs
|
||||
@@ -112,10 +116,20 @@ def check_tensorboard(
|
||||
reader = SummaryReader(event_file)
|
||||
df = reader.scalars # pylint: disable=invalid-name
|
||||
df = df[(df.tag == tag)] # pylint: disable=invalid-name
|
||||
if "%s" in assertion_err:
|
||||
assert df.value.values[-1] < lt_val, assertion_err % df.value.values[-1]
|
||||
if lt:
|
||||
if "%s" in assertion_err:
|
||||
assert df.value.values[-1] < comparison_val, (
|
||||
assertion_err % df.value.values[-1]
|
||||
)
|
||||
else:
|
||||
assert df.value.values[-1] < comparison_val, assertion_err
|
||||
else:
|
||||
assert df.value.values[-1] < lt_val, assertion_err
|
||||
if "%s" in assertion_err:
|
||||
assert df.value.values[-1] > comparison_val, (
|
||||
assertion_err % df.value.values[-1]
|
||||
)
|
||||
else:
|
||||
assert df.value.values[-1] > comparison_val, assertion_err
|
||||
|
||||
|
||||
def check_model_output_exists(temp_dir: str, cfg: DictDefault) -> None:
|
||||
|
||||
Reference in New Issue
Block a user