bits and pieces

This commit is contained in:
Dan Saunders
2025-09-19 02:12:57 +00:00
parent abe1cad6bc
commit bfc848f81d
2 changed files with 15 additions and 2 deletions

View File

@@ -189,6 +189,20 @@ def main():
else:
print("torch_grouped\tN/A (unavailable)")
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CUDA], record_shapes=True
) as prof:
forward_naive(x, gate, experts, args.top_k)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CUDA],
record_shapes=True,
with_stack=False,
) as prof:
forward_tg(x, gate, experts, args.top_k)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))
if __name__ == "__main__":
main()

View File

@@ -300,11 +300,10 @@ def moe_ffn_forward_grouped(
routed_input = torch.gather(x_flat, 0, gather_index)
counts_i32 = assignments.to(device=device, dtype=torch.int32)
offsets = torch.cumsum(counts_i32, dim=0)
offsets = torch.cumsum(counts_i32, dim=0).to(dtype=torch.int32)
if offsets[-1].item() == 0:
zero = torch.zeros_like(x_flat)
return zero.view(bsz, seqlen, hdim), router_logits
mm_dtype = torch.bfloat16 if expert_dtype == torch.bfloat16 else expert_dtype
routed_in = routed_input.to(mm_dtype)
w_gate_t = w_gate.transpose(-2, -1).to(mm_dtype)