fix: compute_loss return sig
This commit is contained in:
@@ -22,13 +22,11 @@ class DistillAttentionXentMSETrainer(AxolotlTrainer):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
metric_for_best_model: str = "distill/eval/loss",
|
|
||||||
mse_factor: float = 1e3,
|
mse_factor: float = 1e3,
|
||||||
xent_factor: float = 0,
|
xent_factor: float = 0,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
):
|
):
|
||||||
super().__init__(model=model, **kwargs)
|
super().__init__(model=model, **kwargs)
|
||||||
self.metric_for_best_model = metric_for_best_model
|
|
||||||
self.criterion_xent = nn.CrossEntropyLoss(reduction="mean")
|
self.criterion_xent = nn.CrossEntropyLoss(reduction="mean")
|
||||||
self.criterion_mse = nn.MSELoss(reduction="mean")
|
self.criterion_mse = nn.MSELoss(reduction="mean")
|
||||||
self.mse_factor = mse_factor
|
self.mse_factor = mse_factor
|
||||||
@@ -40,7 +38,7 @@ class DistillAttentionXentMSETrainer(AxolotlTrainer):
|
|||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
inputs: dict[str, Tensor],
|
inputs: dict[str, Tensor],
|
||||||
return_outputs=False,
|
return_outputs=False,
|
||||||
num_items_in_batch=None,
|
num_items_in_batch=None, # pylint: disable=unused-argument
|
||||||
) -> tuple[Tensor, dict]:
|
) -> tuple[Tensor, dict]:
|
||||||
"""
|
"""
|
||||||
Attention distillation ("attention transfer")
|
Attention distillation ("attention transfer")
|
||||||
@@ -50,8 +48,10 @@ class DistillAttentionXentMSETrainer(AxolotlTrainer):
|
|||||||
# alias inputs to data
|
# alias inputs to data
|
||||||
data = inputs
|
data = inputs
|
||||||
|
|
||||||
|
device = model.device
|
||||||
|
|
||||||
# Filter out labels
|
# Filter out labels
|
||||||
inputs = {k: v.to(model.device) for k, v in data.items() if k != "labels"}
|
inputs = {k: v.to(device) for k, v in data.items() if k != "labels"}
|
||||||
|
|
||||||
# Forward pass
|
# Forward pass
|
||||||
outputs = model(**inputs, output_attentions=True, use_cache=False)
|
outputs = model(**inputs, output_attentions=True, use_cache=False)
|
||||||
@@ -60,8 +60,8 @@ class DistillAttentionXentMSETrainer(AxolotlTrainer):
|
|||||||
# Attentions are tuple[tuple[torch.Tensor, torch.Tensor]]
|
# Attentions are tuple[tuple[torch.Tensor, torch.Tensor]]
|
||||||
# n_layers x (predicted_attns, true_attns)
|
# n_layers x (predicted_attns, true_attns)
|
||||||
# predicted_attns and true_attns are shape (batch, n_heads, q_len, k_len)
|
# predicted_attns and true_attns are shape (batch, n_heads, q_len, k_len)
|
||||||
loss_mse = tensor(0.0)
|
loss_mse = tensor(0.0, device=device)
|
||||||
loss_xent = tensor(0.0)
|
loss_xent = tensor(0.0, device=device)
|
||||||
n_layers = 0 # Number of layers to distill
|
n_layers = 0 # Number of layers to distill
|
||||||
softmax_layers = []
|
softmax_layers = []
|
||||||
for layer_idx, attns in enumerate(outputs):
|
for layer_idx, attns in enumerate(outputs):
|
||||||
@@ -106,4 +106,4 @@ class DistillAttentionXentMSETrainer(AxolotlTrainer):
|
|||||||
"mse_factor": self.mse_factor,
|
"mse_factor": self.mse_factor,
|
||||||
"xent_factor": self.xent_factor,
|
"xent_factor": self.xent_factor,
|
||||||
}
|
}
|
||||||
return loss, outputs
|
return (loss, outputs) if return_outputs else loss
|
||||||
|
|||||||
Reference in New Issue
Block a user