fix
This commit is contained in:
@@ -91,22 +91,35 @@ def moe_ffn_forward_grouped(
|
||||
global LAST_ERROR
|
||||
LAST_ERROR = None
|
||||
bsz, seqlen, hdim = hidden_states.shape
|
||||
compute_dtype = gate_linear.weight.dtype
|
||||
if hidden_states.dtype != compute_dtype:
|
||||
hidden_states = hidden_states.to(dtype=compute_dtype)
|
||||
routing_dtype = gate_linear.weight.dtype
|
||||
use_mixed_router = (
|
||||
hidden_states.device.type == "cuda" and routing_dtype == torch.float32
|
||||
)
|
||||
|
||||
if use_mixed_router:
|
||||
x_router = hidden_states.to(dtype=routing_dtype)
|
||||
router_logits = gate_linear(x_router)
|
||||
else:
|
||||
if hidden_states.dtype != routing_dtype:
|
||||
hidden_states = hidden_states.to(dtype=routing_dtype)
|
||||
x = hidden_states.view(-1, hdim)
|
||||
router_logits = gate_linear(x)
|
||||
|
||||
if router_logits.dtype != routing_dtype:
|
||||
router_logits = router_logits.to(dtype=routing_dtype)
|
||||
|
||||
x = hidden_states.view(-1, hdim)
|
||||
router_logits = gate_linear(x)
|
||||
|
||||
# top-k routing executed in torch to avoid extra dependencies
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
topk_weight, topk_idx = torch.topk(routing_weights, top_k, dim=-1, sorted=False)
|
||||
topk_weight = (topk_weight / topk_weight.sum(dim=-1, keepdim=True)).to(x.dtype)
|
||||
topk_weight = topk_weight / topk_weight.sum(dim=-1, keepdim=True)
|
||||
|
||||
flat_idx = topk_idx.view(-1)
|
||||
x_rep = x.repeat_interleave(top_k, dim=0)
|
||||
|
||||
E = _num_experts(experts_module)
|
||||
dev, dt = x.device, x.dtype
|
||||
dev = hidden_states.device
|
||||
dt: torch.dtype = hidden_states.dtype
|
||||
first = experts_module[0]
|
||||
|
||||
is_mixtral = _is_mixtral_layout(first)
|
||||
@@ -134,7 +147,7 @@ def moe_ffn_forward_grouped(
|
||||
LAST_ERROR = "unsupported expert layout"
|
||||
return None, None
|
||||
|
||||
def _resolve_expert(idx: int):
|
||||
def _resolve_expert(idx: int) -> torch.nn.Module:
|
||||
expert = experts_module[idx]
|
||||
if nested_attr is None:
|
||||
return expert
|
||||
@@ -214,6 +227,18 @@ def moe_ffn_forward_grouped(
|
||||
W13 = experts_module._stacked_w13
|
||||
W2 = experts_module._stacked_w2
|
||||
|
||||
dt = W13.dtype
|
||||
if router_logits.dtype != dt:
|
||||
router_logits = router_logits.to(dtype=dt)
|
||||
if x.dtype != dt:
|
||||
x = x.to(dtype=dt)
|
||||
flat_idx = topk_idx.view(-1)
|
||||
if topk_weight.dtype != dt:
|
||||
topk_weight = topk_weight.to(dtype=dt)
|
||||
x_rep = x.repeat_interleave(top_k, dim=0)
|
||||
if x_rep.dtype != dt:
|
||||
x_rep = x_rep.to(dtype=dt)
|
||||
|
||||
As: List[torch.Tensor] = []
|
||||
Bs: List[torch.Tensor] = []
|
||||
expert_slices: List[Tuple[int, torch.Tensor]] = []
|
||||
@@ -272,7 +297,7 @@ def moe_ffn_forward_grouped(
|
||||
|
||||
As2: List[torch.Tensor] = []
|
||||
Bs2: List[torch.Tensor] = []
|
||||
y_buf = torch.empty_like(x_rep)
|
||||
y_buf = torch.empty_like(x_rep, dtype=dt)
|
||||
for (i, _sel), Yi in zip(expert_slices, Y_list, strict=False):
|
||||
I2 = Yi.shape[-1] // 2
|
||||
Yi_hidden = F.silu(Yi[:, :I2]) * Yi[:, I2:]
|
||||
|
||||
Reference in New Issue
Block a user