diff --git a/scripts/bench_torchtitan_moe_sweep.py b/scripts/bench_torchtitan_moe_sweep.py index 61c5c77aa..b1afdd653 100644 --- a/scripts/bench_torchtitan_moe_sweep.py +++ b/scripts/bench_torchtitan_moe_sweep.py @@ -5,7 +5,6 @@ from __future__ import annotations import argparse import csv -import math import sys import time from dataclasses import dataclass @@ -28,9 +27,15 @@ def _parse_int_list(value: str) -> List[int]: def _parse_args() -> argparse.Namespace: p = argparse.ArgumentParser(description="Torchtitan MoE grouped vs naive sweep") - p.add_argument("--batch-sizes", default="4,8,16", help="Comma separated batch sizes") - p.add_argument("--seq-lens", default="1024,2048", help="Comma separated sequence lengths") - p.add_argument("--experts", default="8,16,32,64", help="Comma separated expert counts") + p.add_argument( + "--batch-sizes", default="4,8,16", help="Comma separated batch sizes" + ) + p.add_argument( + "--seq-lens", default="1024,2048", help="Comma separated sequence lengths" + ) + p.add_argument( + "--experts", default="8,16,32,64", help="Comma separated expert counts" + ) p.add_argument("--top-ks", default="1,2,4", help="Comma separated top_k choices") p.add_argument("--hidden", type=int, default=4096) p.add_argument("--inter", type=int, default=14336) @@ -188,9 +193,7 @@ def _run_case( diff = (y_naive.float() - y_grouped.float()).abs() max_abs = diff.max().item() mean_abs = diff.mean().item() - rel_l2 = ( - (diff.pow(2).sum() / (y_naive.float().pow(2).sum() + 1e-12)).sqrt().item() - ) + rel_l2 = (diff.pow(2).sum() / (y_naive.float().pow(2).sum() + 1e-12)).sqrt().item() tokens = bsz * seq flops = _estimate_flops(tokens, hidden, inter, top_k) @@ -215,10 +218,10 @@ def _run_case( ) -def _print_header(hidden: int, inter: int, dtype: torch.dtype, device: torch.device) -> None: - print( - f"Device={device} dtype={dtype} hidden={hidden} inter={inter}" - ) +def _print_header( + hidden: int, inter: int, dtype: torch.dtype, device: torch.device +) -> None: + print(f"Device={device} dtype={dtype} hidden={hidden} inter={inter}") print( "bsz\tseq\texperts\ttop_k\tnaive(ms)\tgrouped(ms)\tspeedup\t" "naive TF/s\tgrouped TF/s\tmax_abs\tmean_abs\trel_l2" diff --git a/src/axolotl/kernels/moe/torch_grouped.py b/src/axolotl/kernels/moe/torch_grouped.py index 71e5d5626..b5ea0c532 100644 --- a/src/axolotl/kernels/moe/torch_grouped.py +++ b/src/axolotl/kernels/moe/torch_grouped.py @@ -277,6 +277,17 @@ def moe_ffn_forward_grouped( x_flat = hidden_states.view(tokens, hdim).to(expert_dtype) router_logits = gate_linear(x_flat.to(routing_dtype)) + shared_out_flat: Optional[torch.Tensor] = None + if hasattr(experts_module, "shared_expert"): + shared_expert = experts_module.shared_expert + shared_out_flat = shared_expert(x_flat) + shared_out_flat = shared_out_flat.to(expert_dtype) + shared_gate = getattr(experts_module, "shared_expert_gate", None) + if shared_gate is not None: + gate_input = shared_gate(x_flat.to(shared_gate.weight.dtype)) + gate_vals = torch.sigmoid(gate_input) + shared_out_flat.mul_(gate_vals.to(expert_dtype)) + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) topk_weight, topk_idx = torch.topk(routing_weights, top_k, dim=-1, sorted=False) topk_weight = topk_weight / topk_weight.sum(dim=-1, keepdim=True) @@ -321,4 +332,8 @@ def moe_ffn_forward_grouped( combined = torch.zeros_like(x_flat) combined.scatter_add_(0, gather_index, down_out) - return combined.view(bsz, seqlen, hdim), router_logits + + output = combined.view(bsz, seqlen, hdim) + if shared_out_flat is not None: + output = output + shared_out_flat.view(bsz, seqlen, hdim) + return output, router_logits