fix graph break for compile
This commit is contained in:
@@ -16,15 +16,24 @@ from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
@torch.jit.script
|
||||
def get_max_seqlen_in_batch(attention_mask: torch.Tensor) -> torch.Tensor:
|
||||
max_num = int(torch.max(attention_mask).item())
|
||||
batch_size, _ = attention_mask.shape
|
||||
counts = torch.zeros((batch_size, max_num), dtype=torch.int32)
|
||||
for i in range(1, max_num + 1):
|
||||
mask = attention_mask == i
|
||||
counts[:, i - 1] = torch.sum(mask, dim=-1).to(dtype=torch.int32)
|
||||
# Keep max_num as a tensor instead of extracting to Python int
|
||||
max_num = torch.max(attention_mask)
|
||||
|
||||
# Create a range tensor for comparison
|
||||
range_tensor = torch.arange(
|
||||
1, max_num + 1, device=attention_mask.device, dtype=attention_mask.dtype
|
||||
)
|
||||
|
||||
# Vectorized approach - compare attention_mask with each value in range
|
||||
mask = attention_mask.unsqueeze(-1) == range_tensor.unsqueeze(0).unsqueeze(0)
|
||||
|
||||
# Sum along sequence dimension to get counts
|
||||
counts = mask.sum(dim=1).to(dtype=torch.int32)
|
||||
|
||||
# Flatten and filter non-zero values
|
||||
result = counts.flatten()
|
||||
nonzero_indices = torch.nonzero(result).squeeze(-1)
|
||||
return result[nonzero_indices]
|
||||
nonzero_mask = result != 0
|
||||
return result[nonzero_mask]
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
|
||||
Reference in New Issue
Block a user