dtype fix
This commit is contained in:
@@ -93,18 +93,16 @@ def moe_ffn_forward_grouped(
|
|||||||
device = hidden_states.device
|
device = hidden_states.device
|
||||||
|
|
||||||
routing_dtype = gate_linear.weight.dtype
|
routing_dtype = gate_linear.weight.dtype
|
||||||
expert_dtype = hidden_states.dtype
|
|
||||||
x_flat = hidden_states.view(tokens, hdim)
|
|
||||||
router_logits = gate_linear(x_flat.to(routing_dtype))
|
|
||||||
|
|
||||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
sample_mod = getattr(
|
||||||
topk_weight, topk_idx = torch.topk(routing_weights, top_k, dim=-1, sorted=False)
|
|
||||||
topk_weight = topk_weight / topk_weight.sum(dim=-1, keepdim=True)
|
|
||||||
|
|
||||||
sample = getattr(
|
|
||||||
experts_module[0], "mlp", getattr(experts_module[0], "ffn", experts_module[0])
|
experts_module[0], "mlp", getattr(experts_module[0], "ffn", experts_module[0])
|
||||||
)
|
)
|
||||||
if hasattr(sample, "w1") and hasattr(sample, "w3") and hasattr(sample, "w2"):
|
if (
|
||||||
|
hasattr(sample_mod, "w1")
|
||||||
|
and hasattr(sample_mod, "w3")
|
||||||
|
and hasattr(sample_mod, "w2")
|
||||||
|
):
|
||||||
|
expert_dtype = sample_mod.w1.weight.dtype
|
||||||
w13 = _stack_weights(
|
w13 = _stack_weights(
|
||||||
experts_module, ("w1", "w3"), key="w13", dtype=expert_dtype, device=device
|
experts_module, ("w1", "w3"), key="w13", dtype=expert_dtype, device=device
|
||||||
)
|
)
|
||||||
@@ -112,11 +110,12 @@ def moe_ffn_forward_grouped(
|
|||||||
experts_module, ("w2",), key="w2", dtype=expert_dtype, device=device
|
experts_module, ("w2",), key="w2", dtype=expert_dtype, device=device
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
names13 = (
|
if hasattr(sample_mod, "gate_up_proj"):
|
||||||
("gate_up_proj",)
|
expert_dtype = sample_mod.gate_up_proj.weight.dtype
|
||||||
if hasattr(sample, "gate_up_proj")
|
names13: Tuple[str, ...] = ("gate_up_proj",)
|
||||||
else ("up_proj", "gate_proj")
|
else:
|
||||||
)
|
expert_dtype = sample_mod.up_proj.weight.dtype
|
||||||
|
names13 = ("up_proj", "gate_proj")
|
||||||
w13 = _stack_weights(
|
w13 = _stack_weights(
|
||||||
experts_module, names13, key="w13", dtype=expert_dtype, device=device
|
experts_module, names13, key="w13", dtype=expert_dtype, device=device
|
||||||
)
|
)
|
||||||
@@ -124,8 +123,15 @@ def moe_ffn_forward_grouped(
|
|||||||
experts_module, ("down_proj",), key="w2", dtype=expert_dtype, device=device
|
experts_module, ("down_proj",), key="w2", dtype=expert_dtype, device=device
|
||||||
)
|
)
|
||||||
|
|
||||||
|
x_flat = hidden_states.view(tokens, hdim).to(expert_dtype)
|
||||||
|
router_logits = gate_linear(x_flat.to(routing_dtype))
|
||||||
|
|
||||||
|
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||||
|
topk_weight, topk_idx = torch.topk(routing_weights, top_k, dim=-1, sorted=False)
|
||||||
|
topk_weight = topk_weight / topk_weight.sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
flat_idx = topk_idx.view(-1)
|
flat_idx = topk_idx.view(-1)
|
||||||
x_rep = x_flat.to(expert_dtype).repeat_interleave(top_k, dim=0)
|
x_rep = x_flat.repeat_interleave(top_k, dim=0)
|
||||||
|
|
||||||
as_list: List[torch.Tensor] = []
|
as_list: List[torch.Tensor] = []
|
||||||
bs_list: List[torch.Tensor] = []
|
bs_list: List[torch.Tensor] = []
|
||||||
|
|||||||
Reference in New Issue
Block a user