feat: cheap ppl metric (#3317)
* Import math and compute perplexity from loss values * lint * coderabbit changes * lint * fix: add rounding to ppl --------- Co-authored-by: NanoCode012 <nano@axolotl.ai>
This commit is contained in:
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from functools import partial, wraps
|
from functools import partial, wraps
|
||||||
@@ -615,6 +616,17 @@ class AxolotlTrainer(
|
|||||||
)
|
)
|
||||||
logs[key] = round(fn(values).item(), 4)
|
logs[key] = round(fn(values).item(), 4)
|
||||||
|
|
||||||
|
if "loss" in logs:
|
||||||
|
try:
|
||||||
|
logs["ppl"] = round(math.exp(logs["loss"]), 4)
|
||||||
|
except OverflowError:
|
||||||
|
logs["ppl"] = float("inf")
|
||||||
|
if "eval_loss" in logs:
|
||||||
|
try:
|
||||||
|
logs["eval_ppl"] = round(math.exp(logs["eval_loss"]), 4)
|
||||||
|
except OverflowError:
|
||||||
|
logs["eval_ppl"] = float("inf")
|
||||||
|
|
||||||
if is_main_process():
|
if is_main_process():
|
||||||
# Add memory usage
|
# Add memory usage
|
||||||
try:
|
try:
|
||||||
|
|||||||
Reference in New Issue
Block a user