From 555aa5772abe0fd9695054b4c92b6e5f889ca747 Mon Sep 17 00:00:00 2001 From: Sunny Liu Date: Thu, 23 Jan 2025 14:01:53 -0500 Subject: [PATCH] skip mask conversion if already 4d --- src/axolotl/monkeypatch/utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/axolotl/monkeypatch/utils.py b/src/axolotl/monkeypatch/utils.py index be834e98d..5371b3ff3 100644 --- a/src/axolotl/monkeypatch/utils.py +++ b/src/axolotl/monkeypatch/utils.py @@ -225,6 +225,9 @@ def mask_2d_to_4d( when they attend to each other within that sequence. This expansion transforms the mask to lower triangular form to prevent future peeking. """ + + if len(mask.size()) == 4: + return mask bsz, src_len = int(mask.size()[0]), int(mask.size()[1]) tgt_len = tgt_len if tgt_len is not None else src_len