From e8b2789086673b1af234a7e7f94e84526fb43211 Mon Sep 17 00:00:00 2001 From: Sunny Liu Date: Thu, 23 Jan 2025 11:20:38 -0500 Subject: [PATCH] revert mask expand --- src/axolotl/monkeypatch/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/axolotl/monkeypatch/utils.py b/src/axolotl/monkeypatch/utils.py index c360cf94e..be834e98d 100644 --- a/src/axolotl/monkeypatch/utils.py +++ b/src/axolotl/monkeypatch/utils.py @@ -228,8 +228,8 @@ def mask_2d_to_4d( bsz, src_len = int(mask.size()[0]), int(mask.size()[1]) tgt_len = tgt_len if tgt_len is not None else src_len - # mask = mask.unsqueeze(1).unsqueeze(2) - mask = mask[None, None, :, :].expand(bsz, 1, tgt_len, src_len) + mask = mask.unsqueeze(1).unsqueeze(2) + mask = mask.expand(bsz, 1, tgt_len, src_len) # Create a binary mask from the original mask where zeros remain zeros and all other values are set to one binary_mask = torch.where(