diff --git a/src/axolotl/integrations/lolcats/trainer/distill_attention_xent_mse.py b/src/axolotl/integrations/lolcats/trainer/distill_attention_xent_mse.py index 582b21057..453e141d8 100644 --- a/src/axolotl/integrations/lolcats/trainer/distill_attention_xent_mse.py +++ b/src/axolotl/integrations/lolcats/trainer/distill_attention_xent_mse.py @@ -22,13 +22,11 @@ class DistillAttentionXentMSETrainer(AxolotlTrainer): def __init__( self, model: nn.Module, - metric_for_best_model: str = "distill/eval/loss", mse_factor: float = 1e3, xent_factor: float = 0, **kwargs: Any, ): super().__init__(model=model, **kwargs) - self.metric_for_best_model = metric_for_best_model self.criterion_xent = nn.CrossEntropyLoss(reduction="mean") self.criterion_mse = nn.MSELoss(reduction="mean") self.mse_factor = mse_factor @@ -40,7 +38,7 @@ class DistillAttentionXentMSETrainer(AxolotlTrainer): model: nn.Module, inputs: dict[str, Tensor], return_outputs=False, - num_items_in_batch=None, + num_items_in_batch=None, # pylint: disable=unused-argument ) -> tuple[Tensor, dict]: """ Attention distillation ("attention transfer") @@ -50,8 +48,10 @@ class DistillAttentionXentMSETrainer(AxolotlTrainer): # alias inputs to data data = inputs + device = model.device + # Filter out labels - inputs = {k: v.to(model.device) for k, v in data.items() if k != "labels"} + inputs = {k: v.to(device) for k, v in data.items() if k != "labels"} # Forward pass outputs = model(**inputs, output_attentions=True, use_cache=False) @@ -60,8 +60,8 @@ class DistillAttentionXentMSETrainer(AxolotlTrainer): # Attentions are tuple[tuple[torch.Tensor, torch.Tensor]] # n_layers x (predicted_attns, true_attns) # predicted_attns and true_attns are shape (batch, n_heads, q_len, k_len) - loss_mse = tensor(0.0) - loss_xent = tensor(0.0) + loss_mse = tensor(0.0, device=device) + loss_xent = tensor(0.0, device=device) n_layers = 0 # Number of layers to distill softmax_layers = [] for layer_idx, attns in enumerate(outputs): @@ -106,4 +106,4 @@ class DistillAttentionXentMSETrainer(AxolotlTrainer): "mse_factor": self.mse_factor, "xent_factor": self.xent_factor, } - return loss, outputs + return (loss, outputs) if return_outputs else loss