fix
This commit is contained in:
@@ -91,22 +91,35 @@ def moe_ffn_forward_grouped(
|
|||||||
global LAST_ERROR
|
global LAST_ERROR
|
||||||
LAST_ERROR = None
|
LAST_ERROR = None
|
||||||
bsz, seqlen, hdim = hidden_states.shape
|
bsz, seqlen, hdim = hidden_states.shape
|
||||||
compute_dtype = gate_linear.weight.dtype
|
routing_dtype = gate_linear.weight.dtype
|
||||||
if hidden_states.dtype != compute_dtype:
|
use_mixed_router = (
|
||||||
hidden_states = hidden_states.to(dtype=compute_dtype)
|
hidden_states.device.type == "cuda" and routing_dtype == torch.float32
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_mixed_router:
|
||||||
|
x_router = hidden_states.to(dtype=routing_dtype)
|
||||||
|
router_logits = gate_linear(x_router)
|
||||||
|
else:
|
||||||
|
if hidden_states.dtype != routing_dtype:
|
||||||
|
hidden_states = hidden_states.to(dtype=routing_dtype)
|
||||||
|
x = hidden_states.view(-1, hdim)
|
||||||
|
router_logits = gate_linear(x)
|
||||||
|
|
||||||
|
if router_logits.dtype != routing_dtype:
|
||||||
|
router_logits = router_logits.to(dtype=routing_dtype)
|
||||||
|
|
||||||
x = hidden_states.view(-1, hdim)
|
x = hidden_states.view(-1, hdim)
|
||||||
router_logits = gate_linear(x)
|
|
||||||
|
|
||||||
# top-k routing executed in torch to avoid extra dependencies
|
# top-k routing executed in torch to avoid extra dependencies
|
||||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||||
topk_weight, topk_idx = torch.topk(routing_weights, top_k, dim=-1, sorted=False)
|
topk_weight, topk_idx = torch.topk(routing_weights, top_k, dim=-1, sorted=False)
|
||||||
topk_weight = (topk_weight / topk_weight.sum(dim=-1, keepdim=True)).to(x.dtype)
|
topk_weight = topk_weight / topk_weight.sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
flat_idx = topk_idx.view(-1)
|
flat_idx = topk_idx.view(-1)
|
||||||
x_rep = x.repeat_interleave(top_k, dim=0)
|
|
||||||
|
|
||||||
E = _num_experts(experts_module)
|
E = _num_experts(experts_module)
|
||||||
dev, dt = x.device, x.dtype
|
dev = hidden_states.device
|
||||||
|
dt: torch.dtype = hidden_states.dtype
|
||||||
first = experts_module[0]
|
first = experts_module[0]
|
||||||
|
|
||||||
is_mixtral = _is_mixtral_layout(first)
|
is_mixtral = _is_mixtral_layout(first)
|
||||||
@@ -134,7 +147,7 @@ def moe_ffn_forward_grouped(
|
|||||||
LAST_ERROR = "unsupported expert layout"
|
LAST_ERROR = "unsupported expert layout"
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
def _resolve_expert(idx: int):
|
def _resolve_expert(idx: int) -> torch.nn.Module:
|
||||||
expert = experts_module[idx]
|
expert = experts_module[idx]
|
||||||
if nested_attr is None:
|
if nested_attr is None:
|
||||||
return expert
|
return expert
|
||||||
@@ -214,6 +227,18 @@ def moe_ffn_forward_grouped(
|
|||||||
W13 = experts_module._stacked_w13
|
W13 = experts_module._stacked_w13
|
||||||
W2 = experts_module._stacked_w2
|
W2 = experts_module._stacked_w2
|
||||||
|
|
||||||
|
dt = W13.dtype
|
||||||
|
if router_logits.dtype != dt:
|
||||||
|
router_logits = router_logits.to(dtype=dt)
|
||||||
|
if x.dtype != dt:
|
||||||
|
x = x.to(dtype=dt)
|
||||||
|
flat_idx = topk_idx.view(-1)
|
||||||
|
if topk_weight.dtype != dt:
|
||||||
|
topk_weight = topk_weight.to(dtype=dt)
|
||||||
|
x_rep = x.repeat_interleave(top_k, dim=0)
|
||||||
|
if x_rep.dtype != dt:
|
||||||
|
x_rep = x_rep.to(dtype=dt)
|
||||||
|
|
||||||
As: List[torch.Tensor] = []
|
As: List[torch.Tensor] = []
|
||||||
Bs: List[torch.Tensor] = []
|
Bs: List[torch.Tensor] = []
|
||||||
expert_slices: List[Tuple[int, torch.Tensor]] = []
|
expert_slices: List[Tuple[int, torch.Tensor]] = []
|
||||||
@@ -272,7 +297,7 @@ def moe_ffn_forward_grouped(
|
|||||||
|
|
||||||
As2: List[torch.Tensor] = []
|
As2: List[torch.Tensor] = []
|
||||||
Bs2: List[torch.Tensor] = []
|
Bs2: List[torch.Tensor] = []
|
||||||
y_buf = torch.empty_like(x_rep)
|
y_buf = torch.empty_like(x_rep, dtype=dt)
|
||||||
for (i, _sel), Yi in zip(expert_slices, Y_list, strict=False):
|
for (i, _sel), Yi in zip(expert_slices, Y_list, strict=False):
|
||||||
I2 = Yi.shape[-1] // 2
|
I2 = Yi.shape[-1] // 2
|
||||||
Yi_hidden = F.silu(Yi[:, :I2]) * Yi[:, I2:]
|
Yi_hidden = F.silu(Yi[:, :I2]) * Yi[:, I2:]
|
||||||
|
|||||||
Reference in New Issue
Block a user