fix: handle num_items_in_batch

This commit is contained in:
NanoCode012
2025-02-04 19:32:20 +07:00
parent adeefc1991
commit 1fb8d86396

View File

@@ -40,7 +40,7 @@ class DistillAttentionXentMSETrainer(AxolotlTrainer):
model: nn.Module,
inputs: dict[str, Tensor],
return_outputs=False,
num_items_in_batch=None, # pylint: disable=unused-argument
num_items_in_batch=None,
) -> tuple[Tensor, dict]:
"""
Attention distillation ("attention transfer")
@@ -55,6 +55,13 @@ class DistillAttentionXentMSETrainer(AxolotlTrainer):
# Filter out labels
inputs = {k: v.to(device) for k, v in data.items() if k != "labels"}
# set num_items_in_batch
if self.model_accepts_loss_kwargs:
loss_kwargs = {}
if num_items_in_batch is not None:
loss_kwargs["num_items_in_batch"] = num_items_in_batch
inputs = {**inputs, **loss_kwargs}
# Forward pass
outputs = model(**inputs, output_attentions=True, use_cache=False)
outputs = outputs.get("attentions")