Fix
This commit is contained in:
@@ -193,8 +193,9 @@ def main() -> None: # pragma: no cover - CLI entrypoint
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _uniform_gate(self, hidden_states):
|
def _uniform_gate(self, hidden_states):
|
||||||
batch_tokens = hidden_states.shape[0]
|
flat = hidden_states.view(-1, hidden_states.shape[-1])
|
||||||
return topk_idx[:batch_tokens], weights[:batch_tokens]
|
token_count = flat.shape[0]
|
||||||
|
return topk_idx[:token_count], weights[:token_count]
|
||||||
|
|
||||||
patched_module.gate.forward = _uniform_gate.__get__(
|
patched_module.gate.forward = _uniform_gate.__get__(
|
||||||
patched_module.gate, patched_module.gate.__class__
|
patched_module.gate, patched_module.gate.__class__
|
||||||
|
|||||||
Reference in New Issue
Block a user