diff --git a/src/axolotl/monkeypatch/moe/__init__.py b/src/axolotl/monkeypatch/moe/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/axolotl/monkeypatch/moe/linear.py b/src/axolotl/monkeypatch/moe/linear.py new file mode 100644 index 000000000..1d2c3e101 --- /dev/null +++ b/src/axolotl/monkeypatch/moe/linear.py @@ -0,0 +1,147 @@ +""" +Adapted from: +https://github.com/shawntan/scattermoe +https://arxiv.org/abs/2403.08245 +""" + +import torch +import torch.nn as nn +from axolotl.monkeypatch.moe import ops + +class ParallelLinear(torch.autograd.Function): + @staticmethod + def forward( + ctx, x, expert_weights, k, + sorted_expert_idxs, sorted_scattered_idxs, + padded_block_idxs, expert_offsets, + gates=None, grouped_in=False, grouped_out=False, + ): + + output = ops.scatter2scatter( + X=x, W=expert_weights, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + padded_block_idxs=padded_block_idxs, + k=k, x_grouped=grouped_in, y_grouped=grouped_out + ) + if gates is not None: + output_expanded = output.view(gates.size(0), gates.size(1), output.size(-1)) + output = torch.bmm( + gates[:, None, :], + output_expanded + ).squeeze(1) + else: + output_expanded = None + + ctx.save_for_backward( + x, expert_weights, + sorted_expert_idxs, + sorted_scattered_idxs, + padded_block_idxs, expert_offsets, + gates, + output_expanded + ) + ctx.grouped_in = grouped_in + ctx.grouped_out = grouped_out + ctx.k = k + return output + @staticmethod + def backward(ctx, grad_out): + (x, expert_weights, + sorted_expert_idxs, + sorted_scattered_idxs, + padded_block_idxs, expert_offsets, + gates, output_expanded) = ctx.saved_tensors + k = ctx.k + grouped_in = ctx.grouped_in + grouped_out = ctx.grouped_out + # print("backward") + if gates is not None: + # calculate gates gradient + d_gates = torch.bmm(output_expanded, grad_out[:, :, None]).squeeze(-1) + gates_flat = gates.flatten() + gate_fan = gates.size(1) + # print("expanded and grouping") + grouped_grad_out = output_expanded.flatten(0, 1) # reuse expanded buffer later + else: + d_gates = None + gates_flat = None + gate_fan = 1 + grouped_grad_out = None + + if grouped_out: + grouped_grad_out = grad_out + else: + grouped_grad_out = ops.group(grad_out, sorted_scattered_idxs, + fan_out=gate_fan, coeff=gates_flat, + out=grouped_grad_out) + if grouped_in: + grouped_x = x + d_expanded_input = None + else: + grouped_x = ops.group(x, sorted_scattered_idxs, fan_out=k) + d_expanded_input = grouped_x + d_weights = ops.group_bwd_W( + DY=grouped_grad_out, X=grouped_x, + expert_offsets=expert_offsets, + E=expert_weights.size(0) + ) + d_expanded_input = ops.scatter2scatter( + X=grouped_grad_out, x_grouped=True, + W=expert_weights.permute(0, 2, 1), + padded_block_idxs=padded_block_idxs, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + k=1, + y_grouped=grouped_in, + out=d_expanded_input # Reuse grouped_x buffer + ) + + if k == 1: + d_input = d_expanded_input + else: + d_input = d_expanded_input.view(x.size(0), k, d_expanded_input.size(-1)).sum(-2) + # print("backward end.") + return ( + # x, expert_weights, k, + d_input, d_weights, None, + # sorted_expert_idxs, sorted_scattered_idxs, + None, None, + # padded_block_idxs, expert_offsets, + None, None, + # gates + d_gates, None, None + ) + +def parallel_linear(inputs, expert_weights, k, + sorted_expert_idxs, sorted_scattered_idxs, + padded_block_idxs, expert_offsets, + gates=None): + results = ParallelLinear.apply(inputs, expert_weights, k, + sorted_expert_idxs, sorted_scattered_idxs, + padded_block_idxs, expert_offsets, gates) + return results + +class ParallelExperts(nn.Module): + def __init__(self, num_experts, input_size, output_size) -> None: + super().__init__() + self.weight = nn.Parameter(torch.empty(num_experts, output_size, input_size)) + self.num_experts = num_experts + self.input_size = input_size + self.output_size = output_size + + def extra_repr(self): + return 'num_experts={}, input_size={}, output_size={}'.format( + self.num_experts, self.input_size, self.output_size) + + def forward(self, inputs, k, sorted_expert_idxs, sorted_scattered_idxs, + padded_block_idxs, expert_offsets, + gates=None, grouped_in=False, grouped_out=False): + + results = ParallelLinear.apply( + inputs, self.weight.permute(0, 2, 1), k, + sorted_expert_idxs, sorted_scattered_idxs, + padded_block_idxs, expert_offsets, + gates, grouped_in, grouped_out + ) + return results \ No newline at end of file diff --git a/src/axolotl/monkeypatch/moe/mlp.py b/src/axolotl/monkeypatch/moe/mlp.py new file mode 100644 index 000000000..9092d3a86 --- /dev/null +++ b/src/axolotl/monkeypatch/moe/mlp.py @@ -0,0 +1,89 @@ +""" +Adapted from: +https://github.com/shawntan/scattermoe +https://arxiv.org/abs/2403.08245 +""" + +import torch +from torch import nn + +from axolotl.monkeypatch.moe import ops +from axolotl.monkeypatch.moe.linear import ParallelExperts + + +class FusedExperts(nn.Module): + def __init__( + self, + experts, + input_size, + hidden_size, + num_experts, + top_k, + activation=nn.SiLU(), + ): + """ + This implements fused experts that are compatible with Mixtral. + MLP of type Gated-Linear Unit, typically with a SiLU activation function. + """ + super(FusedExperts, self).__init__() + + self.num_experts = num_experts + self.input_size = input_size + self.hidden_size = hidden_size + self.experts = ParallelExperts(num_experts, input_size, 2 * hidden_size) + self.output_experts = ParallelExperts(num_experts, hidden_size, input_size) + self.top_k = min(top_k, self.num_experts) + self.activation = activation + + # parallelize all w1 and w3 computation by concat + stack + self.experts.weight = torch.stack( + [ + torch.cat([experts[i].w1, experts[i].w3], dim=1) + for i in range(len(experts)) + ], + dim=0, + device=experts[0].w1.weight.device, + ) + + # parallelize all w2 computation by stack + self.output_experts.weight = torch.stack( + [expert.w2 for expert in experts], + dim=0, + device=experts[0].w2.weight.device, + ) + + def forward( + self, x: torch.Tensor, expert_p: torch.Tensor, expert_idxs: torch.Tensor + ): + x_shape = x.size() + x = x.view(-1, x_shape[-1]) + with torch.no_grad(): + sorted_expert_idxs, sorted_scattered_idxs = ops.flatten_and_sort( + expert_idxs + ) + padded_block_idxs, expert_offsets = ops.padded_block_indices( + sorted_expert_idxs, self.num_experts + ) + + h, gates = self.experts( + x, + self.top_k, + sorted_expert_idxs, + sorted_scattered_idxs, + padded_block_idxs, + expert_offsets, + grouped_out=True, + ).chunk(2, dim=-1) + h = self.activation(gates) * h + y = self.output_experts( + h, + 1, + sorted_expert_idxs, + sorted_scattered_idxs, + padded_block_idxs, + expert_offsets, + grouped_in=True, + gates=expert_p, + ) + y = y.view(*x_shape[:-1], y.size(-1)) + return y diff --git a/src/axolotl/monkeypatch/moe/moe.py b/src/axolotl/monkeypatch/moe/moe.py new file mode 100644 index 000000000..0f68f0c43 --- /dev/null +++ b/src/axolotl/monkeypatch/moe/moe.py @@ -0,0 +1,43 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from axolotl.monkeypatch.moe.mlp import FusedExperts + +class SparseMoeBlock(nn.Module): + def __init__(self, experts, hidden_dim, ffn_dim, num_experts, top_k): + super().__init__() + self.hidden_dim = hidden_dim + self.ffn_dim = ffn_dim + self.num_experts = num_experts + self.top_k = top_k + self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) + self.experts: FusedExperts = experts + + def _post_training(self, model, name): + # get original weights back: reverse the concat + stack in the fused experts + w1s, w3s = torch.split(torch.unbind(self.experts.experts.weight, dim=0), 2, dim=1) + w2s = torch.unbind(self.experts.output_experts.weight, dim=0) + + # TODO: recreate MoE class with original weights + experts = [] + for i in range(self.num_experts): + pass + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + # Fused expert forward + final_hidden_states = self.experts(hidden_states, routing_weights, selected_experts) + + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states, router_logits diff --git a/src/axolotl/monkeypatch/moe/ops.py b/src/axolotl/monkeypatch/moe/ops.py new file mode 100644 index 000000000..8b674299c --- /dev/null +++ b/src/axolotl/monkeypatch/moe/ops.py @@ -0,0 +1,353 @@ +""" +Adapted from: +https://github.com/shawntan/scattermoe +https://arxiv.org/abs/2403.08245 +""" + +import torch +import triton +import triton.language as tl +from torch.nn import functional as F + +BLOCK_M = 128 + +@torch.jit.script +def flatten_and_sort(expert_idxs:torch.Tensor): + flattened_expert_idxs = expert_idxs.flatten() + sorted_expert_idxs, sorted_scattered_idxs = torch.sort(flattened_expert_idxs) + return sorted_expert_idxs, sorted_scattered_idxs + +@torch.jit.script +def padded_block_indices(sorted_experts_idxs: torch.Tensor, k: int, N_BLOCK_SIZE: int=BLOCK_M) : + expert_counts = torch.bincount(sorted_experts_idxs, minlength=k) + padded_block_counts = ((expert_counts - 1) // N_BLOCK_SIZE) + 1 + padded_expert_block_end = padded_block_counts.cumsum(-1) + expert_boundaries_end = expert_counts.cumsum(-1) + expert_boundaries_start = expert_boundaries_end - expert_counts + padded_expert_block_start = padded_expert_block_end - padded_block_counts + block_idxs = torch.arange(padded_expert_block_end[-1], + dtype=sorted_experts_idxs.dtype, + device=sorted_experts_idxs.device) + block_mask = ( + (block_idxs[:, None] < padded_expert_block_start) | + (block_idxs[:, None] >= padded_expert_block_end) + ) + expanded_block_idxs = ( + N_BLOCK_SIZE * (block_idxs[:, None] - padded_expert_block_start) + + expert_boundaries_start + ) + expanded_block_idxs = expanded_block_idxs.masked_fill(block_mask, 0).sum(-1) + return expanded_block_idxs, expert_boundaries_end + + + +def _scatter2scatter_configs(): + return [ + triton.Config({'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=4), + ] + +@triton.autotune(configs=_scatter2scatter_configs(), key=['M', 'N', 'K'], ) +@triton.heuristics({ + "NO_K_MASK": lambda args: (args['K'] % args['BLOCK_K']) == 0, + "NO_N_MASK": lambda args: (args['N'] % args['BLOCK_N']) == 0, +}) +@triton.jit +def _scatter2scatter( + X_ptr, stride_xm, stride_xk, + W_ptr, stride_we, stride_wk, stride_wn, + Y_ptr, stride_ym, stride_yn, + grouped_idx_ptr, expert_idxs_ptr, block_start_idx_ptr, + FAN_OUT: tl.constexpr, + M: tl.constexpr, K: tl.constexpr, N: tl.constexpr, E: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + ACC_TYPE: tl.constexpr, + OUT_M: tl.constexpr, + allow_tf32: tl.constexpr, + x_grouped: tl.constexpr, y_grouped: tl.constexpr, + NO_K_MASK: tl.constexpr, NO_N_MASK: tl.constexpr +): + pid = tl.program_id(axis=0) + + N_BLOCK_COUNT = tl.cdiv(N, BLOCK_N) + M_block_id = pid // N_BLOCK_COUNT + N_block_id = pid % N_BLOCK_COUNT + M_range = tl.arange(0, BLOCK_M) + block_start_idx = tl.load(block_start_idx_ptr + M_block_id) + # M_block = tl.max_contiguous((block_start_idx + M_range) % OUT_M, BLOCK_M) + M_block = tl.max_contiguous(block_start_idx + M_range, BLOCK_M) + E_idxs = tl.load(expert_idxs_ptr + M_block, mask=M_block < (FAN_OUT * M), other=E) + E_idx = tl.min(E_idxs) + E_mask = E_idxs == E_idx + M_idx = tl.load(grouped_idx_ptr + M_block, mask=E_mask, other=0) + if x_grouped: + M_in_idx = M_block + else: + M_in_idx = M_idx // FAN_OUT + + if y_grouped: + M_out_idx = M_block + else: + M_out_idx = M_idx + + K_block = tl.arange(0, BLOCK_K) + + N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) + N_mask = N_block < N + # N_block = tl.max_contiguous(tl.multiple_of(N_block % N, BLOCK_N), BLOCK_N) + # N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) + + X_blk_ptrs = X_ptr + M_in_idx[:, None] * stride_xm + K_block[None, :] * stride_xk + W_blk_ptrs = W_ptr + K_block[:, None] * stride_wk + N_block[None, :] * stride_wn + E_idx * stride_we + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + iters = tl.cdiv(K, BLOCK_K) + for K_block_id in range(0, iters): + if NO_K_MASK: + x = tl.load(X_blk_ptrs, mask=E_mask[:, None]) + if NO_N_MASK: + w = tl.load(W_blk_ptrs) + else: + w = tl.load(W_blk_ptrs, mask=N_mask[None, :]) + else: + K_mask = (K_block_id * BLOCK_K + K_block) < K + x = tl.load(X_blk_ptrs, mask=E_mask[:, None] & K_mask[None, :]) + w = tl.load(W_blk_ptrs, mask=K_mask[:, None] & N_mask[None, :]) + X_blk_ptrs += BLOCK_K * stride_xk + W_blk_ptrs += BLOCK_K * stride_wk + acc += tl.dot(x, w, allow_tf32=allow_tf32, out_dtype=ACC_TYPE) + + Y_blk_ptrs = Y_ptr + (M_out_idx[:, None] * stride_ym + N_block[None, :] * stride_yn) + tl.store(Y_blk_ptrs, acc, mask=E_mask[:, None] & N_mask[None, :]) + +def scatter2scatter(X, W, sorted_expert_idxs, sorted_scattered_idxs, k, + padded_block_idxs, x_grouped=False, y_grouped=False, + out=None): + assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0) + assert sorted_scattered_idxs.size(0) == X.size(0) * k + # Pre-kernel setup + x_dim = X.size(-1) + y_dim = W.size(-1) + L_scattered = sorted_expert_idxs.size(0) + if out is None: + O = torch.empty((L_scattered, y_dim), device=X.device, dtype=X.dtype) + else: + assert out.size(0) == L_scattered and out.size(1) == y_dim + O = out + + def grid(META): + grid_num = ( + padded_block_idxs.size(0) * + triton.cdiv(META['N'], META['BLOCK_N']), + ) + return grid_num + """ + print("X", X.size(), X.stride(), + "W", W.size(), W.stride(), + "O", O.size(), O.stride(), + "sorted_idxs", sorted_scattered_idxs.size(), + "FAN_OUT", k, + "BLOCK_M", BLOCK_M, + "grouped", (x_grouped, y_grouped)) + """ + _scatter2scatter[grid]( + # X_ptr, stride_xm, stride_xk, + X, X.stride(0), X.stride(1), + # W_ptr, stride_we, stride_wk, stride_wn, + W, W.stride(0), W.stride(1), W.stride(2), + # Y_ptr, stride_ym, stride_yn, + O, O.stride(0), O.stride(1), + grouped_idx_ptr=sorted_scattered_idxs, + expert_idxs_ptr=sorted_expert_idxs, + block_start_idx_ptr=padded_block_idxs, + FAN_OUT=k, + M=X.size(0), + K=X.size(1), + N=O.size(1), E=W.size(0), + BLOCK_M=BLOCK_M, + ACC_TYPE=tl.float32, + OUT_M=O.size(0), + allow_tf32=True, + x_grouped=x_grouped, y_grouped=y_grouped, + ) + return O + + +def _config_XtY(): + return [ + triton.Config({'BLOCK_N': 128, 'BLOCK_K': 128, 'BLOCK_M': 32}, num_stages=4, num_warps=4), + ] + +def group_bwd_W(DY, X, expert_offsets, E): + DWt = torch.zeros((E, DY.size(-1), X.size(-1)), device=DY.device, dtype=DY.dtype) + DW = DWt.permute(0, 2, 1) + def grid(META): + grid = ( + E * triton.cdiv(META['K'], META['BLOCK_K']), + triton.cdiv(META['N'], META['BLOCK_N']), + ) + return grid + _groupXtY[grid]( + # DY_ptr, stride_dym, stride_dyk, + DY, DY.stride(0), DY.stride(1), + # X_ptr, stride_xm, stride_xn, + X, X.stride(0), X.stride(1), + # DW_ptr, stride_dwe, stride_dwk, stride_dwn, + DW, DW.stride(0), DW.stride(1), DW.stride(2), + # expert_offsets_ptr, + expert_offsets, + # K: tl.constexpr, N: tl.constexpr, + M=DY.size(0), N=DY.size(-1), K=X.size(-1), + # ACC_TYPE: tl.constexpr, + ACC_TYPE=tl.float32, + allow_tf32=True + ) + return DW + +@triton.autotune(configs=_config_XtY(), key=['M', 'N', 'K'], ) +@triton.heuristics({ + "NO_K_MASK": lambda args: (args['K'] % args['BLOCK_K']) == 0, + "NO_N_MASK": lambda args: (args['N'] % args['BLOCK_N']) == 0, +}) +@triton.jit +def _groupXtY( + DY_ptr, stride_dym, stride_dyk, + X_ptr, stride_xm, stride_xn, + DW_ptr, stride_dwe, stride_dwk, stride_dwn, + expert_offsets_ptr, + M: tl.constexpr, K: tl.constexpr, N: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + ACC_TYPE: tl.constexpr, + allow_tf32: tl.constexpr, + NO_K_MASK: tl.constexpr, NO_N_MASK: tl.constexpr +): + pid0 = tl.program_id(axis=0) + pid1 = tl.program_id(axis=1) + num0 = tl.num_programs(0) + num1 = tl.num_programs(1) + pid1, pid0 = tl.swizzle2d(pid1, pid0, num1, num0, 128) + + K_BLOCK_COUNT = tl.cdiv(K, BLOCK_K) + E_idx = pid0 // K_BLOCK_COUNT + K_block_id = pid0 % K_BLOCK_COUNT + N_block_id = pid1 + + if E_idx == 0: + start_idx = 0 + else: + start_idx = tl.load(expert_offsets_ptr + E_idx - 1).to(tl.int32) + end_idx = tl.load(expert_offsets_ptr + E_idx).to(tl.int32) + + if end_idx > start_idx: + M_block = tl.max_contiguous(start_idx + tl.arange(0, BLOCK_M), BLOCK_M) + + K_block = K_block_id * BLOCK_K + tl.arange(0, BLOCK_K) + K_mask = K_block < K + K_block = tl.max_contiguous(tl.multiple_of(K_block % K, BLOCK_K), BLOCK_K) + + N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) + N_mask = N_block < N + N_block = tl.max_contiguous(tl.multiple_of(N_block % N, BLOCK_N), BLOCK_N) + + M_idxs = M_block + xt_blk_ptrs = X_ptr + K_block[:, None] * stride_xn + M_idxs[None, :] * stride_xm + dy_blk_ptrs = DY_ptr + M_idxs[:, None] * stride_dym + N_block[None, :] * stride_dyk + + acc = tl.zeros((BLOCK_K, BLOCK_N), dtype=ACC_TYPE) + iters = tl.cdiv(end_idx - start_idx, BLOCK_M) + for i in range(0, iters): + M_mask = (i * BLOCK_M + M_block) < end_idx + if NO_K_MASK: + xt = tl.load(xt_blk_ptrs, mask=M_mask[None, :]) + else: + xt = tl.load(xt_blk_ptrs, mask=K_mask[:, None] & M_mask[None, :]) + if NO_N_MASK: + dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None]) + else: + dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None] & N_mask[None, :]) + acc += tl.dot(xt, dy, out_dtype=ACC_TYPE, allow_tf32=allow_tf32) + xt_blk_ptrs += BLOCK_M * stride_xm + dy_blk_ptrs += BLOCK_M * stride_dym + + + DW_blk_ptrs = DW_ptr + E_idx * stride_dwe + K_block[:, None] * stride_dwk + N_block[None, :] * stride_dwn + acc = acc.to(DW_blk_ptrs.dtype.element_ty) + tl.store(DW_blk_ptrs, acc, mask=K_mask[:, None] & N_mask[None, :]) + + +def _config_grouping(): + return [ + triton.Config({'BLOCK_N': 256, 'BLOCK_K': 128}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_N': 64, 'BLOCK_K': 32}, num_stages=4, num_warps=4), + ] + +def group(A, sorted_expert_idxs, coeff=None, fan_out=1, out=None): + N = sorted_expert_idxs.size(0) + K = A.size(1) + assert A.size(0) * fan_out == N + if out is not None: + Y = out + else: + Y = torch.empty((N, K), dtype=A.dtype, device=A.device) + # print("grp init:", Y.size()) + def grid(META): + grid_num = (triton.cdiv(META['N'], META['BLOCK_N']),) + return grid_num + _group[grid]( + # A_ptr, stride_an, stride_ai, + A, A.stride(0), A.stride(1), coeff is not None, coeff, fan_out, + # Y_ptr, stride_yn, stride_yk, + Y, Y.stride(0), Y.stride(1), + # grouped_idx_ptr, + sorted_expert_idxs, + # N: tl.constexpr, K: tl.constexpr, + N, K + ) + return Y + +@triton.autotune(configs=_config_grouping(), key=['K']) +@triton.heuristics({ + "NO_K_MASK": lambda args: (args['K'] % args['BLOCK_K']) == 0 +}) +@triton.jit +def _group( + src_ptr, stride_sn, stride_sk, has_coeff: tl.constexpr, coeff_ptr, FAN_OUT: tl.constexpr, + tgt_ptr, stride_tn, stride_ti, + grouped_idx_ptr, + N: tl.constexpr, K: tl.constexpr, + BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + NO_K_MASK: tl.constexpr +): + pid = tl.program_id(axis=0) + + N_block_id = pid + N_blk = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) + N_mask = N_blk < N + N_blk = tl.max_contiguous(tl.multiple_of(N_blk % N, BLOCK_N), BLOCK_N) + N_idx = tl.load(grouped_idx_ptr + N_blk, mask=N_mask, other=0) + + K_blk = tl.arange(0, BLOCK_K) + src_blk_ptrs = src_ptr + (N_idx // FAN_OUT)[:, None] * stride_sn + K_blk[None, :] * stride_sk + tgt_blk_ptrs = tgt_ptr + N_blk[:, None] * stride_tn + K_blk[None, :] * stride_ti + + if has_coeff: + c = tl.load(coeff_ptr + N_idx, mask=N_mask)[:, None] + + iters = tl.cdiv(K, BLOCK_K) + for i in range(0, iters): + if NO_K_MASK: + block = tl.load(src_blk_ptrs) # , mask=N_mask[:, None]) + if has_coeff: + block *= c + tl.store(tgt_blk_ptrs, block, mask=N_mask[:, None]) + + else: + K_mask = (i * BLOCK_K + K_blk) < K + mask = N_mask[:, None] & K_mask[None, :] + block = tl.load(src_blk_ptrs, mask=mask) + if has_coeff: + block *= c + tl.store(tgt_blk_ptrs, block, mask=mask) + + src_blk_ptrs += BLOCK_K * stride_sk + tgt_blk_ptrs += BLOCK_K * stride_ti \ No newline at end of file diff --git a/src/axolotl/monkeypatch/moe/single.py b/src/axolotl/monkeypatch/moe/single.py new file mode 100644 index 000000000..e8bc3ca78 --- /dev/null +++ b/src/axolotl/monkeypatch/moe/single.py @@ -0,0 +1,66 @@ +""" +Adapted from: +https://github.com/shawntan/scattermoe +https://arxiv.org/abs/2403.08245 +""" + +import torch +import triton +import triton.language as tl +from torch.nn import functional as F + +@triton.jit +def _single2scatter( + X_ptr, stride_xm, stride_xk, + W_ptr, stride_we, stride_wk, stride_wn, + Y_ptr, stride_ym, stride_yn, + expert_idxs_ptr, + FAN_OUT: tl.constexpr, + K: tl.constexpr, N: tl.constexpr, E: tl.constexpr, + BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + ACC_TYPE: tl.constexpr, +): + pid0 = tl.program_id(axis=0) + pid1 = tl.program_id(axis=1) + + N_block_id = pid0 + if FAN_OUT == 1: + in_idx = pid1 + else: + in_idx = 0 + out_idx = pid1 + + K_block = tl.arange(0, BLOCK_K) + N_block = tl.max_contiguous(tl.multiple_of((N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)) % N, BLOCK_N), BLOCK_N) + E_idx = tl.load(expert_idxs_ptr + pid1) + X_blk_ptrs = X_ptr + in_idx * stride_xm + K_block[:, None] * stride_xk + W_blk_ptrs = W_ptr + E_idx * stride_we + K_block[:, None] * stride_wk + N_block[None, :] * stride_wn + acc = tl.zeros((1, BLOCK_N), dtype=ACC_TYPE) + for K_block_id in range(0, tl.cdiv(K, BLOCK_K)): + x = tl.load(X_blk_ptrs) + w = tl.load(W_blk_ptrs) + acc += tl.sum(x * w, axis=0)[None, :] + X_blk_ptrs += BLOCK_K * stride_xk + W_blk_ptrs += BLOCK_K * stride_wk + Y_blk_ptrs = Y_ptr + out_idx * stride_ym + N_block[None, :] * stride_yn + tl.store(Y_blk_ptrs, acc) + +def single2scatter(X, W, expert_idxs): + E, xdim, ydim = W.size() + k = expert_idxs.size(1) + assert X.size(0) == k or X.size(0) == 1 + Y = torch.empty((k, ydim), device=X.device, dtype=X.dtype) + BLOCK_N = 128 + BLOCK_K = 128 + grid = ydim // BLOCK_N, k + _single2scatter[grid]( + X, X.stride(0), X.stride(1), + W, W.stride(0), W.stride(1), W.stride(2), + Y, Y.stride(0), Y.stride(1), + expert_idxs, + FAN_OUT=Y.size(0) // X.size(0), + K=xdim, N=ydim, E=E, + BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, + ACC_TYPE=tl.float32 + ) + return Y \ No newline at end of file diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index fce7b20a7..6128269b2 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -715,32 +715,35 @@ def load_model( if cfg.flash_attn_fuse_qkv: LOG.info("patching with fused QKV") replace_llama_qkv_with_fused(model) - # elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention: - # This is a WIP, still an issue with the backward pass - # RuntimeError: grad can be implicitly created only for scalar outputs - # TODO: try config.sequence_parallel = False - # # https://github.com/HazyResearch/flash-attention/blob/40a25c8ee7465cf547b929cfa2937034e37bfce9/tests/models/test_gpt_neox.py#L12 - # # https://github.com/HazyResearch/flash-attention/tree/main/training#model-components - # # add `**kwargs` to https://github.com/HazyResearch/flash-attention/blob/40a25c8ee7465cf547b929cfa2937034e37bfce9/flash_attn/models/gpt.py#L442 - # from flash_attn.utils.pretrained import state_dict_from_pretrained - # from flash_attn.models.gpt import GPTLMHeadModel - # from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox, gpt_neox_config_to_gpt2_config - # from transformers import GPTNeoXConfig - # config = gpt_neox_config_to_gpt2_config(GPTNeoXConfig.from_pretrained(base_model)) - # config.use_flash_attn = True - # config.fused_bias_fc = True - # config.fused_mlp = True # GPT-NeoX-20B uses "gelu_fast" - # config.activation_function = "gelu_fast" - # config.fused_dropout_add_ln = True - # # config.residual_in_fp32 = True - # - # model: GPTLMHeadModel = GPTLMHeadModel.from_pretrained( - # base_model, - # config, - # dtype=torch_dtype, - # device=cfg.device, - # ) - # model.train() # sets to train instead of eval mode + elif ( + model_config.model_type == "mixtral" + and not cfg.adapter + and cfg.fuse_moe + ): + from axolotl.monkeypatch.moe.mlp import FusedExperts + from axolotl.monkeypatch.utils import set_module_name + from axolotl.monkeypatch.moe.moe import SparseMoeBlock + from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock + + for name, module in model.named_modules(): + if isinstance(module, MixtralSparseMoeBlock): + experts = FusedExperts( + experts=module.experts, + input_size=module.ffn_dim, + hidden_size=module.hidden_dim, + num_experts=module.num_experts, + top_k=module.top_k, + activation=module.experts[0].act_fn + ) + smoe = SparseMoeBlock( + experts=experts, + hidden_dim=module.hidden_dim, + ffn_dim=module.ffn_dim, + num_experts=module.num_experts, + top_k=module.top_k, + ) + set_module_name(model, name, smoe) + elif model_type == "MambaLMHeadModel": # FIXME this is janky at best and hacked together to make it work MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name diff --git a/tests/monkeypatch/test_moe.py b/tests/monkeypatch/test_moe.py new file mode 100644 index 000000000..1a4ead522 --- /dev/null +++ b/tests/monkeypatch/test_moe.py @@ -0,0 +1,59 @@ +import torch +from copy import deepcopy +from axolotl.monkeypatch.moe.mlp import FusedExperts +from axolotl.monkeypatch.moe.moe import SparseMoeBlock +from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock + +def test_fused_mixtral_moe(): + # Set random seeds for reproducibility + torch.manual_seed(0) + torch.cuda.manual_seed(0) + torch.cuda.manual_seed_all(0) + + # Define the configuration for the MixtralSparseMoeBlock + config = { + 'hidden_size': 128, + 'intermediate_size': 512, + 'num_local_experts': 8, + 'num_experts_per_tok': 2, + } + + # Initialize the MixtralSparseMoeBlock and SparseMoeBlock with the same configuration + mixtral_moe = MixtralSparseMoeBlock(config) + mixtral_moe_copy = deepcopy(mixtral_moe) + + experts = FusedExperts( + experts=mixtral_moe_copy.experts, + input_size=mixtral_moe_copy.ffn_dim, + hidden_size=mixtral_moe_copy.hidden_dim, + num_experts=mixtral_moe_copy.num_experts, + top_k=mixtral_moe_copy.top_k, + activation=mixtral_moe_copy.experts[0].act_fn + ) + sparse_moe = SparseMoeBlock( + experts, + hidden_dim=config['hidden_size'], + ffn_dim=config['intermediate_size'], + num_experts=config['num_local_experts'], + top_k=config['num_experts_per_tok'] + ) + + # Generate random input data + batch_size = 16 + sequence_length = 32 + input_data = torch.randn(batch_size, sequence_length, config['hidden_size']) + + # Run the forward pass with gradients for both models + mixtral_output, mixtral_router_logits = mixtral_moe(input_data) + sparse_output, sparse_router_logits = sparse_moe(input_data) + + # Compute the difference between the outputs and router logits + output_diff = torch.abs(mixtral_output - sparse_output).mean().item() + router_logits_diff = torch.abs(mixtral_router_logits - sparse_router_logits).mean().item() + + # Define the tolerance for the difference + tolerance = 0.00001 + + # Check if the difference is within the tolerance + assert output_diff < tolerance, f"Output difference is {output_diff}, which is greater than the tolerance of {tolerance}" + assert router_logits_diff < tolerance, f"Router logits difference is {router_logits_diff}, which is greater than the tolerance of {tolerance}" \ No newline at end of file