make sure to multiply against the correct loss

This commit is contained in:
Wing Lian
2024-12-19 01:42:57 -05:00
parent ae545e0165
commit 00ce77e7ef

View File

@@ -121,6 +121,6 @@ class AxolotlKDTrainer(AxolotlTrainer):
]
if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs:
loss_kd *= self.accelerator.num_processes
loss *= self.accelerator.num_processes
return (loss, outputs) if return_outputs else loss