From b88128d067bc2dbfb2bd91d244792d488d5129fb Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 24 Dec 2024 19:54:32 -0500 Subject: [PATCH] use kd_alpha in the correct loss method --- src/axolotl/core/trainers/kd.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/axolotl/core/trainers/kd.py b/src/axolotl/core/trainers/kd.py index 6047c72ff..e8adfab41 100644 --- a/src/axolotl/core/trainers/kd.py +++ b/src/axolotl/core/trainers/kd.py @@ -182,7 +182,8 @@ class AxolotlKDTrainer(AxolotlTrainer): ) if self.args.kd_ce_alpha > 0: - loss = self.args.kd_ce_alpha * outputs["loss"] + loss_kd + kd_alpha = self.args.kd_alpha + loss = self.args.kd_ce_alpha * outputs["loss"] + kd_alpha * loss_kd else: loss = loss_kd # Save past state if it exists