fix: position_ids casted to int64 for qwen35 patch (#3468) [skip ci]
* fix: position_ids casted to int64 for qwen35 patch * fix: to use view instead of reshape to ensure noncontiguous error explicitly * chore: lint
This commit is contained in:
@@ -35,9 +35,9 @@ def get_cu_seqlens(position_ids):
|
||||
if position_ids.ndim == 3:
|
||||
position_ids = position_ids[0]
|
||||
|
||||
tensor_kwargs = {"dtype": torch.long, "device": position_ids.device}
|
||||
position_ids = position_ids.reshape(-1)
|
||||
indices_q = (position_ids == 0).nonzero().reshape(-1)
|
||||
tensor_kwargs = {"dtype": torch.int32, "device": position_ids.device}
|
||||
position_ids = position_ids.view(-1)
|
||||
indices_q = (position_ids == 0).nonzero().view(-1)
|
||||
return torch.cat(
|
||||
(
|
||||
indices_q.to(**tensor_kwargs),
|
||||
|
||||
Reference in New Issue
Block a user