From f3b953e22224565252315330fdb4c1c79dad7bcc Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Wed, 17 Sep 2025 18:42:10 -0400 Subject: [PATCH] fix? --- src/axolotl/kernels/moe/torch_grouped.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/axolotl/kernels/moe/torch_grouped.py b/src/axolotl/kernels/moe/torch_grouped.py index 467dfeb89..c442a0e9e 100644 --- a/src/axolotl/kernels/moe/torch_grouped.py +++ b/src/axolotl/kernels/moe/torch_grouped.py @@ -57,8 +57,8 @@ def _stack_weights( def _call_grouped_mm( As: List[torch.Tensor], Bs: List[torch.Tensor], dtype: torch.dtype ) -> Optional[List[torch.Tensor]]: - if not As: - return [] + if not As or dtype not in (torch.bfloat16, torch.float16): + return [] if not As else None try: As2 = [a.to(dtype).contiguous().view(a.shape[0], a.shape[1]) for a in As] @@ -94,6 +94,14 @@ def moe_ffn_forward_grouped( routing_dtype = gate_linear.weight.dtype + expert_dtype = hidden_states.dtype + if expert_dtype not in (torch.bfloat16, torch.float16): + _LOGGER.debug( + "torch_grouped: unsupported expert dtype %s; falling back to naive", + expert_dtype, + ) + return None, None + sample_mod = getattr( experts_module[0], "mlp", getattr(experts_module[0], "ffn", experts_module[0]) ) @@ -102,7 +110,6 @@ def moe_ffn_forward_grouped( and hasattr(sample_mod, "w3") and hasattr(sample_mod, "w2") ): - expert_dtype = sample_mod.w1.weight.dtype w13 = _stack_weights( experts_module, ("w1", "w3"), key="w13", dtype=expert_dtype, device=device ) @@ -111,10 +118,8 @@ def moe_ffn_forward_grouped( ) else: if hasattr(sample_mod, "gate_up_proj"): - expert_dtype = sample_mod.gate_up_proj.weight.dtype names13: Tuple[str, ...] = ("gate_up_proj",) else: - expert_dtype = sample_mod.up_proj.weight.dtype names13 = ("up_proj", "gate_proj") w13 = _stack_weights( experts_module, names13, key="w13", dtype=expert_dtype, device=device