Fix
This commit is contained in:
@@ -193,8 +193,9 @@ def main() -> None: # pragma: no cover - CLI entrypoint
|
||||
)
|
||||
|
||||
def _uniform_gate(self, hidden_states):
|
||||
batch_tokens = hidden_states.shape[0]
|
||||
return topk_idx[:batch_tokens], weights[:batch_tokens]
|
||||
flat = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
token_count = flat.shape[0]
|
||||
return topk_idx[:token_count], weights[:token_count]
|
||||
|
||||
patched_module.gate.forward = _uniform_gate.__get__(
|
||||
patched_module.gate, patched_module.gate.__class__
|
||||
|
||||
Reference in New Issue
Block a user