fix: compute_loss return sig

This commit is contained in:
NanoCode012
2025-02-04 01:53:18 +07:00
parent 0b7b58c8be
commit 433cf4a8c7

View File

@@ -22,13 +22,11 @@ class DistillAttentionXentMSETrainer(AxolotlTrainer):
def __init__( def __init__(
self, self,
model: nn.Module, model: nn.Module,
metric_for_best_model: str = "distill/eval/loss",
mse_factor: float = 1e3, mse_factor: float = 1e3,
xent_factor: float = 0, xent_factor: float = 0,
**kwargs: Any, **kwargs: Any,
): ):
super().__init__(model=model, **kwargs) super().__init__(model=model, **kwargs)
self.metric_for_best_model = metric_for_best_model
self.criterion_xent = nn.CrossEntropyLoss(reduction="mean") self.criterion_xent = nn.CrossEntropyLoss(reduction="mean")
self.criterion_mse = nn.MSELoss(reduction="mean") self.criterion_mse = nn.MSELoss(reduction="mean")
self.mse_factor = mse_factor self.mse_factor = mse_factor
@@ -40,7 +38,7 @@ class DistillAttentionXentMSETrainer(AxolotlTrainer):
model: nn.Module, model: nn.Module,
inputs: dict[str, Tensor], inputs: dict[str, Tensor],
return_outputs=False, return_outputs=False,
num_items_in_batch=None, num_items_in_batch=None, # pylint: disable=unused-argument
) -> tuple[Tensor, dict]: ) -> tuple[Tensor, dict]:
""" """
Attention distillation ("attention transfer") Attention distillation ("attention transfer")
@@ -50,8 +48,10 @@ class DistillAttentionXentMSETrainer(AxolotlTrainer):
# alias inputs to data # alias inputs to data
data = inputs data = inputs
device = model.device
# Filter out labels # Filter out labels
inputs = {k: v.to(model.device) for k, v in data.items() if k != "labels"} inputs = {k: v.to(device) for k, v in data.items() if k != "labels"}
# Forward pass # Forward pass
outputs = model(**inputs, output_attentions=True, use_cache=False) outputs = model(**inputs, output_attentions=True, use_cache=False)
@@ -60,8 +60,8 @@ class DistillAttentionXentMSETrainer(AxolotlTrainer):
# Attentions are tuple[tuple[torch.Tensor, torch.Tensor]] # Attentions are tuple[tuple[torch.Tensor, torch.Tensor]]
# n_layers x (predicted_attns, true_attns) # n_layers x (predicted_attns, true_attns)
# predicted_attns and true_attns are shape (batch, n_heads, q_len, k_len) # predicted_attns and true_attns are shape (batch, n_heads, q_len, k_len)
loss_mse = tensor(0.0) loss_mse = tensor(0.0, device=device)
loss_xent = tensor(0.0) loss_xent = tensor(0.0, device=device)
n_layers = 0 # Number of layers to distill n_layers = 0 # Number of layers to distill
softmax_layers = [] softmax_layers = []
for layer_idx, attns in enumerate(outputs): for layer_idx, attns in enumerate(outputs):
@@ -106,4 +106,4 @@ class DistillAttentionXentMSETrainer(AxolotlTrainer):
"mse_factor": self.mse_factor, "mse_factor": self.mse_factor,
"xent_factor": self.xent_factor, "xent_factor": self.xent_factor,
} }
return loss, outputs return (loss, outputs) if return_outputs else loss