allow users to set ndigits for rounding of metrics when logging (#3325)
* METRIC_PRECISION-> 8 * use ndigits and move env getter to top of log function --------- Co-authored-by: Ved <ved.work2024@gmail.com> Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
@@ -604,6 +604,7 @@ class AxolotlTrainer(
|
|||||||
"""
|
"""
|
||||||
# logs either has 'loss' or 'eval_loss'
|
# logs either has 'loss' or 'eval_loss'
|
||||||
train_eval = "train" if "loss" in logs else "eval"
|
train_eval = "train" if "loss" in logs else "eval"
|
||||||
|
metric_ndigits = int(os.getenv("AXOLOTL_METRIC_NDIGITS", "5"))
|
||||||
|
|
||||||
for key, metric_data in self._stored_metrics[train_eval].items():
|
for key, metric_data in self._stored_metrics[train_eval].items():
|
||||||
values = torch.tensor(metric_data["values"]) # type: ignore[arg-type]
|
values = torch.tensor(metric_data["values"]) # type: ignore[arg-type]
|
||||||
@@ -614,16 +615,16 @@ class AxolotlTrainer(
|
|||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Metric reduction must be one of [mean, min, max, sum]"
|
"Metric reduction must be one of [mean, min, max, sum]"
|
||||||
)
|
)
|
||||||
logs[key] = round(fn(values).item(), 4)
|
logs[key] = round(fn(values).item(), metric_ndigits)
|
||||||
|
|
||||||
if "loss" in logs:
|
if "loss" in logs:
|
||||||
try:
|
try:
|
||||||
logs["ppl"] = round(math.exp(logs["loss"]), 4)
|
logs["ppl"] = round(math.exp(logs["loss"]), metric_ndigits)
|
||||||
except OverflowError:
|
except OverflowError:
|
||||||
logs["ppl"] = float("inf")
|
logs["ppl"] = float("inf")
|
||||||
if "eval_loss" in logs:
|
if "eval_loss" in logs:
|
||||||
try:
|
try:
|
||||||
logs["eval_ppl"] = round(math.exp(logs["eval_loss"]), 4)
|
logs["eval_ppl"] = round(math.exp(logs["eval_loss"]), metric_ndigits)
|
||||||
except OverflowError:
|
except OverflowError:
|
||||||
logs["eval_ppl"] = float("inf")
|
logs["eval_ppl"] = float("inf")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user