This commit is contained in:
Dan Saunders
2025-09-22 16:10:41 -04:00
parent 92f2f6e73c
commit 9d69c6fb3e

View File

@@ -193,8 +193,9 @@ def main() -> None: # pragma: no cover - CLI entrypoint
)
def _uniform_gate(self, hidden_states):
batch_tokens = hidden_states.shape[0]
return topk_idx[:batch_tokens], weights[:batch_tokens]
flat = hidden_states.view(-1, hidden_states.shape[-1])
token_count = flat.shape[0]
return topk_idx[:token_count], weights[:token_count]
patched_module.gate.forward = _uniform_gate.__get__(
patched_module.gate, patched_module.gate.__class__