fix: handle num_items_in_batch
This commit is contained in:
@@ -40,7 +40,7 @@ class DistillAttentionXentMSETrainer(AxolotlTrainer):
|
||||
model: nn.Module,
|
||||
inputs: dict[str, Tensor],
|
||||
return_outputs=False,
|
||||
num_items_in_batch=None, # pylint: disable=unused-argument
|
||||
num_items_in_batch=None,
|
||||
) -> tuple[Tensor, dict]:
|
||||
"""
|
||||
Attention distillation ("attention transfer")
|
||||
@@ -55,6 +55,13 @@ class DistillAttentionXentMSETrainer(AxolotlTrainer):
|
||||
# Filter out labels
|
||||
inputs = {k: v.to(device) for k, v in data.items() if k != "labels"}
|
||||
|
||||
# set num_items_in_batch
|
||||
if self.model_accepts_loss_kwargs:
|
||||
loss_kwargs = {}
|
||||
if num_items_in_batch is not None:
|
||||
loss_kwargs["num_items_in_batch"] = num_items_in_batch
|
||||
inputs = {**inputs, **loss_kwargs}
|
||||
|
||||
# Forward pass
|
||||
outputs = model(**inputs, output_attentions=True, use_cache=False)
|
||||
outputs = outputs.get("attentions")
|
||||
|
||||
Reference in New Issue
Block a user