diff --git a/src/axolotl/monkeypatch/llama_patch_multipack.py b/src/axolotl/monkeypatch/llama_patch_multipack.py index e4729a45f..b286175bf 100644 --- a/src/axolotl/monkeypatch/llama_patch_multipack.py +++ b/src/axolotl/monkeypatch/llama_patch_multipack.py @@ -29,11 +29,11 @@ def hijack_llama_prepare_4d_mask(): @staticmethod def llama_patched_prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: Optional[torch.Tensor], *args + attention_mask: Optional[torch.Tensor], **kwargs ): dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32 return LlamaModel._prepare_4d_causal_attention_mask_with_cache_position( - mask_2d_to_4d(attention_mask, dtype=dtype), *args + mask_2d_to_4d(attention_mask, dtype=dtype), **kwargs ) LlamaModel._prepare_4d_causal_attention_mask_with_cache_position = ( # pylint: disable=protected-access