diff --git a/src/axolotl/monkeypatch/utils.py b/src/axolotl/monkeypatch/utils.py index 1bb8f3d34..3b35f29dd 100644 --- a/src/axolotl/monkeypatch/utils.py +++ b/src/axolotl/monkeypatch/utils.py @@ -137,7 +137,7 @@ def get_packed_mask_from_pos_ids(position_ids): results.append(doc_mask) - return results + return torch.stack(results) def get_seqlens_from_pos_ids(position_ids):