Fix decay

This commit is contained in:
Wing Lian
2025-05-28 08:19:52 -04:00
parent 3a0faa97ca
commit e77d62933d

View File

@@ -29,7 +29,7 @@ class KDTemperatureSchedulerCallback(TrainerCallback):
# This factor goes from 1 (at progress=0) to 0 (at progress=1)
decay_factor = 0.5 * (1.0 + math.cos(math.pi * progress))
self.temperature = self.temperature_start - (
(self.temperature_start - self.temperature_min) * decay_factor
(self.temperature_start - self.temperature_min) * (1.0 - decay_factor)
)
if hasattr(self.trainer.data_collator, "kd_temperature"):