diff --git a/scripts/benchmarks/deepseek_v3_moe.py b/scripts/benchmarks/deepseek_v3_moe.py index 91c777048..7385ca1e9 100644 --- a/scripts/benchmarks/deepseek_v3_moe.py +++ b/scripts/benchmarks/deepseek_v3_moe.py @@ -185,7 +185,12 @@ def main() -> None: # pragma: no cover - CLI entrypoint max_tokens = int(tokens_per_expert.max().item()) if args.uniform_routing: - weights = torch.full_like(topk_idx, 1.0 / args.top_k) + weights = torch.full( + topk_idx.shape, + 1.0 / args.top_k, + device=device, + dtype=torch.float32, + ) def _uniform_gate(self, hidden_states): batch_tokens = hidden_states.shape[0]