diff --git a/src/axolotl/kernels/moe/torch_grouped.py b/src/axolotl/kernels/moe/torch_grouped.py index 858342685..72efde43c 100644 --- a/src/axolotl/kernels/moe/torch_grouped.py +++ b/src/axolotl/kernels/moe/torch_grouped.py @@ -91,22 +91,35 @@ def moe_ffn_forward_grouped( global LAST_ERROR LAST_ERROR = None bsz, seqlen, hdim = hidden_states.shape - compute_dtype = gate_linear.weight.dtype - if hidden_states.dtype != compute_dtype: - hidden_states = hidden_states.to(dtype=compute_dtype) + routing_dtype = gate_linear.weight.dtype + use_mixed_router = ( + hidden_states.device.type == "cuda" and routing_dtype == torch.float32 + ) + + if use_mixed_router: + x_router = hidden_states.to(dtype=routing_dtype) + router_logits = gate_linear(x_router) + else: + if hidden_states.dtype != routing_dtype: + hidden_states = hidden_states.to(dtype=routing_dtype) + x = hidden_states.view(-1, hdim) + router_logits = gate_linear(x) + + if router_logits.dtype != routing_dtype: + router_logits = router_logits.to(dtype=routing_dtype) + x = hidden_states.view(-1, hdim) - router_logits = gate_linear(x) # top-k routing executed in torch to avoid extra dependencies routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) topk_weight, topk_idx = torch.topk(routing_weights, top_k, dim=-1, sorted=False) - topk_weight = (topk_weight / topk_weight.sum(dim=-1, keepdim=True)).to(x.dtype) + topk_weight = topk_weight / topk_weight.sum(dim=-1, keepdim=True) flat_idx = topk_idx.view(-1) - x_rep = x.repeat_interleave(top_k, dim=0) E = _num_experts(experts_module) - dev, dt = x.device, x.dtype + dev = hidden_states.device + dt: torch.dtype = hidden_states.dtype first = experts_module[0] is_mixtral = _is_mixtral_layout(first) @@ -134,7 +147,7 @@ def moe_ffn_forward_grouped( LAST_ERROR = "unsupported expert layout" return None, None - def _resolve_expert(idx: int): + def _resolve_expert(idx: int) -> torch.nn.Module: expert = experts_module[idx] if nested_attr is None: return expert @@ -214,6 +227,18 @@ def moe_ffn_forward_grouped( W13 = experts_module._stacked_w13 W2 = experts_module._stacked_w2 + dt = W13.dtype + if router_logits.dtype != dt: + router_logits = router_logits.to(dtype=dt) + if x.dtype != dt: + x = x.to(dtype=dt) + flat_idx = topk_idx.view(-1) + if topk_weight.dtype != dt: + topk_weight = topk_weight.to(dtype=dt) + x_rep = x.repeat_interleave(top_k, dim=0) + if x_rep.dtype != dt: + x_rep = x_rep.to(dtype=dt) + As: List[torch.Tensor] = [] Bs: List[torch.Tensor] = [] expert_slices: List[Tuple[int, torch.Tensor]] = [] @@ -272,7 +297,7 @@ def moe_ffn_forward_grouped( As2: List[torch.Tensor] = [] Bs2: List[torch.Tensor] = [] - y_buf = torch.empty_like(x_rep) + y_buf = torch.empty_like(x_rep, dtype=dt) for (i, _sel), Yi in zip(expert_slices, Y_list, strict=False): I2 = Yi.shape[-1] // 2 Yi_hidden = F.silu(Yi[:, :I2]) * Yi[:, I2:]