feat: cheap ppl metric (#3317)

* Import math and compute perplexity from loss values

* lint

* coderabbit changes

* lint

* fix: add rounding to ppl

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
This commit is contained in:
xzuyn
2025-12-18 09:02:41 -05:00
committed by GitHub
parent 3e51a680c2
commit 2197b0bf89

View File

@@ -2,6 +2,7 @@
from __future__ import annotations
import math
import os
from collections import defaultdict
from functools import partial, wraps
@@ -615,6 +616,17 @@ class AxolotlTrainer(
)
logs[key] = round(fn(values).item(), 4)
if "loss" in logs:
try:
logs["ppl"] = round(math.exp(logs["loss"]), 4)
except OverflowError:
logs["ppl"] = float("inf")
if "eval_loss" in logs:
try:
logs["eval_ppl"] = round(math.exp(logs["eval_loss"]), 4)
except OverflowError:
logs["eval_ppl"] = float("inf")
if is_main_process():
# Add memory usage
try: