diff --git a/src/axolotl/core/trainers/kd.py b/src/axolotl/core/trainers/kd.py index 6047c72ff..e8adfab41 100644 --- a/src/axolotl/core/trainers/kd.py +++ b/src/axolotl/core/trainers/kd.py @@ -182,7 +182,8 @@ class AxolotlKDTrainer(AxolotlTrainer): ) if self.args.kd_ce_alpha > 0: - loss = self.args.kd_ce_alpha * outputs["loss"] + loss_kd + kd_alpha = self.args.kd_alpha + loss = self.args.kd_ce_alpha * outputs["loss"] + kd_alpha * loss_kd else: loss = loss_kd # Save past state if it exists