refactor
This commit is contained in:
@@ -26,54 +26,30 @@ def available() -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _iter_expert_impls(experts_module) -> List[torch.nn.Module]:
|
||||
impls: List[torch.nn.Module] = []
|
||||
for exp in experts_module:
|
||||
impls.append(getattr(exp, "mlp", getattr(exp, "ffn", exp)))
|
||||
return impls
|
||||
|
||||
|
||||
def _stack_weights(
|
||||
experts_module,
|
||||
names: Tuple[str, ...],
|
||||
*,
|
||||
key: str,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
attr = f"_ax_grouped_{key}"
|
||||
cached = getattr(experts_module, attr, None)
|
||||
if cached is not None and cached.dtype == dtype and cached.device == device:
|
||||
return cached
|
||||
|
||||
tensors: List[torch.Tensor] = []
|
||||
for exp in experts_module:
|
||||
mod = getattr(exp, "mlp", getattr(exp, "ffn", exp))
|
||||
for mod in _iter_expert_impls(experts_module):
|
||||
parts = [getattr(mod, name).weight.t() for name in names]
|
||||
tensors.append(parts[0] if len(parts) == 1 else torch.cat(parts, dim=-1))
|
||||
|
||||
stacked = (
|
||||
return (
|
||||
torch.stack(tensors, dim=0)
|
||||
.to(device=device, dtype=dtype, non_blocking=True)
|
||||
.contiguous()
|
||||
)
|
||||
setattr(experts_module, attr, stacked)
|
||||
return stacked
|
||||
|
||||
|
||||
def _call_grouped_mm(
|
||||
As: List[torch.Tensor], Bs: List[torch.Tensor], dtype: torch.dtype
|
||||
) -> List[torch.Tensor]:
|
||||
if not As:
|
||||
return []
|
||||
if dtype not in (torch.bfloat16, torch.float16):
|
||||
raise RuntimeError(f"unsupported dtype {dtype} for grouped_mm")
|
||||
|
||||
As2 = [a.to(dtype).contiguous().view(a.shape[0], a.shape[1]) for a in As]
|
||||
Bs2 = [b.to(dtype).contiguous().view(b.shape[0], b.shape[1]) for b in Bs]
|
||||
device = As2[0].device
|
||||
lengths = torch.tensor([a.shape[0] for a in As2], device=device, dtype=torch.int32)
|
||||
offsets = torch.cumsum(lengths, dim=0).to(torch.int32)
|
||||
Y_cat = torch._grouped_mm(torch.cat(As2, dim=0), torch.stack(Bs2, dim=0), offsets)
|
||||
outs: List[torch.Tensor] = []
|
||||
start = 0
|
||||
for size in lengths.tolist():
|
||||
outs.append(Y_cat[start : start + size])
|
||||
start += size
|
||||
return outs
|
||||
|
||||
|
||||
def moe_ffn_forward_grouped(
|
||||
@@ -99,30 +75,30 @@ def moe_ffn_forward_grouped(
|
||||
)
|
||||
return None, None
|
||||
|
||||
sample_mod = getattr(
|
||||
experts_module[0], "mlp", getattr(experts_module[0], "ffn", experts_module[0])
|
||||
)
|
||||
for suffix in ("w13", "w2"):
|
||||
attr = f"_ax_grouped_{suffix}"
|
||||
if hasattr(experts_module, attr):
|
||||
delattr(experts_module, attr)
|
||||
|
||||
expert_impls = _iter_expert_impls(experts_module)
|
||||
sample_mod = expert_impls[0]
|
||||
if (
|
||||
hasattr(sample_mod, "w1")
|
||||
and hasattr(sample_mod, "w3")
|
||||
and hasattr(sample_mod, "w2")
|
||||
):
|
||||
w13 = _stack_weights(
|
||||
experts_module, ("w1", "w3"), key="w13", dtype=expert_dtype, device=device
|
||||
)
|
||||
w2 = _stack_weights(
|
||||
experts_module, ("w2",), key="w2", dtype=expert_dtype, device=device
|
||||
experts_module, ("w1", "w3"), dtype=expert_dtype, device=device
|
||||
)
|
||||
w2 = _stack_weights(experts_module, ("w2",), dtype=expert_dtype, device=device)
|
||||
else:
|
||||
if hasattr(sample_mod, "gate_up_proj"):
|
||||
names13: Tuple[str, ...] = ("gate_up_proj",)
|
||||
else:
|
||||
names13 = ("up_proj", "gate_proj")
|
||||
w13 = _stack_weights(
|
||||
experts_module, names13, key="w13", dtype=expert_dtype, device=device
|
||||
)
|
||||
w13 = _stack_weights(experts_module, names13, dtype=expert_dtype, device=device)
|
||||
w2 = _stack_weights(
|
||||
experts_module, ("down_proj",), key="w2", dtype=expert_dtype, device=device
|
||||
experts_module, ("down_proj",), dtype=expert_dtype, device=device
|
||||
)
|
||||
|
||||
x_flat = hidden_states.view(tokens, hdim).to(expert_dtype)
|
||||
@@ -133,41 +109,45 @@ def moe_ffn_forward_grouped(
|
||||
topk_weight = topk_weight / topk_weight.sum(dim=-1, keepdim=True)
|
||||
|
||||
flat_idx = topk_idx.view(-1)
|
||||
x_rep = x_flat.repeat_interleave(top_k, dim=0)
|
||||
num_experts = len(expert_impls)
|
||||
if flat_idx.numel() == 0:
|
||||
zero = torch.zeros_like(x_flat)
|
||||
return zero.view(bsz, seqlen, hdim), router_logits
|
||||
|
||||
as_list: List[torch.Tensor] = []
|
||||
bs_list: List[torch.Tensor] = []
|
||||
slices: List[Tuple[int, torch.Tensor]] = []
|
||||
for i, _ in enumerate(experts_module):
|
||||
sel = flat_idx == i
|
||||
if sel.any():
|
||||
as_list.append(x_rep[sel])
|
||||
bs_list.append(w13[i])
|
||||
slices.append((i, sel))
|
||||
assignments = torch.bincount(flat_idx, minlength=num_experts)
|
||||
if assignments.sum() == 0:
|
||||
zero = torch.zeros_like(x_flat)
|
||||
return zero.view(bsz, seqlen, hdim), router_logits
|
||||
|
||||
if not as_list:
|
||||
return torch.zeros_like(x_flat).view(bsz, seqlen, hdim), router_logits
|
||||
perm = torch.argsort(flat_idx, stable=True)
|
||||
token_indices_sorted = perm // top_k
|
||||
scores_sorted = topk_weight.view(-1)[perm]
|
||||
|
||||
up_out = _call_grouped_mm(as_list, bs_list, expert_dtype)
|
||||
gather_index = token_indices_sorted.unsqueeze(-1).expand(-1, hdim)
|
||||
routed_input = torch.gather(x_flat, 0, gather_index).contiguous()
|
||||
|
||||
down_inputs: List[torch.Tensor] = []
|
||||
down_weights: List[torch.Tensor] = []
|
||||
buf = torch.empty_like(x_rep)
|
||||
for (i, _sel), Yi in zip(slices, up_out, strict=False):
|
||||
mid = Yi.shape[-1] // 2
|
||||
hidden = F.silu(Yi[:, :mid]) * Yi[:, mid:]
|
||||
down_inputs.append(hidden)
|
||||
down_weights.append(w2[i])
|
||||
offsets = torch.cumsum(assignments.to(device=device, dtype=torch.int32), dim=0)
|
||||
if offsets[-1].item() == 0:
|
||||
zero = torch.zeros_like(x_flat)
|
||||
return zero.view(bsz, seqlen, hdim), router_logits
|
||||
|
||||
down_out = _call_grouped_mm(down_inputs, down_weights, expert_dtype)
|
||||
mid = w13.shape[-1] // 2
|
||||
w_gate = w13[..., :mid]
|
||||
w_up = w13[..., mid:]
|
||||
|
||||
for (_i, sel), tensor in zip(slices, down_out, strict=False):
|
||||
buf[sel] = tensor
|
||||
w_gate_t = w_gate.transpose(-2, -1).contiguous()
|
||||
w_up_t = w_up.transpose(-2, -1).contiguous()
|
||||
w2_t = w2.transpose(-2, -1).contiguous()
|
||||
|
||||
combined = (
|
||||
(buf.view(tokens, top_k, -1) * topk_weight.to(expert_dtype).unsqueeze(-1))
|
||||
.sum(dim=1)
|
||||
.to(torch.bfloat16)
|
||||
)
|
||||
routed_in = routed_input.to(expert_dtype)
|
||||
gate_out = torch._grouped_mm(routed_in, w_gate_t, offs=offsets)
|
||||
up_out = torch._grouped_mm(routed_in, w_up_t, offs=offsets)
|
||||
activated = F.silu(gate_out) * up_out
|
||||
down_out = torch._grouped_mm(activated, w2_t, offs=offsets)
|
||||
|
||||
weights_fp32 = scores_sorted.unsqueeze(-1).to(torch.float32)
|
||||
weighted = (down_out.to(torch.float32) * weights_fp32).to(expert_dtype)
|
||||
|
||||
combined = torch.zeros_like(x_flat)
|
||||
combined.scatter_add_(0, gather_index, weighted)
|
||||
return combined.view(bsz, seqlen, hdim), router_logits
|
||||
|
||||
Reference in New Issue
Block a user