fix: set model_accepts_loss_kwargs=False
This commit is contained in:
@@ -33,6 +33,8 @@ class DistillAttentionXentMSETrainer(AxolotlTrainer):
|
||||
self.xent_factor = xent_factor
|
||||
# self.compute_loss_backprop = False # Whether we backprop in self.compute_loss # NOTE: this config seems unnecessary
|
||||
|
||||
self.model_accepts_loss_kwargs = False # added to combat explosive loss
|
||||
|
||||
def compute_loss(
|
||||
self,
|
||||
model: nn.Module,
|
||||
|
||||
Reference in New Issue
Block a user