From faaff6c7929f948ec4f6ea9dd9816b7430f03a2a Mon Sep 17 00:00:00 2001 From: VED <146507396+ved1beta@users.noreply.github.com> Date: Mon, 22 Dec 2025 19:24:43 +0530 Subject: [PATCH] allow users to set ndigits for rounding of metrics when logging (#3325) * METRIC_PRECISION-> 8 * use ndigits and move env getter to top of log function --------- Co-authored-by: Ved Co-authored-by: Wing Lian --- src/axolotl/core/trainers/base.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 8adafd42d..aae3d28fb 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -604,6 +604,7 @@ class AxolotlTrainer( """ # logs either has 'loss' or 'eval_loss' train_eval = "train" if "loss" in logs else "eval" + metric_ndigits = int(os.getenv("AXOLOTL_METRIC_NDIGITS", "5")) for key, metric_data in self._stored_metrics[train_eval].items(): values = torch.tensor(metric_data["values"]) # type: ignore[arg-type] @@ -614,16 +615,16 @@ class AxolotlTrainer( raise NotImplementedError( "Metric reduction must be one of [mean, min, max, sum]" ) - logs[key] = round(fn(values).item(), 4) + logs[key] = round(fn(values).item(), metric_ndigits) if "loss" in logs: try: - logs["ppl"] = round(math.exp(logs["loss"]), 4) + logs["ppl"] = round(math.exp(logs["loss"]), metric_ndigits) except OverflowError: logs["ppl"] = float("inf") if "eval_loss" in logs: try: - logs["eval_ppl"] = round(math.exp(logs["eval_loss"]), 4) + logs["eval_ppl"] = round(math.exp(logs["eval_loss"]), metric_ndigits) except OverflowError: logs["eval_ppl"] = float("inf")