This commit is contained in:
Dan Saunders
2025-09-17 18:42:10 -04:00
parent 7935dc0911
commit f3b953e222

View File

@@ -57,8 +57,8 @@ def _stack_weights(
def _call_grouped_mm(
As: List[torch.Tensor], Bs: List[torch.Tensor], dtype: torch.dtype
) -> Optional[List[torch.Tensor]]:
if not As:
return []
if not As or dtype not in (torch.bfloat16, torch.float16):
return [] if not As else None
try:
As2 = [a.to(dtype).contiguous().view(a.shape[0], a.shape[1]) for a in As]
@@ -94,6 +94,14 @@ def moe_ffn_forward_grouped(
routing_dtype = gate_linear.weight.dtype
expert_dtype = hidden_states.dtype
if expert_dtype not in (torch.bfloat16, torch.float16):
_LOGGER.debug(
"torch_grouped: unsupported expert dtype %s; falling back to naive",
expert_dtype,
)
return None, None
sample_mod = getattr(
experts_module[0], "mlp", getattr(experts_module[0], "ffn", experts_module[0])
)
@@ -102,7 +110,6 @@ def moe_ffn_forward_grouped(
and hasattr(sample_mod, "w3")
and hasattr(sample_mod, "w2")
):
expert_dtype = sample_mod.w1.weight.dtype
w13 = _stack_weights(
experts_module, ("w1", "w3"), key="w13", dtype=expert_dtype, device=device
)
@@ -111,10 +118,8 @@ def moe_ffn_forward_grouped(
)
else:
if hasattr(sample_mod, "gate_up_proj"):
expert_dtype = sample_mod.gate_up_proj.weight.dtype
names13: Tuple[str, ...] = ("gate_up_proj",)
else:
expert_dtype = sample_mod.up_proj.weight.dtype
names13 = ("up_proj", "gate_proj")
w13 = _stack_weights(
experts_module, names13, key="w13", dtype=expert_dtype, device=device