diff --git a/src/axolotl/integrations/lolcats/trainer/distill_attention_xent_mse.py b/src/axolotl/integrations/lolcats/trainer/distill_attention_xent_mse.py index 453e141d8..62a95ed7b 100644 --- a/src/axolotl/integrations/lolcats/trainer/distill_attention_xent_mse.py +++ b/src/axolotl/integrations/lolcats/trainer/distill_attention_xent_mse.py @@ -33,6 +33,8 @@ class DistillAttentionXentMSETrainer(AxolotlTrainer): self.xent_factor = xent_factor # self.compute_loss_backprop = False # Whether we backprop in self.compute_loss # NOTE: this config seems unnecessary + self.model_accepts_loss_kwargs = False # added to combat explosive loss + def compute_loss( self, model: nn.Module,