make sure to multiply against the correct loss
This commit is contained in:
@@ -121,6 +121,6 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
|||||||
]
|
]
|
||||||
|
|
||||||
if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs:
|
if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs:
|
||||||
loss_kd *= self.accelerator.num_processes
|
loss *= self.accelerator.num_processes
|
||||||
|
|
||||||
return (loss, outputs) if return_outputs else loss
|
return (loss, outputs) if return_outputs else loss
|
||||||
|
|||||||
Reference in New Issue
Block a user