From 1f5c0d3613b13b1d480fac1f9e29e1c091278d7f Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 23 May 2025 11:50:37 -0400 Subject: [PATCH] fix graph break for compile --- src/axolotl/monkeypatch/utils.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/src/axolotl/monkeypatch/utils.py b/src/axolotl/monkeypatch/utils.py index 4c6a4de11..9b3b97ab6 100644 --- a/src/axolotl/monkeypatch/utils.py +++ b/src/axolotl/monkeypatch/utils.py @@ -16,15 +16,24 @@ from transformers.utils import is_torch_bf16_gpu_available @torch.jit.script def get_max_seqlen_in_batch(attention_mask: torch.Tensor) -> torch.Tensor: - max_num = int(torch.max(attention_mask).item()) - batch_size, _ = attention_mask.shape - counts = torch.zeros((batch_size, max_num), dtype=torch.int32) - for i in range(1, max_num + 1): - mask = attention_mask == i - counts[:, i - 1] = torch.sum(mask, dim=-1).to(dtype=torch.int32) + # Keep max_num as a tensor instead of extracting to Python int + max_num = torch.max(attention_mask) + + # Create a range tensor for comparison + range_tensor = torch.arange( + 1, max_num + 1, device=attention_mask.device, dtype=attention_mask.dtype + ) + + # Vectorized approach - compare attention_mask with each value in range + mask = attention_mask.unsqueeze(-1) == range_tensor.unsqueeze(0).unsqueeze(0) + + # Sum along sequence dimension to get counts + counts = mask.sum(dim=1).to(dtype=torch.int32) + + # Flatten and filter non-zero values result = counts.flatten() - nonzero_indices = torch.nonzero(result).squeeze(-1) - return result[nonzero_indices] + nonzero_mask = result != 0 + return result[nonzero_mask] @torch.jit.script