Fix decay
This commit is contained in:
@@ -29,7 +29,7 @@ class KDTemperatureSchedulerCallback(TrainerCallback):
|
|||||||
# This factor goes from 1 (at progress=0) to 0 (at progress=1)
|
# This factor goes from 1 (at progress=0) to 0 (at progress=1)
|
||||||
decay_factor = 0.5 * (1.0 + math.cos(math.pi * progress))
|
decay_factor = 0.5 * (1.0 + math.cos(math.pi * progress))
|
||||||
self.temperature = self.temperature_start - (
|
self.temperature = self.temperature_start - (
|
||||||
(self.temperature_start - self.temperature_min) * decay_factor
|
(self.temperature_start - self.temperature_min) * (1.0 - decay_factor)
|
||||||
)
|
)
|
||||||
|
|
||||||
if hasattr(self.trainer.data_collator, "kd_temperature"):
|
if hasattr(self.trainer.data_collator, "kd_temperature"):
|
||||||
|
|||||||
Reference in New Issue
Block a user