diff --git a/src/axolotl/kernels/moe/torch_grouped.py b/src/axolotl/kernels/moe/torch_grouped.py index fbb6f9b4d..99d861551 100644 --- a/src/axolotl/kernels/moe/torch_grouped.py +++ b/src/axolotl/kernels/moe/torch_grouped.py @@ -26,54 +26,30 @@ def available() -> bool: return False +def _iter_expert_impls(experts_module) -> List[torch.nn.Module]: + impls: List[torch.nn.Module] = [] + for exp in experts_module: + impls.append(getattr(exp, "mlp", getattr(exp, "ffn", exp))) + return impls + + def _stack_weights( experts_module, names: Tuple[str, ...], *, - key: str, dtype: torch.dtype, device: torch.device, ) -> torch.Tensor: - attr = f"_ax_grouped_{key}" - cached = getattr(experts_module, attr, None) - if cached is not None and cached.dtype == dtype and cached.device == device: - return cached - tensors: List[torch.Tensor] = [] - for exp in experts_module: - mod = getattr(exp, "mlp", getattr(exp, "ffn", exp)) + for mod in _iter_expert_impls(experts_module): parts = [getattr(mod, name).weight.t() for name in names] tensors.append(parts[0] if len(parts) == 1 else torch.cat(parts, dim=-1)) - stacked = ( + return ( torch.stack(tensors, dim=0) .to(device=device, dtype=dtype, non_blocking=True) .contiguous() ) - setattr(experts_module, attr, stacked) - return stacked - - -def _call_grouped_mm( - As: List[torch.Tensor], Bs: List[torch.Tensor], dtype: torch.dtype -) -> List[torch.Tensor]: - if not As: - return [] - if dtype not in (torch.bfloat16, torch.float16): - raise RuntimeError(f"unsupported dtype {dtype} for grouped_mm") - - As2 = [a.to(dtype).contiguous().view(a.shape[0], a.shape[1]) for a in As] - Bs2 = [b.to(dtype).contiguous().view(b.shape[0], b.shape[1]) for b in Bs] - device = As2[0].device - lengths = torch.tensor([a.shape[0] for a in As2], device=device, dtype=torch.int32) - offsets = torch.cumsum(lengths, dim=0).to(torch.int32) - Y_cat = torch._grouped_mm(torch.cat(As2, dim=0), torch.stack(Bs2, dim=0), offsets) - outs: List[torch.Tensor] = [] - start = 0 - for size in lengths.tolist(): - outs.append(Y_cat[start : start + size]) - start += size - return outs def moe_ffn_forward_grouped( @@ -99,30 +75,30 @@ def moe_ffn_forward_grouped( ) return None, None - sample_mod = getattr( - experts_module[0], "mlp", getattr(experts_module[0], "ffn", experts_module[0]) - ) + for suffix in ("w13", "w2"): + attr = f"_ax_grouped_{suffix}" + if hasattr(experts_module, attr): + delattr(experts_module, attr) + + expert_impls = _iter_expert_impls(experts_module) + sample_mod = expert_impls[0] if ( hasattr(sample_mod, "w1") and hasattr(sample_mod, "w3") and hasattr(sample_mod, "w2") ): w13 = _stack_weights( - experts_module, ("w1", "w3"), key="w13", dtype=expert_dtype, device=device - ) - w2 = _stack_weights( - experts_module, ("w2",), key="w2", dtype=expert_dtype, device=device + experts_module, ("w1", "w3"), dtype=expert_dtype, device=device ) + w2 = _stack_weights(experts_module, ("w2",), dtype=expert_dtype, device=device) else: if hasattr(sample_mod, "gate_up_proj"): names13: Tuple[str, ...] = ("gate_up_proj",) else: names13 = ("up_proj", "gate_proj") - w13 = _stack_weights( - experts_module, names13, key="w13", dtype=expert_dtype, device=device - ) + w13 = _stack_weights(experts_module, names13, dtype=expert_dtype, device=device) w2 = _stack_weights( - experts_module, ("down_proj",), key="w2", dtype=expert_dtype, device=device + experts_module, ("down_proj",), dtype=expert_dtype, device=device ) x_flat = hidden_states.view(tokens, hdim).to(expert_dtype) @@ -133,41 +109,45 @@ def moe_ffn_forward_grouped( topk_weight = topk_weight / topk_weight.sum(dim=-1, keepdim=True) flat_idx = topk_idx.view(-1) - x_rep = x_flat.repeat_interleave(top_k, dim=0) + num_experts = len(expert_impls) + if flat_idx.numel() == 0: + zero = torch.zeros_like(x_flat) + return zero.view(bsz, seqlen, hdim), router_logits - as_list: List[torch.Tensor] = [] - bs_list: List[torch.Tensor] = [] - slices: List[Tuple[int, torch.Tensor]] = [] - for i, _ in enumerate(experts_module): - sel = flat_idx == i - if sel.any(): - as_list.append(x_rep[sel]) - bs_list.append(w13[i]) - slices.append((i, sel)) + assignments = torch.bincount(flat_idx, minlength=num_experts) + if assignments.sum() == 0: + zero = torch.zeros_like(x_flat) + return zero.view(bsz, seqlen, hdim), router_logits - if not as_list: - return torch.zeros_like(x_flat).view(bsz, seqlen, hdim), router_logits + perm = torch.argsort(flat_idx, stable=True) + token_indices_sorted = perm // top_k + scores_sorted = topk_weight.view(-1)[perm] - up_out = _call_grouped_mm(as_list, bs_list, expert_dtype) + gather_index = token_indices_sorted.unsqueeze(-1).expand(-1, hdim) + routed_input = torch.gather(x_flat, 0, gather_index).contiguous() - down_inputs: List[torch.Tensor] = [] - down_weights: List[torch.Tensor] = [] - buf = torch.empty_like(x_rep) - for (i, _sel), Yi in zip(slices, up_out, strict=False): - mid = Yi.shape[-1] // 2 - hidden = F.silu(Yi[:, :mid]) * Yi[:, mid:] - down_inputs.append(hidden) - down_weights.append(w2[i]) + offsets = torch.cumsum(assignments.to(device=device, dtype=torch.int32), dim=0) + if offsets[-1].item() == 0: + zero = torch.zeros_like(x_flat) + return zero.view(bsz, seqlen, hdim), router_logits - down_out = _call_grouped_mm(down_inputs, down_weights, expert_dtype) + mid = w13.shape[-1] // 2 + w_gate = w13[..., :mid] + w_up = w13[..., mid:] - for (_i, sel), tensor in zip(slices, down_out, strict=False): - buf[sel] = tensor + w_gate_t = w_gate.transpose(-2, -1).contiguous() + w_up_t = w_up.transpose(-2, -1).contiguous() + w2_t = w2.transpose(-2, -1).contiguous() - combined = ( - (buf.view(tokens, top_k, -1) * topk_weight.to(expert_dtype).unsqueeze(-1)) - .sum(dim=1) - .to(torch.bfloat16) - ) + routed_in = routed_input.to(expert_dtype) + gate_out = torch._grouped_mm(routed_in, w_gate_t, offs=offsets) + up_out = torch._grouped_mm(routed_in, w_up_t, offs=offsets) + activated = F.silu(gate_out) * up_out + down_out = torch._grouped_mm(activated, w2_t, offs=offsets) + weights_fp32 = scores_sorted.unsqueeze(-1).to(torch.float32) + weighted = (down_out.to(torch.float32) * weights_fp32).to(expert_dtype) + + combined = torch.zeros_like(x_flat) + combined.scatter_add_(0, gather_index, weighted) return combined.view(bsz, seqlen, hdim), router_logits