From 0e9387c395a71117ce10d7a5c2c80caae52f85e8 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Wed, 17 Sep 2025 14:35:36 -0400 Subject: [PATCH] fix --- src/axolotl/kernels/moe/torch_grouped.py | 29 ++++++++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/src/axolotl/kernels/moe/torch_grouped.py b/src/axolotl/kernels/moe/torch_grouped.py index 9453ad437..18b8dcc06 100644 --- a/src/axolotl/kernels/moe/torch_grouped.py +++ b/src/axolotl/kernels/moe/torch_grouped.py @@ -258,7 +258,28 @@ def moe_ffn_forward_grouped( out = torch.zeros_like(x) return out.view(bsz, seqlen, hdim), router_logits - Y_list = _call_grouped_mm(As, Bs) + comp_dtype = dt + if dt == torch.bfloat16: + comp_dtype = torch.float16 + if not getattr(experts_module, "_ax_grouped_logged_cast", False): + _LOGGER.info("torch_grouped: casting grouped_mm operands to float16") + experts_module._ax_grouped_logged_cast = True + + def _maybe_cast( + tensors: List[torch.Tensor], *, to_dtype: torch.dtype + ) -> List[torch.Tensor]: + if to_dtype == dt: + return tensors + return [t.to(to_dtype) for t in tensors] + + def _restore_dtype(tensors: List[torch.Tensor]) -> List[torch.Tensor]: + if comp_dtype == dt: + return tensors + return [t.to(dt) for t in tensors] + + As_mm = _maybe_cast(As, to_dtype=comp_dtype) + Bs_mm = _maybe_cast(Bs, to_dtype=comp_dtype) + Y_list = _call_grouped_mm(As_mm, Bs_mm) if Y_list is None: if not getattr(experts_module, "_ax_grouped_logged_fail", False): _LOGGER.warning( @@ -266,6 +287,7 @@ def moe_ffn_forward_grouped( ) experts_module._ax_grouped_logged_fail = True return None, None + Y_list = _restore_dtype(Y_list) As2: List[torch.Tensor] = [] Bs2: List[torch.Tensor] = [] @@ -276,7 +298,9 @@ def moe_ffn_forward_grouped( As2.append(Yi_hidden) Bs2.append(W2[i]) - Y2_list = _call_grouped_mm(As2, Bs2) + As2_mm = _maybe_cast(As2, to_dtype=comp_dtype) + Bs2_mm = _maybe_cast(Bs2, to_dtype=comp_dtype) + Y2_list = _call_grouped_mm(As2_mm, Bs2_mm) if Y2_list is None: if not getattr(experts_module, "_ax_grouped_logged_fail", False): _LOGGER.warning( @@ -284,6 +308,7 @@ def moe_ffn_forward_grouped( ) experts_module._ax_grouped_logged_fail = True return None, None + Y2_list = _restore_dtype(Y2_list) for (_i, sel), Out_i in zip(expert_slices, Y2_list, strict=False): y_buf[sel] = Out_i