diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index f4414d649..8adafd42d 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -2,6 +2,7 @@ from __future__ import annotations +import math import os from collections import defaultdict from functools import partial, wraps @@ -615,6 +616,17 @@ class AxolotlTrainer( ) logs[key] = round(fn(values).item(), 4) + if "loss" in logs: + try: + logs["ppl"] = round(math.exp(logs["loss"]), 4) + except OverflowError: + logs["ppl"] = float("inf") + if "eval_loss" in logs: + try: + logs["eval_ppl"] = round(math.exp(logs["eval_loss"]), 4) + except OverflowError: + logs["eval_ppl"] = float("inf") + if is_main_process(): # Add memory usage try: