diff --git a/src/axolotl/integrations/kd/callbacks.py b/src/axolotl/integrations/kd/callbacks.py index 521833477..911c3d517 100644 --- a/src/axolotl/integrations/kd/callbacks.py +++ b/src/axolotl/integrations/kd/callbacks.py @@ -29,7 +29,7 @@ class KDTemperatureSchedulerCallback(TrainerCallback): # This factor goes from 1 (at progress=0) to 0 (at progress=1) decay_factor = 0.5 * (1.0 + math.cos(math.pi * progress)) self.temperature = self.temperature_start - ( - (self.temperature_start - self.temperature_min) * decay_factor + (self.temperature_start - self.temperature_min) * (1.0 - decay_factor) ) if hasattr(self.trainer.data_collator, "kd_temperature"):