fix: set model_accepts_loss_kwargs=False

This commit is contained in:
NanoCode012
2025-02-04 02:01:05 +07:00
parent 433cf4a8c7
commit fb88269dcb

View File

@@ -33,6 +33,8 @@ class DistillAttentionXentMSETrainer(AxolotlTrainer):
self.xent_factor = xent_factor
# self.compute_loss_backprop = False # Whether we backprop in self.compute_loss # NOTE: this config seems unnecessary
self.model_accepts_loss_kwargs = False # added to combat explosive loss
def compute_loss(
self,
model: nn.Module,