diff --git a/src/axolotl/kernels/moe/torch_grouped.py b/src/axolotl/kernels/moe/torch_grouped.py index 8f1b149fe..858342685 100644 --- a/src/axolotl/kernels/moe/torch_grouped.py +++ b/src/axolotl/kernels/moe/torch_grouped.py @@ -23,7 +23,7 @@ def available() -> bool: major, minor = torch.cuda.get_device_capability() if major < 9: return False - return hasattr(torch.ops, "aten") and hasattr(torch.ops.aten, "_grouped_mm") + return hasattr(torch.ops, "_grouped_mm") except Exception: return False @@ -57,7 +57,7 @@ def _call_grouped_mm( As: List[torch.Tensor], Bs: List[torch.Tensor] ) -> Optional[List[torch.Tensor]]: """ - Call grouped mm using aten._grouped_mm with packed representation. + Call grouped mm with packed representation - A_cat: concat As along rows -> [sum_i Mi, K] - B_stk: stack Bs per group -> [G, K, N] @@ -65,35 +65,23 @@ def _call_grouped_mm( Returns list of per-group outputs split from concatenated result. """ global LAST_ERROR - try: - # Ensure 2D contiguous inputs - As2 = [a.contiguous().view(a.shape[0], a.shape[1]) for a in As] - Bs2 = [b.contiguous().view(b.shape[0], b.shape[1]) for b in Bs] + # Ensure 2D contiguous inputs + As2 = [a.contiguous().view(a.shape[0], a.shape[1]) for a in As] + Bs2 = [b.contiguous().view(b.shape[0], b.shape[1]) for b in Bs] - if not As2: - return [] - device = As2[0].device - A_cat = torch.cat(As2, dim=0) - B_stk = torch.stack(Bs2, dim=0) - offs = torch.tensor([a.shape[0] for a in As2], device=device, dtype=torch.int32) - - if hasattr(torch.ops.aten, "_grouped_mm"): - try: - Y_cat = torch.ops.aten._grouped_mm(A_cat, B_stk, offs) # type: ignore[attr-defined] - outs: List[torch.Tensor] = [] - start = 0 - for m in offs.tolist(): - outs.append(Y_cat[start : start + m, :]) - start += m - return outs - except Exception as e: - LAST_ERROR = f"_grouped_mm failed: {e}" - return None - LAST_ERROR = "aten._grouped_mm not present" - return None - except Exception as e: - LAST_ERROR = f"call error: {e}" - return None + if not As2: + return [] + device = As2[0].device + A_cat = torch.cat(As2, dim=0).to(torch.bfloat16) + B_stk = torch.stack(Bs2, dim=0).to(torch.bfloat16) + offs = torch.tensor([a.shape[0] for a in As2], device=device, dtype=torch.int32) + Y_cat = torch._grouped_mm(A_cat, B_stk, offs) # type: ignore[attr-defined] + outs: List[torch.Tensor] = [] + start = 0 + for m in offs.tolist(): + outs.append(Y_cat[start : start + m, :]) + start += m + return outs def moe_ffn_forward_grouped( @@ -103,6 +91,9 @@ def moe_ffn_forward_grouped( global LAST_ERROR LAST_ERROR = None bsz, seqlen, hdim = hidden_states.shape + compute_dtype = gate_linear.weight.dtype + if hidden_states.dtype != compute_dtype: + hidden_states = hidden_states.to(dtype=compute_dtype) x = hidden_states.view(-1, hdim) router_logits = gate_linear(x) @@ -114,16 +105,7 @@ def moe_ffn_forward_grouped( flat_idx = topk_idx.view(-1) x_rep = x.repeat_interleave(top_k, dim=0) - try: - E = _num_experts(experts_module) - except AttributeError as err: - LAST_ERROR = str(err) - if not getattr(experts_module, "_ax_grouped_logged_fail", False): - _LOGGER.warning( - "torch_grouped: could not determine expert count; falling back to naive" - ) - experts_module._ax_grouped_logged_fail = True - return None, None + E = _num_experts(experts_module) dev, dt = x.device, x.dtype first = experts_module[0] @@ -161,87 +143,76 @@ def moe_ffn_forward_grouped( raise AttributeError(f"expert {idx} missing nested module '{nested_attr}'") return nested_mod - try: - if is_mixtral: - if ( - not hasattr(experts_module, "_stacked_w1") - or experts_module._stacked_w1.device != dev - or experts_module._stacked_w1.dtype != dt - ): - mods = [_resolve_expert(i) for i in range(E)] - w1 = [mod.w1.weight.t() for mod in mods] - w3 = [mod.w3.weight.t() for mod in mods] - w2 = [mod.w2.weight.t() for mod in mods] - experts_module._stacked_w1 = ( - torch.stack(w1, dim=0) - .to(device=dev, dtype=dt, non_blocking=True) - .contiguous() - ) - experts_module._stacked_w3 = ( - torch.stack(w3, dim=0) - .to(device=dev, dtype=dt, non_blocking=True) - .contiguous() - ) - experts_module._stacked_w2 = ( - torch.stack(w2, dim=0) - .to(device=dev, dtype=dt, non_blocking=True) - .contiguous() - ) - experts_module._stacked_w13 = torch.cat( - [experts_module._stacked_w1, experts_module._stacked_w3], dim=-1 - ).contiguous() - W13 = experts_module._stacked_w13 - W2 = experts_module._stacked_w2 - else: - if ( - not hasattr(experts_module, "_stacked_w13") - or experts_module._stacked_w13.device != dev - or experts_module._stacked_w13.dtype != dt - ): - w13 = [] - w2 = [] - for i in range(E): - mod = _resolve_expert(i) - if hasattr(mod, "gate_up_proj"): - w13.append(mod.gate_up_proj.weight.t()) - elif hasattr(mod, "up_proj") and hasattr(mod, "gate_proj"): - w13.append( - torch.cat( - [mod.up_proj.weight.t(), mod.gate_proj.weight.t()], - dim=-1, - ) - ) - else: - LAST_ERROR = "unrecognized Qwen MoE expert weight layout" - if not getattr( - experts_module, "_ax_grouped_logged_fail", False - ): - _LOGGER.warning( - "torch_grouped: could not resolve Qwen MoE expert weights; fallback to naive" - ) - experts_module._ax_grouped_logged_fail = True - return None, None - w2.append(mod.down_proj.weight.t()) - experts_module._stacked_w13 = ( - torch.stack(w13, dim=0) - .to(device=dev, dtype=dt, non_blocking=True) - .contiguous() - ) - experts_module._stacked_w2 = ( - torch.stack(w2, dim=0) - .to(device=dev, dtype=dt, non_blocking=True) - .contiguous() - ) - W13 = experts_module._stacked_w13 - W2 = experts_module._stacked_w2 - except AttributeError as err: - LAST_ERROR = str(err) - if not getattr(experts_module, "_ax_grouped_logged_fail", False): - _LOGGER.warning( - "torch_grouped: expert weights missing expected attributes; falling back to naive" + if is_mixtral: + if ( + not hasattr(experts_module, "_stacked_w1") + or experts_module._stacked_w1.device != dev + or experts_module._stacked_w1.dtype != dt + ): + mods = [_resolve_expert(i) for i in range(E)] + w1 = [mod.w1.weight.t() for mod in mods] + w3 = [mod.w3.weight.t() for mod in mods] + w2 = [mod.w2.weight.t() for mod in mods] + experts_module._stacked_w1 = ( + torch.stack(w1, dim=0) + .to(device=dev, dtype=dt, non_blocking=True) + .contiguous() ) - experts_module._ax_grouped_logged_fail = True - return None, None + experts_module._stacked_w3 = ( + torch.stack(w3, dim=0) + .to(device=dev, dtype=dt, non_blocking=True) + .contiguous() + ) + experts_module._stacked_w2 = ( + torch.stack(w2, dim=0) + .to(device=dev, dtype=dt, non_blocking=True) + .contiguous() + ) + experts_module._stacked_w13 = torch.cat( + [experts_module._stacked_w1, experts_module._stacked_w3], dim=-1 + ).contiguous() + W13 = experts_module._stacked_w13 + W2 = experts_module._stacked_w2 + else: + if ( + not hasattr(experts_module, "_stacked_w13") + or experts_module._stacked_w13.device != dev + or experts_module._stacked_w13.dtype != dt + ): + w13 = [] + w2 = [] + for i in range(E): + mod = _resolve_expert(i) + if hasattr(mod, "gate_up_proj"): + w13.append(mod.gate_up_proj.weight.t()) + elif hasattr(mod, "up_proj") and hasattr(mod, "gate_proj"): + w13.append( + torch.cat( + [mod.up_proj.weight.t(), mod.gate_proj.weight.t()], + dim=-1, + ) + ) + else: + LAST_ERROR = "unrecognized Qwen MoE expert weight layout" + if not getattr(experts_module, "_ax_grouped_logged_fail", False): + _LOGGER.warning( + "torch_grouped: could not resolve Qwen MoE expert weights; fallback to naive" + ) + experts_module._ax_grouped_logged_fail = True + return None, None + w2.append(mod.down_proj.weight.t()) + experts_module._stacked_w13 = ( + torch.stack(w13, dim=0) + .to(device=dev, dtype=dt, non_blocking=True) + .contiguous() + ) + experts_module._stacked_w2 = ( + torch.stack(w2, dim=0) + .to(device=dev, dtype=dt, non_blocking=True) + .contiguous() + ) + W13 = experts_module._stacked_w13 + W2 = experts_module._stacked_w2 As: List[torch.Tensor] = [] Bs: List[torch.Tensor] = [] @@ -251,7 +222,7 @@ def moe_ffn_forward_grouped( if sel.any(): Xi = x_rep[sel].contiguous() As.append(Xi) - Bs.append(W13[i].contiguous()) + Bs.append(W13[i].reshape(hdim, -1).contiguous()) expert_slices.append((i, sel)) if not As: @@ -264,19 +235,13 @@ def moe_ffn_forward_grouped( target_dtype: torch.dtype, ) -> Optional[List[torch.Tensor]]: global LAST_ERROR - try: - if target_dtype != dt: - a_tensors = [t.to(target_dtype).contiguous() for t in a_tensors] - b_tensors = [t.to(target_dtype).contiguous() for t in b_tensors] - outputs = _call_grouped_mm(a_tensors, b_tensors) - if outputs is not None and target_dtype != dt: - outputs = [t.to(dt).contiguous() for t in outputs] - return outputs - except RuntimeError as err: - LAST_ERROR = f"grouped_mm cast failure: {err}" # type: ignore[assignment] - if torch.cuda.is_available(): # pragma: no cover - defensive - torch.cuda.synchronize() - return None + if target_dtype != dt: + a_tensors = [t.to(target_dtype).contiguous() for t in a_tensors] + b_tensors = [t.to(target_dtype).contiguous() for t in b_tensors] + outputs = _call_grouped_mm(a_tensors, b_tensors) + if outputs is not None and target_dtype != dt: + outputs = [t.to(dt).contiguous() for t in outputs] + return outputs def _try_grouped_mm( a_tensors: List[torch.Tensor], b_tensors: List[torch.Tensor] @@ -312,7 +277,7 @@ def moe_ffn_forward_grouped( I2 = Yi.shape[-1] // 2 Yi_hidden = F.silu(Yi[:, :I2]) * Yi[:, I2:] As2.append(Yi_hidden) - Bs2.append(W2[i].contiguous()) + Bs2.append(W2[i].reshape(I2, hdim).contiguous()) Y2_list, _cast_used_down = _try_grouped_mm(As2, Bs2) if Y2_list is None: