dtype fix

This commit is contained in:
Dan Saunders
2025-09-22 16:07:45 -04:00
parent e5d2aebe16
commit 92f2f6e73c

View File

@@ -185,7 +185,12 @@ def main() -> None: # pragma: no cover - CLI entrypoint
max_tokens = int(tokens_per_expert.max().item())
if args.uniform_routing:
weights = torch.full_like(topk_idx, 1.0 / args.top_k)
weights = torch.full(
topk_idx.shape,
1.0 / args.top_k,
device=device,
dtype=torch.float32,
)
def _uniform_gate(self, hidden_states):
batch_tokens = hidden_states.shape[0]