fix: position_ids casted to int64 for qwen35 patch (#3468) [skip ci]

* fix: position_ids casted to int64 for qwen35 patch

* fix: to use view instead of reshape to ensure noncontiguous error explicitly

* chore: lint
This commit is contained in:
NanoCode012
2026-03-06 23:44:00 +07:00
committed by GitHub
parent fc2d63ee5f
commit 0a23ae08f7

View File

@@ -35,9 +35,9 @@ def get_cu_seqlens(position_ids):
if position_ids.ndim == 3:
position_ids = position_ids[0]
tensor_kwargs = {"dtype": torch.long, "device": position_ids.device}
position_ids = position_ids.reshape(-1)
indices_q = (position_ids == 0).nonzero().reshape(-1)
tensor_kwargs = {"dtype": torch.int32, "device": position_ids.device}
position_ids = position_ids.view(-1)
indices_q = (position_ids == 0).nonzero().view(-1)
return torch.cat(
(
indices_q.to(**tensor_kwargs),