fix: compute_loss return sig

This commit is contained in:
NanoCode012
2025-02-04 01:53:18 +07:00
parent 0b7b58c8be
commit 433cf4a8c7

View File

@@ -22,13 +22,11 @@ class DistillAttentionXentMSETrainer(AxolotlTrainer):
def __init__(
self,
model: nn.Module,
metric_for_best_model: str = "distill/eval/loss",
mse_factor: float = 1e3,
xent_factor: float = 0,
**kwargs: Any,
):
super().__init__(model=model, **kwargs)
self.metric_for_best_model = metric_for_best_model
self.criterion_xent = nn.CrossEntropyLoss(reduction="mean")
self.criterion_mse = nn.MSELoss(reduction="mean")
self.mse_factor = mse_factor
@@ -40,7 +38,7 @@ class DistillAttentionXentMSETrainer(AxolotlTrainer):
model: nn.Module,
inputs: dict[str, Tensor],
return_outputs=False,
num_items_in_batch=None,
num_items_in_batch=None, # pylint: disable=unused-argument
) -> tuple[Tensor, dict]:
"""
Attention distillation ("attention transfer")
@@ -50,8 +48,10 @@ class DistillAttentionXentMSETrainer(AxolotlTrainer):
# alias inputs to data
data = inputs
device = model.device
# Filter out labels
inputs = {k: v.to(model.device) for k, v in data.items() if k != "labels"}
inputs = {k: v.to(device) for k, v in data.items() if k != "labels"}
# Forward pass
outputs = model(**inputs, output_attentions=True, use_cache=False)
@@ -60,8 +60,8 @@ class DistillAttentionXentMSETrainer(AxolotlTrainer):
# Attentions are tuple[tuple[torch.Tensor, torch.Tensor]]
# n_layers x (predicted_attns, true_attns)
# predicted_attns and true_attns are shape (batch, n_heads, q_len, k_len)
loss_mse = tensor(0.0)
loss_xent = tensor(0.0)
loss_mse = tensor(0.0, device=device)
loss_xent = tensor(0.0, device=device)
n_layers = 0 # Number of layers to distill
softmax_layers = []
for layer_idx, attns in enumerate(outputs):
@@ -106,4 +106,4 @@ class DistillAttentionXentMSETrainer(AxolotlTrainer):
"mse_factor": self.mse_factor,
"xent_factor": self.xent_factor,
}
return loss, outputs
return (loss, outputs) if return_outputs else loss