From d024048d74c8faa8a64548dd7938531e93815c31 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Wed, 17 Sep 2025 14:50:49 -0400 Subject: [PATCH] logs + fix --- src/axolotl/kernels/moe/torch_grouped.py | 27 +++++++++++++++--------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/src/axolotl/kernels/moe/torch_grouped.py b/src/axolotl/kernels/moe/torch_grouped.py index f44684abd..8f1b149fe 100644 --- a/src/axolotl/kernels/moe/torch_grouped.py +++ b/src/axolotl/kernels/moe/torch_grouped.py @@ -249,9 +249,9 @@ def moe_ffn_forward_grouped( for i in range(E): sel = flat_idx == i if sel.any(): - Xi = x_rep[sel] + Xi = x_rep[sel].contiguous() As.append(Xi) - Bs.append(W13[i]) + Bs.append(W13[i].contiguous()) expert_slices.append((i, sel)) if not As: @@ -263,13 +263,20 @@ def moe_ffn_forward_grouped( b_tensors: List[torch.Tensor], target_dtype: torch.dtype, ) -> Optional[List[torch.Tensor]]: - if target_dtype != dt: - a_tensors = [t.to(target_dtype) for t in a_tensors] - b_tensors = [t.to(target_dtype) for t in b_tensors] - outputs = _call_grouped_mm(a_tensors, b_tensors) - if outputs is not None and target_dtype != dt: - outputs = [t.to(dt) for t in outputs] - return outputs + global LAST_ERROR + try: + if target_dtype != dt: + a_tensors = [t.to(target_dtype).contiguous() for t in a_tensors] + b_tensors = [t.to(target_dtype).contiguous() for t in b_tensors] + outputs = _call_grouped_mm(a_tensors, b_tensors) + if outputs is not None and target_dtype != dt: + outputs = [t.to(dt).contiguous() for t in outputs] + return outputs + except RuntimeError as err: + LAST_ERROR = f"grouped_mm cast failure: {err}" # type: ignore[assignment] + if torch.cuda.is_available(): # pragma: no cover - defensive + torch.cuda.synchronize() + return None def _try_grouped_mm( a_tensors: List[torch.Tensor], b_tensors: List[torch.Tensor] @@ -305,7 +312,7 @@ def moe_ffn_forward_grouped( I2 = Yi.shape[-1] // 2 Yi_hidden = F.silu(Yi[:, :I2]) * Yi[:, I2:] As2.append(Yi_hidden) - Bs2.append(W2[i]) + Bs2.append(W2[i].contiguous()) Y2_list, _cast_used_down = _try_grouped_mm(As2, Bs2) if Y2_list is None: