use kd_alpha in the correct loss method

This commit is contained in:
Wing Lian
2024-12-24 19:54:32 -05:00
parent 3416302b0d
commit ca5e397fc5

View File

@@ -182,7 +182,8 @@ class AxolotlKDTrainer(AxolotlTrainer):
)
if self.args.kd_ce_alpha > 0:
loss = self.args.kd_ce_alpha * outputs["loss"] + loss_kd
kd_alpha = self.args.kd_alpha
loss = self.args.kd_ce_alpha * outputs["loss"] + kd_alpha * loss_kd
else:
loss = loss_kd
# Save past state if it exists