diff --git a/src/axolotl/kernels/moe/torch_grouped.py b/src/axolotl/kernels/moe/torch_grouped.py index 44447d83e..b0f308497 100644 --- a/src/axolotl/kernels/moe/torch_grouped.py +++ b/src/axolotl/kernels/moe/torch_grouped.py @@ -56,9 +56,13 @@ def _stack_weights( def _call_grouped_mm( As: List[torch.Tensor], Bs: List[torch.Tensor], dtype: torch.dtype -) -> Optional[List[torch.Tensor]]: - if not As or dtype not in (torch.bfloat16, torch.float16): - return [] if not As else None +) -> Tuple[Optional[List[torch.Tensor]], Optional[str]]: + if not As: + return [], None + if dtype not in (torch.bfloat16, torch.float16): + msg = f"unsupported dtype {dtype}" + _LOGGER.debug("torch_grouped: %s", msg) + return None, msg try: As2 = [a.to(dtype).contiguous().view(a.shape[0], a.shape[1]) for a in As] @@ -68,7 +72,7 @@ def _call_grouped_mm( [a.shape[0] for a in As2], device=device, dtype=torch.int32 ) offsets = torch.cumsum(lengths, dim=0) - Y_cat = torch.ops.aten._grouped_mm( + Y_cat = torch._grouped_mm( torch.cat(As2, dim=0), torch.stack(Bs2, dim=0), offsets ) outs: List[torch.Tensor] = [] @@ -76,10 +80,11 @@ def _call_grouped_mm( for size in lengths.tolist(): outs.append(Y_cat[start : start + size]) start += size - return outs + return outs, None except RuntimeError as err: - _LOGGER.warning("torch_grouped: _grouped_mm failed (%s)", err) - return None + message = f"_grouped_mm failed ({err})" + _LOGGER.warning("torch_grouped: %s", message) + return None, message def moe_ffn_forward_grouped( @@ -154,7 +159,7 @@ def moe_ffn_forward_grouped( if not as_list: return torch.zeros_like(x_flat).view(bsz, seqlen, hdim), router_logits - up_out = _call_grouped_mm(as_list, bs_list, expert_dtype) + up_out, reason = _call_grouped_mm(as_list, bs_list, expert_dtype) if up_out is None: return None, None @@ -167,7 +172,7 @@ def moe_ffn_forward_grouped( down_inputs.append(hidden) down_weights.append(w2[i]) - down_out = _call_grouped_mm(down_inputs, down_weights, expert_dtype) + down_out, reason = _call_grouped_mm(down_inputs, down_weights, expert_dtype) if down_out is None: return None, None