bits and pieces
This commit is contained in:
@@ -189,6 +189,20 @@ def main():
|
||||
else:
|
||||
print("torch_grouped\tN/A (unavailable)")
|
||||
|
||||
with torch.profiler.profile(
|
||||
activities=[torch.profiler.ProfilerActivity.CUDA], record_shapes=True
|
||||
) as prof:
|
||||
forward_naive(x, gate, experts, args.top_k)
|
||||
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))
|
||||
|
||||
with torch.profiler.profile(
|
||||
activities=[torch.profiler.ProfilerActivity.CUDA],
|
||||
record_shapes=True,
|
||||
with_stack=False,
|
||||
) as prof:
|
||||
forward_tg(x, gate, experts, args.top_k)
|
||||
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -300,11 +300,10 @@ def moe_ffn_forward_grouped(
|
||||
routed_input = torch.gather(x_flat, 0, gather_index)
|
||||
|
||||
counts_i32 = assignments.to(device=device, dtype=torch.int32)
|
||||
offsets = torch.cumsum(counts_i32, dim=0)
|
||||
offsets = torch.cumsum(counts_i32, dim=0).to(dtype=torch.int32)
|
||||
if offsets[-1].item() == 0:
|
||||
zero = torch.zeros_like(x_flat)
|
||||
return zero.view(bsz, seqlen, hdim), router_logits
|
||||
|
||||
mm_dtype = torch.bfloat16 if expert_dtype == torch.bfloat16 else expert_dtype
|
||||
routed_in = routed_input.to(mm_dtype)
|
||||
w_gate_t = w_gate.transpose(-2, -1).to(mm_dtype)
|
||||
|
||||
Reference in New Issue
Block a user