feat: cheap ppl metric (#3317)
* Import math and compute perplexity from loss values * lint * coderabbit changes * lint * fix: add rounding to ppl --------- Co-authored-by: NanoCode012 <nano@axolotl.ai>
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from functools import partial, wraps
|
||||
@@ -615,6 +616,17 @@ class AxolotlTrainer(
|
||||
)
|
||||
logs[key] = round(fn(values).item(), 4)
|
||||
|
||||
if "loss" in logs:
|
||||
try:
|
||||
logs["ppl"] = round(math.exp(logs["loss"]), 4)
|
||||
except OverflowError:
|
||||
logs["ppl"] = float("inf")
|
||||
if "eval_loss" in logs:
|
||||
try:
|
||||
logs["eval_ppl"] = round(math.exp(logs["eval_loss"]), 4)
|
||||
except OverflowError:
|
||||
logs["eval_ppl"] = float("inf")
|
||||
|
||||
if is_main_process():
|
||||
# Add memory usage
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user