This commit is contained in:
Dan Saunders
2025-09-22 16:10:41 -04:00
parent 92f2f6e73c
commit 9d69c6fb3e

View File

@@ -193,8 +193,9 @@ def main() -> None: # pragma: no cover - CLI entrypoint
) )
def _uniform_gate(self, hidden_states): def _uniform_gate(self, hidden_states):
batch_tokens = hidden_states.shape[0] flat = hidden_states.view(-1, hidden_states.shape[-1])
return topk_idx[:batch_tokens], weights[:batch_tokens] token_count = flat.shape[0]
return topk_idx[:token_count], weights[:token_count]
patched_module.gate.forward = _uniform_gate.__get__( patched_module.gate.forward = _uniform_gate.__get__(
patched_module.gate, patched_module.gate.__class__ patched_module.gate, patched_module.gate.__class__