From 8b3eec7f6e1536424057e452db29c03dbca459a4 Mon Sep 17 00:00:00 2001 From: Sunny Liu Date: Wed, 22 Jan 2025 21:29:52 -0500 Subject: [PATCH] mask expansion --- src/axolotl/monkeypatch/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/monkeypatch/utils.py b/src/axolotl/monkeypatch/utils.py index 541ffbf89..210eace3b 100644 --- a/src/axolotl/monkeypatch/utils.py +++ b/src/axolotl/monkeypatch/utils.py @@ -225,7 +225,7 @@ def mask_2d_to_4d( when they attend to each other within that sequence. This expansion transforms the mask to lower triangular form to prevent future peeking. """ - bsz, src_len = mask.size() + bsz, src_len = int(mask.size()[0]), int(mask.size()[1]) tgt_len = tgt_len if tgt_len is not None else src_len # mask = mask.unsqueeze(1).unsqueeze(2)