diff --git a/scripts/benchmarks/deepseek_v3_moe.py b/scripts/benchmarks/deepseek_v3_moe.py index 7385ca1e9..4d65b4e20 100644 --- a/scripts/benchmarks/deepseek_v3_moe.py +++ b/scripts/benchmarks/deepseek_v3_moe.py @@ -193,8 +193,9 @@ def main() -> None: # pragma: no cover - CLI entrypoint ) def _uniform_gate(self, hidden_states): - batch_tokens = hidden_states.shape[0] - return topk_idx[:batch_tokens], weights[:batch_tokens] + flat = hidden_states.view(-1, hidden_states.shape[-1]) + token_count = flat.shape[0] + return topk_idx[:token_count], weights[:token_count] patched_module.gate.forward = _uniform_gate.__get__( patched_module.gate, patched_module.gate.__class__