make sure to multiply against the correct loss

This commit is contained in:
Wing Lian
2024-12-19 01:42:57 -05:00
parent 1107f1f603
commit 222dc27410

View File

@@ -121,6 +121,6 @@ class AxolotlKDTrainer(AxolotlTrainer):
]
if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs:
loss_kd *= self.accelerator.num_processes
loss *= self.accelerator.num_processes
return (loss, outputs) if return_outputs else loss