use kd_alpha in the correct loss method

This commit is contained in:
Wing Lian
2024-12-24 19:54:32 -05:00
parent 2e6422a711
commit b88128d067

View File

@@ -182,7 +182,8 @@ class AxolotlKDTrainer(AxolotlTrainer):
)
if self.args.kd_ce_alpha > 0:
loss = self.args.kd_ce_alpha * outputs["loss"] + loss_kd
kd_alpha = self.args.kd_alpha
loss = self.args.kd_ce_alpha * outputs["loss"] + kd_alpha * loss_kd
else:
loss = loss_kd
# Save past state if it exists