fix: set model_accepts_loss_kwargs=False

This commit is contained in:
NanoCode012
2025-02-04 02:01:05 +07:00
parent 433cf4a8c7
commit fb88269dcb

View File

@@ -33,6 +33,8 @@ class DistillAttentionXentMSETrainer(AxolotlTrainer):
self.xent_factor = xent_factor self.xent_factor = xent_factor
# self.compute_loss_backprop = False # Whether we backprop in self.compute_loss # NOTE: this config seems unnecessary # self.compute_loss_backprop = False # Whether we backprop in self.compute_loss # NOTE: this config seems unnecessary
self.model_accepts_loss_kwargs = False # added to combat explosive loss
def compute_loss( def compute_loss(
self, self,
model: nn.Module, model: nn.Module,