From 98dc945838056d42553497fd5a645f0490240712 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Wed, 17 Sep 2025 14:42:53 -0400 Subject: [PATCH] fix --- src/axolotl/kernels/moe/torch_grouped.py | 56 +++++++++++++----------- 1 file changed, 31 insertions(+), 25 deletions(-) diff --git a/src/axolotl/kernels/moe/torch_grouped.py b/src/axolotl/kernels/moe/torch_grouped.py index 18b8dcc06..f44684abd 100644 --- a/src/axolotl/kernels/moe/torch_grouped.py +++ b/src/axolotl/kernels/moe/torch_grouped.py @@ -258,28 +258,38 @@ def moe_ffn_forward_grouped( out = torch.zeros_like(x) return out.view(bsz, seqlen, hdim), router_logits - comp_dtype = dt - if dt == torch.bfloat16: - comp_dtype = torch.float16 - if not getattr(experts_module, "_ax_grouped_logged_cast", False): - _LOGGER.info("torch_grouped: casting grouped_mm operands to float16") - experts_module._ax_grouped_logged_cast = True + def _run_grouped_mm( + a_tensors: List[torch.Tensor], + b_tensors: List[torch.Tensor], + target_dtype: torch.dtype, + ) -> Optional[List[torch.Tensor]]: + if target_dtype != dt: + a_tensors = [t.to(target_dtype) for t in a_tensors] + b_tensors = [t.to(target_dtype) for t in b_tensors] + outputs = _call_grouped_mm(a_tensors, b_tensors) + if outputs is not None and target_dtype != dt: + outputs = [t.to(dt) for t in outputs] + return outputs - def _maybe_cast( - tensors: List[torch.Tensor], *, to_dtype: torch.dtype - ) -> List[torch.Tensor]: - if to_dtype == dt: - return tensors - return [t.to(to_dtype) for t in tensors] + def _try_grouped_mm( + a_tensors: List[torch.Tensor], b_tensors: List[torch.Tensor] + ) -> Tuple[Optional[List[torch.Tensor]], bool]: + global LAST_ERROR + result = _run_grouped_mm(a_tensors, b_tensors, target_dtype=dt) + cast_used_local = False + if result is None and dt == torch.bfloat16: + result = _run_grouped_mm(a_tensors, b_tensors, target_dtype=torch.float16) + if result is not None: + cast_used_local = True + LAST_ERROR = None + if not getattr(experts_module, "_ax_grouped_logged_cast", False): + _LOGGER.info( + "torch_grouped: grouped_mm casting bfloat16 operands to float16" + ) + experts_module._ax_grouped_logged_cast = True + return result, cast_used_local - def _restore_dtype(tensors: List[torch.Tensor]) -> List[torch.Tensor]: - if comp_dtype == dt: - return tensors - return [t.to(dt) for t in tensors] - - As_mm = _maybe_cast(As, to_dtype=comp_dtype) - Bs_mm = _maybe_cast(Bs, to_dtype=comp_dtype) - Y_list = _call_grouped_mm(As_mm, Bs_mm) + Y_list, _cast_used_up = _try_grouped_mm(As, Bs) if Y_list is None: if not getattr(experts_module, "_ax_grouped_logged_fail", False): _LOGGER.warning( @@ -287,7 +297,6 @@ def moe_ffn_forward_grouped( ) experts_module._ax_grouped_logged_fail = True return None, None - Y_list = _restore_dtype(Y_list) As2: List[torch.Tensor] = [] Bs2: List[torch.Tensor] = [] @@ -298,9 +307,7 @@ def moe_ffn_forward_grouped( As2.append(Yi_hidden) Bs2.append(W2[i]) - As2_mm = _maybe_cast(As2, to_dtype=comp_dtype) - Bs2_mm = _maybe_cast(Bs2, to_dtype=comp_dtype) - Y2_list = _call_grouped_mm(As2_mm, Bs2_mm) + Y2_list, _cast_used_down = _try_grouped_mm(As2, Bs2) if Y2_list is None: if not getattr(experts_module, "_ax_grouped_logged_fail", False): _LOGGER.warning( @@ -308,7 +315,6 @@ def moe_ffn_forward_grouped( ) experts_module._ax_grouped_logged_fail = True return None, None - Y2_list = _restore_dtype(Y2_list) for (_i, sel), Out_i in zip(expert_slices, Y2_list, strict=False): y_buf[sel] = Out_i