yet another refactor
This commit is contained in:
@@ -45,6 +45,29 @@ class _GroupedWeightStorage:
|
||||
device: torch.device
|
||||
|
||||
|
||||
def _allocate_fused_gate_up(
|
||||
num_experts: int,
|
||||
gate_shape: torch.Size,
|
||||
up_shape: torch.Size,
|
||||
*,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
if gate_shape[1] != up_shape[1]:
|
||||
raise RuntimeError(
|
||||
"torch_grouped: gate and up projections must share the hidden dimension"
|
||||
)
|
||||
|
||||
fused = torch.empty(
|
||||
(num_experts, gate_shape[0] + up_shape[0], gate_shape[1]),
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
gate_view = fused[:, : gate_shape[0]]
|
||||
up_view = fused[:, gate_shape[0] : gate_shape[0] + up_shape[0]]
|
||||
return fused, gate_view, up_view
|
||||
|
||||
|
||||
def _ensure_grouped_weights(
|
||||
experts_module, expert_impls: List[torch.nn.Module], sample_mod: torch.nn.Module
|
||||
) -> _GroupedWeightStorage:
|
||||
@@ -63,28 +86,26 @@ def _ensure_grouped_weights(
|
||||
and hasattr(sample_mod, "w2")
|
||||
):
|
||||
pattern = "swi_glu"
|
||||
num_experts = len(expert_impls)
|
||||
w1_shape = sample_mod.w1.weight.shape
|
||||
w3_shape = sample_mod.w3.weight.shape
|
||||
w2_shape = sample_mod.w2.weight.shape
|
||||
if (
|
||||
storage is not None
|
||||
and storage.pattern == pattern
|
||||
and storage.dtype == sample_mod.w1.weight.dtype
|
||||
and storage.device == sample_mod.w1.weight.device
|
||||
and storage.gate.shape[1:] == w1_shape
|
||||
):
|
||||
return storage
|
||||
|
||||
num_experts = len(expert_impls)
|
||||
w1_shape = sample_mod.w1.weight.shape
|
||||
w3_shape = sample_mod.w3.weight.shape
|
||||
w2_shape = sample_mod.w2.weight.shape
|
||||
gate = torch.empty(
|
||||
(num_experts, *w1_shape),
|
||||
fused, gate, up = _allocate_fused_gate_up(
|
||||
num_experts,
|
||||
w1_shape,
|
||||
w3_shape,
|
||||
device=sample_mod.w1.weight.device,
|
||||
dtype=sample_mod.w1.weight.dtype,
|
||||
)
|
||||
up = torch.empty(
|
||||
(num_experts, *w3_shape),
|
||||
device=sample_mod.w3.weight.device,
|
||||
dtype=sample_mod.w3.weight.dtype,
|
||||
)
|
||||
down = torch.empty(
|
||||
(num_experts, *w2_shape),
|
||||
device=sample_mod.w2.weight.device,
|
||||
@@ -102,7 +123,6 @@ def _ensure_grouped_weights(
|
||||
mod.w2.weight.detach_()
|
||||
mod.w2.weight.set_(down[idx])
|
||||
|
||||
fused = torch.cat((gate, up), dim=1).contiguous()
|
||||
return _store(
|
||||
_GroupedWeightStorage(
|
||||
pattern=pattern,
|
||||
@@ -124,6 +144,8 @@ def _ensure_grouped_weights(
|
||||
and storage.pattern == pattern
|
||||
and storage.dtype == gate_weight.dtype
|
||||
and storage.device == gate_weight.device
|
||||
and storage.gate.shape[1:]
|
||||
== (gate_weight.shape[0] // 2, gate_weight.shape[1])
|
||||
):
|
||||
return storage
|
||||
|
||||
@@ -168,25 +190,23 @@ def _ensure_grouped_weights(
|
||||
and hasattr(sample_mod, "down_proj")
|
||||
):
|
||||
pattern = "dual_proj"
|
||||
up_weight = sample_mod.up_proj.weight
|
||||
gate_weight = sample_mod.gate_proj.weight
|
||||
down_weight = sample_mod.down_proj.weight
|
||||
if (
|
||||
storage is not None
|
||||
and storage.pattern == pattern
|
||||
and storage.dtype == sample_mod.up_proj.weight.dtype
|
||||
and storage.device == sample_mod.up_proj.weight.device
|
||||
and storage.gate.shape[1:] == gate_weight.shape
|
||||
):
|
||||
return storage
|
||||
|
||||
num_experts = len(expert_impls)
|
||||
up_weight = sample_mod.up_proj.weight
|
||||
gate_weight = sample_mod.gate_proj.weight
|
||||
down_weight = sample_mod.down_proj.weight
|
||||
up = torch.empty(
|
||||
(num_experts, *up_weight.shape),
|
||||
device=up_weight.device,
|
||||
dtype=up_weight.dtype,
|
||||
)
|
||||
gate = torch.empty(
|
||||
(num_experts, *gate_weight.shape),
|
||||
fused, gate, up = _allocate_fused_gate_up(
|
||||
num_experts,
|
||||
gate_weight.shape,
|
||||
up_weight.shape,
|
||||
device=gate_weight.device,
|
||||
dtype=gate_weight.dtype,
|
||||
)
|
||||
@@ -197,8 +217,8 @@ def _ensure_grouped_weights(
|
||||
)
|
||||
with torch.no_grad():
|
||||
for idx, mod in enumerate(expert_impls):
|
||||
up[idx].copy_(mod.up_proj.weight.detach())
|
||||
gate[idx].copy_(mod.gate_proj.weight.detach())
|
||||
up[idx].copy_(mod.up_proj.weight.detach())
|
||||
down[idx].copy_(mod.down_proj.weight.detach())
|
||||
mod.up_proj.weight.detach_()
|
||||
mod.up_proj.weight.set_(up[idx])
|
||||
@@ -207,7 +227,6 @@ def _ensure_grouped_weights(
|
||||
mod.down_proj.weight.detach_()
|
||||
mod.down_proj.weight.set_(down[idx])
|
||||
|
||||
fused = torch.cat((gate, up), dim=1).contiguous()
|
||||
return _store(
|
||||
_GroupedWeightStorage(
|
||||
pattern=pattern,
|
||||
@@ -268,33 +287,31 @@ def moe_ffn_forward_grouped(
|
||||
zero = torch.zeros_like(x_flat)
|
||||
return zero.view(bsz, seqlen, hdim), router_logits
|
||||
|
||||
perm = torch.argsort(flat_idx, stable=True)
|
||||
sorted_experts = flat_idx[perm]
|
||||
sorted_experts, perm = torch.sort(flat_idx)
|
||||
assignments = torch.bincount(sorted_experts, minlength=num_experts)
|
||||
if assignments.sum() == 0:
|
||||
zero = torch.zeros_like(x_flat)
|
||||
return zero.view(bsz, seqlen, hdim), router_logits
|
||||
|
||||
token_indices_sorted = perm // top_k
|
||||
scores_sorted = topk_weight.view(-1)[perm]
|
||||
token_indices_sorted = torch.div(perm, top_k, rounding_mode="floor").contiguous()
|
||||
scores_sorted = topk_weight.reshape(-1).index_select(0, perm)
|
||||
|
||||
gather_index = token_indices_sorted.unsqueeze(-1).expand(-1, hdim)
|
||||
routed_input = torch.gather(x_flat, 0, gather_index).contiguous()
|
||||
routed_input = x_flat.index_select(0, token_indices_sorted).contiguous()
|
||||
|
||||
active_idx = torch.nonzero(assignments, as_tuple=False).squeeze(-1)
|
||||
active_idx = torch.nonzero(assignments, as_tuple=False).squeeze(-1).contiguous()
|
||||
if active_idx.numel() == 0:
|
||||
zero = torch.zeros_like(x_flat)
|
||||
return zero.view(bsz, seqlen, hdim), router_logits
|
||||
|
||||
counts_active = assignments[active_idx]
|
||||
offsets = torch.cumsum(counts_active.to(device=device, dtype=torch.int32), dim=0)
|
||||
offsets = offsets.to(torch.int32)
|
||||
counts_active_i32 = counts_active.to(device=device, dtype=torch.int32)
|
||||
offsets = torch.cumsum(counts_active_i32, dim=0)
|
||||
if offsets[-1].item() == 0:
|
||||
zero = torch.zeros_like(x_flat)
|
||||
return zero.view(bsz, seqlen, hdim), router_logits
|
||||
|
||||
w_gate_up_t = w_gate_up[active_idx].transpose(-2, -1)
|
||||
w2_t = w2[active_idx].transpose(-2, -1)
|
||||
w_gate_up_t = w_gate_up.index_select(0, active_idx).transpose(-2, -1)
|
||||
w2_t = w2.index_select(0, active_idx).transpose(-2, -1)
|
||||
|
||||
routed_in = routed_input.to(expert_dtype)
|
||||
gate_up_out = torch._grouped_mm(routed_in, w_gate_up_t, offs=offsets)
|
||||
@@ -307,5 +324,5 @@ def moe_ffn_forward_grouped(
|
||||
down_out.mul_(weights)
|
||||
|
||||
combined = torch.zeros_like(x_flat)
|
||||
combined.scatter_add_(0, gather_index, down_out)
|
||||
combined.index_add_(0, token_indices_sorted, down_out)
|
||||
return combined.view(bsz, seqlen, hdim), router_logits
|
||||
|
||||
Reference in New Issue
Block a user