dtype fix
This commit is contained in:
@@ -185,7 +185,12 @@ def main() -> None: # pragma: no cover - CLI entrypoint
|
|||||||
max_tokens = int(tokens_per_expert.max().item())
|
max_tokens = int(tokens_per_expert.max().item())
|
||||||
|
|
||||||
if args.uniform_routing:
|
if args.uniform_routing:
|
||||||
weights = torch.full_like(topk_idx, 1.0 / args.top_k)
|
weights = torch.full(
|
||||||
|
topk_idx.shape,
|
||||||
|
1.0 / args.top_k,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
|
||||||
def _uniform_gate(self, hidden_states):
|
def _uniform_gate(self, hidden_states):
|
||||||
batch_tokens = hidden_states.shape[0]
|
batch_tokens = hidden_states.shape[0]
|
||||||
|
|||||||
Reference in New Issue
Block a user