fix?
This commit is contained in:
@@ -57,8 +57,8 @@ def _stack_weights(
|
|||||||
def _call_grouped_mm(
|
def _call_grouped_mm(
|
||||||
As: List[torch.Tensor], Bs: List[torch.Tensor], dtype: torch.dtype
|
As: List[torch.Tensor], Bs: List[torch.Tensor], dtype: torch.dtype
|
||||||
) -> Optional[List[torch.Tensor]]:
|
) -> Optional[List[torch.Tensor]]:
|
||||||
if not As:
|
if not As or dtype not in (torch.bfloat16, torch.float16):
|
||||||
return []
|
return [] if not As else None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
As2 = [a.to(dtype).contiguous().view(a.shape[0], a.shape[1]) for a in As]
|
As2 = [a.to(dtype).contiguous().view(a.shape[0], a.shape[1]) for a in As]
|
||||||
@@ -94,6 +94,14 @@ def moe_ffn_forward_grouped(
|
|||||||
|
|
||||||
routing_dtype = gate_linear.weight.dtype
|
routing_dtype = gate_linear.weight.dtype
|
||||||
|
|
||||||
|
expert_dtype = hidden_states.dtype
|
||||||
|
if expert_dtype not in (torch.bfloat16, torch.float16):
|
||||||
|
_LOGGER.debug(
|
||||||
|
"torch_grouped: unsupported expert dtype %s; falling back to naive",
|
||||||
|
expert_dtype,
|
||||||
|
)
|
||||||
|
return None, None
|
||||||
|
|
||||||
sample_mod = getattr(
|
sample_mod = getattr(
|
||||||
experts_module[0], "mlp", getattr(experts_module[0], "ffn", experts_module[0])
|
experts_module[0], "mlp", getattr(experts_module[0], "ffn", experts_module[0])
|
||||||
)
|
)
|
||||||
@@ -102,7 +110,6 @@ def moe_ffn_forward_grouped(
|
|||||||
and hasattr(sample_mod, "w3")
|
and hasattr(sample_mod, "w3")
|
||||||
and hasattr(sample_mod, "w2")
|
and hasattr(sample_mod, "w2")
|
||||||
):
|
):
|
||||||
expert_dtype = sample_mod.w1.weight.dtype
|
|
||||||
w13 = _stack_weights(
|
w13 = _stack_weights(
|
||||||
experts_module, ("w1", "w3"), key="w13", dtype=expert_dtype, device=device
|
experts_module, ("w1", "w3"), key="w13", dtype=expert_dtype, device=device
|
||||||
)
|
)
|
||||||
@@ -111,10 +118,8 @@ def moe_ffn_forward_grouped(
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if hasattr(sample_mod, "gate_up_proj"):
|
if hasattr(sample_mod, "gate_up_proj"):
|
||||||
expert_dtype = sample_mod.gate_up_proj.weight.dtype
|
|
||||||
names13: Tuple[str, ...] = ("gate_up_proj",)
|
names13: Tuple[str, ...] = ("gate_up_proj",)
|
||||||
else:
|
else:
|
||||||
expert_dtype = sample_mod.up_proj.weight.dtype
|
|
||||||
names13 = ("up_proj", "gate_proj")
|
names13 = ("up_proj", "gate_proj")
|
||||||
w13 = _stack_weights(
|
w13 = _stack_weights(
|
||||||
experts_module, names13, key="w13", dtype=expert_dtype, device=device
|
experts_module, names13, key="w13", dtype=expert_dtype, device=device
|
||||||
|
|||||||
Reference in New Issue
Block a user