diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 8adafd42d..aae3d28fb 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -604,6 +604,7 @@ class AxolotlTrainer( """ # logs either has 'loss' or 'eval_loss' train_eval = "train" if "loss" in logs else "eval" + metric_ndigits = int(os.getenv("AXOLOTL_METRIC_NDIGITS", "5")) for key, metric_data in self._stored_metrics[train_eval].items(): values = torch.tensor(metric_data["values"]) # type: ignore[arg-type] @@ -614,16 +615,16 @@ class AxolotlTrainer( raise NotImplementedError( "Metric reduction must be one of [mean, min, max, sum]" ) - logs[key] = round(fn(values).item(), 4) + logs[key] = round(fn(values).item(), metric_ndigits) if "loss" in logs: try: - logs["ppl"] = round(math.exp(logs["loss"]), 4) + logs["ppl"] = round(math.exp(logs["loss"]), metric_ndigits) except OverflowError: logs["ppl"] = float("inf") if "eval_loss" in logs: try: - logs["eval_ppl"] = round(math.exp(logs["eval_loss"]), 4) + logs["eval_ppl"] = round(math.exp(logs["eval_loss"]), metric_ndigits) except OverflowError: logs["eval_ppl"] = float("inf")