From 1fb8d863969faee486acac3539fb5b628444cfb2 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 4 Feb 2025 19:32:20 +0700 Subject: [PATCH] fix: handle num_items_in_batch --- .../lolcats/trainer/distill_attention_xent_mse.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/axolotl/integrations/lolcats/trainer/distill_attention_xent_mse.py b/src/axolotl/integrations/lolcats/trainer/distill_attention_xent_mse.py index 62a95ed7b..053b29a6d 100644 --- a/src/axolotl/integrations/lolcats/trainer/distill_attention_xent_mse.py +++ b/src/axolotl/integrations/lolcats/trainer/distill_attention_xent_mse.py @@ -40,7 +40,7 @@ class DistillAttentionXentMSETrainer(AxolotlTrainer): model: nn.Module, inputs: dict[str, Tensor], return_outputs=False, - num_items_in_batch=None, # pylint: disable=unused-argument + num_items_in_batch=None, ) -> tuple[Tensor, dict]: """ Attention distillation ("attention transfer") @@ -55,6 +55,13 @@ class DistillAttentionXentMSETrainer(AxolotlTrainer): # Filter out labels inputs = {k: v.to(device) for k, v in data.items() if k != "labels"} + # set num_items_in_batch + if self.model_accepts_loss_kwargs: + loss_kwargs = {} + if num_items_in_batch is not None: + loss_kwargs["num_items_in_batch"] = num_items_in_batch + inputs = {**inputs, **loss_kwargs} + # Forward pass outputs = model(**inputs, output_attentions=True, use_cache=False) outputs = outputs.get("attentions")