From bfc848f81d6daacd332363717959105c2f33a476 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Fri, 19 Sep 2025 02:12:57 +0000 Subject: [PATCH] bits and pieces --- scripts/bench_moe.py | 14 ++++++++++++++ src/axolotl/kernels/moe/torch_grouped.py | 3 +-- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/scripts/bench_moe.py b/scripts/bench_moe.py index 49f54a23b..76c3929c2 100644 --- a/scripts/bench_moe.py +++ b/scripts/bench_moe.py @@ -189,6 +189,20 @@ def main(): else: print("torch_grouped\tN/A (unavailable)") + with torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CUDA], record_shapes=True + ) as prof: + forward_naive(x, gate, experts, args.top_k) + print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20)) + + with torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CUDA], + record_shapes=True, + with_stack=False, + ) as prof: + forward_tg(x, gate, experts, args.top_k) + print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20)) + if __name__ == "__main__": main() diff --git a/src/axolotl/kernels/moe/torch_grouped.py b/src/axolotl/kernels/moe/torch_grouped.py index 4f161b571..71e5d5626 100644 --- a/src/axolotl/kernels/moe/torch_grouped.py +++ b/src/axolotl/kernels/moe/torch_grouped.py @@ -300,11 +300,10 @@ def moe_ffn_forward_grouped( routed_input = torch.gather(x_flat, 0, gather_index) counts_i32 = assignments.to(device=device, dtype=torch.int32) - offsets = torch.cumsum(counts_i32, dim=0) + offsets = torch.cumsum(counts_i32, dim=0).to(dtype=torch.int32) if offsets[-1].item() == 0: zero = torch.zeros_like(x_flat) return zero.view(bsz, seqlen, hdim), router_logits - mm_dtype = torch.bfloat16 if expert_dtype == torch.bfloat16 else expert_dtype routed_in = routed_input.to(mm_dtype) w_gate_t = w_gate.transpose(-2, -1).to(mm_dtype)