fix: handle num_items_in_batch
This commit is contained in:
@@ -40,7 +40,7 @@ class DistillAttentionXentMSETrainer(AxolotlTrainer):
|
|||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
inputs: dict[str, Tensor],
|
inputs: dict[str, Tensor],
|
||||||
return_outputs=False,
|
return_outputs=False,
|
||||||
num_items_in_batch=None, # pylint: disable=unused-argument
|
num_items_in_batch=None,
|
||||||
) -> tuple[Tensor, dict]:
|
) -> tuple[Tensor, dict]:
|
||||||
"""
|
"""
|
||||||
Attention distillation ("attention transfer")
|
Attention distillation ("attention transfer")
|
||||||
@@ -55,6 +55,13 @@ class DistillAttentionXentMSETrainer(AxolotlTrainer):
|
|||||||
# Filter out labels
|
# Filter out labels
|
||||||
inputs = {k: v.to(device) for k, v in data.items() if k != "labels"}
|
inputs = {k: v.to(device) for k, v in data.items() if k != "labels"}
|
||||||
|
|
||||||
|
# set num_items_in_batch
|
||||||
|
if self.model_accepts_loss_kwargs:
|
||||||
|
loss_kwargs = {}
|
||||||
|
if num_items_in_batch is not None:
|
||||||
|
loss_kwargs["num_items_in_batch"] = num_items_in_batch
|
||||||
|
inputs = {**inputs, **loss_kwargs}
|
||||||
|
|
||||||
# Forward pass
|
# Forward pass
|
||||||
outputs = model(**inputs, output_attentions=True, use_cache=False)
|
outputs = model(**inputs, output_attentions=True, use_cache=False)
|
||||||
outputs = outputs.get("attentions")
|
outputs = outputs.get("attentions")
|
||||||
|
|||||||
Reference in New Issue
Block a user