diff --git a/src/axolotl/monkeypatch/models/qwen3_5/modeling.py b/src/axolotl/monkeypatch/models/qwen3_5/modeling.py index f88f60555..0b3302d82 100644 --- a/src/axolotl/monkeypatch/models/qwen3_5/modeling.py +++ b/src/axolotl/monkeypatch/models/qwen3_5/modeling.py @@ -35,9 +35,9 @@ def get_cu_seqlens(position_ids): if position_ids.ndim == 3: position_ids = position_ids[0] - tensor_kwargs = {"dtype": torch.long, "device": position_ids.device} - position_ids = position_ids.reshape(-1) - indices_q = (position_ids == 0).nonzero().reshape(-1) + tensor_kwargs = {"dtype": torch.int32, "device": position_ids.device} + position_ids = position_ids.view(-1) + indices_q = (position_ids == 0).nonzero().view(-1) return torch.cat( ( indices_q.to(**tensor_kwargs),