make sure to multiply against the correct loss
This commit is contained in:
@@ -121,6 +121,6 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
||||
]
|
||||
|
||||
if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs:
|
||||
loss_kd *= self.accelerator.num_processes
|
||||
loss *= self.accelerator.num_processes
|
||||
|
||||
return (loss, outputs) if return_outputs else loss
|
||||
|
||||
Reference in New Issue
Block a user