From f2f23c80411c655a45c33d937b2f970a32617b6a Mon Sep 17 00:00:00 2001 From: Sunny Liu Date: Wed, 22 Jan 2025 21:31:42 -0500 Subject: [PATCH] mask expansion --- src/axolotl/monkeypatch/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/monkeypatch/utils.py b/src/axolotl/monkeypatch/utils.py index 210eace3b..da85aa9cc 100644 --- a/src/axolotl/monkeypatch/utils.py +++ b/src/axolotl/monkeypatch/utils.py @@ -229,7 +229,7 @@ def mask_2d_to_4d( tgt_len = tgt_len if tgt_len is not None else src_len # mask = mask.unsqueeze(1).unsqueeze(2) - mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len) + mask = mask[:, None, :].expand(bsz, 1, tgt_len, src_len) # Create a binary mask from the original mask where zeros remain zeros and all other values are set to one binary_mask = torch.where(