fix: compute_loss return sig
This commit is contained in:
@@ -22,13 +22,11 @@ class DistillAttentionXentMSETrainer(AxolotlTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
metric_for_best_model: str = "distill/eval/loss",
|
||||
mse_factor: float = 1e3,
|
||||
xent_factor: float = 0,
|
||||
**kwargs: Any,
|
||||
):
|
||||
super().__init__(model=model, **kwargs)
|
||||
self.metric_for_best_model = metric_for_best_model
|
||||
self.criterion_xent = nn.CrossEntropyLoss(reduction="mean")
|
||||
self.criterion_mse = nn.MSELoss(reduction="mean")
|
||||
self.mse_factor = mse_factor
|
||||
@@ -40,7 +38,7 @@ class DistillAttentionXentMSETrainer(AxolotlTrainer):
|
||||
model: nn.Module,
|
||||
inputs: dict[str, Tensor],
|
||||
return_outputs=False,
|
||||
num_items_in_batch=None,
|
||||
num_items_in_batch=None, # pylint: disable=unused-argument
|
||||
) -> tuple[Tensor, dict]:
|
||||
"""
|
||||
Attention distillation ("attention transfer")
|
||||
@@ -50,8 +48,10 @@ class DistillAttentionXentMSETrainer(AxolotlTrainer):
|
||||
# alias inputs to data
|
||||
data = inputs
|
||||
|
||||
device = model.device
|
||||
|
||||
# Filter out labels
|
||||
inputs = {k: v.to(model.device) for k, v in data.items() if k != "labels"}
|
||||
inputs = {k: v.to(device) for k, v in data.items() if k != "labels"}
|
||||
|
||||
# Forward pass
|
||||
outputs = model(**inputs, output_attentions=True, use_cache=False)
|
||||
@@ -60,8 +60,8 @@ class DistillAttentionXentMSETrainer(AxolotlTrainer):
|
||||
# Attentions are tuple[tuple[torch.Tensor, torch.Tensor]]
|
||||
# n_layers x (predicted_attns, true_attns)
|
||||
# predicted_attns and true_attns are shape (batch, n_heads, q_len, k_len)
|
||||
loss_mse = tensor(0.0)
|
||||
loss_xent = tensor(0.0)
|
||||
loss_mse = tensor(0.0, device=device)
|
||||
loss_xent = tensor(0.0, device=device)
|
||||
n_layers = 0 # Number of layers to distill
|
||||
softmax_layers = []
|
||||
for layer_idx, attns in enumerate(outputs):
|
||||
@@ -106,4 +106,4 @@ class DistillAttentionXentMSETrainer(AxolotlTrainer):
|
||||
"mse_factor": self.mse_factor,
|
||||
"xent_factor": self.xent_factor,
|
||||
}
|
||||
return loss, outputs
|
||||
return (loss, outputs) if return_outputs else loss
|
||||
|
||||
Reference in New Issue
Block a user