Compare commits
11 Commits
train-refa
...
scatter_mo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
10328b3429 | ||
|
|
5bfc470d57 | ||
|
|
04168801c9 | ||
|
|
d43a79b7bf | ||
|
|
884d81331e | ||
|
|
2ea75b4160 | ||
|
|
035e680631 | ||
|
|
26fc10df01 | ||
|
|
1bc008e901 | ||
|
|
3f7ed6a784 | ||
|
|
feea977923 |
75
examples/mistral/mixtral_fused.py
Normal file
75
examples/mistral/mixtral_fused.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import gc
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from axolotl.monkeypatch.moe.moe import SparseMoeBlock
|
||||
from transformers import AutoTokenizer, TextStreamer
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock, MixtralForCausalLM, MixtralConfig
|
||||
|
||||
def compute_memory_used_pct(device):
|
||||
memory_used = torch.cuda.max_memory_allocated(device) / (1024**3)
|
||||
memory_pct = (
|
||||
memory_used
|
||||
/ (torch.cuda.get_device_properties(device).total_memory / (1024**3))
|
||||
* 100
|
||||
)
|
||||
return memory_pct
|
||||
|
||||
model_path = "mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||
|
||||
# Load model
|
||||
config = MixtralConfig.from_pretrained(model_path, max_position_embeddings=2048, use_cache=False)
|
||||
model = MixtralForCausalLM.from_pretrained(
|
||||
model_path,
|
||||
config=config,
|
||||
device_map="auto",
|
||||
low_cpu_mem_usage=True,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
modules = {k:v for k,v in model.named_modules() if isinstance(v, MixtralSparseMoeBlock)}
|
||||
|
||||
for device_index in range(torch.cuda.device_count()):
|
||||
device_memory_pct = compute_memory_used_pct(device_index)
|
||||
print(device_index, device_memory_pct)
|
||||
|
||||
with tqdm(modules.items(), desc="scatter moe") as pbar:
|
||||
for i, (name, module) in enumerate(pbar):
|
||||
smoe = SparseMoeBlock(
|
||||
experts=module.experts,
|
||||
gate=module.gate,
|
||||
hidden_dim=module.hidden_dim,
|
||||
ffn_dim=module.ffn_dim,
|
||||
num_experts=module.num_experts,
|
||||
top_k=module.top_k,
|
||||
)
|
||||
old_module = model.model.layers[i].block_sparse_moe
|
||||
setattr(model.model.layers[i], "block_sparse_moe", smoe)
|
||||
del old_module
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
for device_index in range(torch.cuda.device_count()):
|
||||
device_memory_pct = compute_memory_used_pct(device_index)
|
||||
print(device_index, device_memory_pct)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
||||
|
||||
# Convert prompt to tokens
|
||||
prompt_template = "[INST] {prompt} [/INST]"
|
||||
|
||||
prompt = "You're standing on the surface of the Earth. "\
|
||||
"You walk one mile south, one mile west and one mile north. "\
|
||||
"You end up exactly where you started. Where are you?"
|
||||
|
||||
tokens = tokenizer(
|
||||
prompt_template.format(prompt=prompt),
|
||||
return_tensors='pt'
|
||||
).input_ids.cuda()
|
||||
|
||||
# Generate output
|
||||
generation_output = model.generate(
|
||||
tokens,
|
||||
streamer=streamer,
|
||||
max_new_tokens=512
|
||||
)
|
||||
0
src/axolotl/monkeypatch/moe/__init__.py
Normal file
0
src/axolotl/monkeypatch/moe/__init__.py
Normal file
149
src/axolotl/monkeypatch/moe/linear.py
Normal file
149
src/axolotl/monkeypatch/moe/linear.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""
|
||||
Adapted from:
|
||||
https://github.com/shawntan/scattermoe
|
||||
https://arxiv.org/abs/2403.08245
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from axolotl.monkeypatch.moe import ops
|
||||
|
||||
class ParallelLinear(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx, x, expert_weights, k,
|
||||
sorted_expert_idxs, sorted_scattered_idxs,
|
||||
padded_block_idxs, expert_offsets,
|
||||
gates=None, grouped_in=False, grouped_out=False,
|
||||
):
|
||||
|
||||
output = ops.scatter2scatter(
|
||||
X=x, W=expert_weights,
|
||||
sorted_expert_idxs=sorted_expert_idxs,
|
||||
sorted_scattered_idxs=sorted_scattered_idxs,
|
||||
padded_block_idxs=padded_block_idxs,
|
||||
k=k, x_grouped=grouped_in, y_grouped=grouped_out
|
||||
)
|
||||
if gates is not None:
|
||||
output_expanded = output.view(gates.size(0), gates.size(1), output.size(-1))
|
||||
output = torch.bmm(
|
||||
gates[:, None, :],
|
||||
output_expanded
|
||||
).squeeze(1)
|
||||
else:
|
||||
output_expanded = None
|
||||
|
||||
ctx.save_for_backward(
|
||||
x, expert_weights,
|
||||
sorted_expert_idxs,
|
||||
sorted_scattered_idxs,
|
||||
padded_block_idxs, expert_offsets,
|
||||
gates,
|
||||
output_expanded
|
||||
)
|
||||
ctx.grouped_in = grouped_in
|
||||
ctx.grouped_out = grouped_out
|
||||
ctx.k = k
|
||||
return output
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out):
|
||||
(x, expert_weights,
|
||||
sorted_expert_idxs,
|
||||
sorted_scattered_idxs,
|
||||
padded_block_idxs, expert_offsets,
|
||||
gates, output_expanded) = ctx.saved_tensors
|
||||
k = ctx.k
|
||||
grouped_in = ctx.grouped_in
|
||||
grouped_out = ctx.grouped_out
|
||||
# print("backward")
|
||||
if gates is not None:
|
||||
# calculate gates gradient
|
||||
d_gates = torch.bmm(output_expanded, grad_out[:, :, None]).squeeze(-1)
|
||||
gates_flat = gates.flatten()
|
||||
gate_fan = gates.size(1)
|
||||
# print("expanded and grouping")
|
||||
grouped_grad_out = output_expanded.flatten(0, 1) # reuse expanded buffer later
|
||||
else:
|
||||
d_gates = None
|
||||
gates_flat = None
|
||||
gate_fan = 1
|
||||
grouped_grad_out = None
|
||||
|
||||
if grouped_out:
|
||||
grouped_grad_out = grad_out
|
||||
else:
|
||||
grouped_grad_out = ops.group(grad_out, sorted_scattered_idxs,
|
||||
fan_out=gate_fan, coeff=gates_flat,
|
||||
out=grouped_grad_out)
|
||||
if grouped_in:
|
||||
grouped_x = x
|
||||
d_expanded_input = None
|
||||
else:
|
||||
grouped_x = ops.group(x, sorted_scattered_idxs, fan_out=k)
|
||||
d_expanded_input = grouped_x
|
||||
d_weights = ops.group_bwd_W(
|
||||
DY=grouped_grad_out, X=grouped_x,
|
||||
expert_offsets=expert_offsets,
|
||||
E=expert_weights.size(0)
|
||||
)
|
||||
d_expanded_input = ops.scatter2scatter(
|
||||
X=grouped_grad_out, x_grouped=True,
|
||||
W=expert_weights.permute(0, 2, 1),
|
||||
padded_block_idxs=padded_block_idxs,
|
||||
sorted_expert_idxs=sorted_expert_idxs,
|
||||
sorted_scattered_idxs=sorted_scattered_idxs,
|
||||
k=1,
|
||||
y_grouped=grouped_in,
|
||||
out=d_expanded_input # Reuse grouped_x buffer
|
||||
)
|
||||
|
||||
if k == 1:
|
||||
d_input = d_expanded_input
|
||||
else:
|
||||
d_input = d_expanded_input.view(x.size(0), k, d_expanded_input.size(-1)).sum(-2)
|
||||
# print("backward end.")
|
||||
return (
|
||||
# x, expert_weights, k,
|
||||
d_input, d_weights, None,
|
||||
# sorted_expert_idxs, sorted_scattered_idxs,
|
||||
None, None,
|
||||
# padded_block_idxs, expert_offsets,
|
||||
None, None,
|
||||
# gates
|
||||
d_gates, None, None
|
||||
)
|
||||
|
||||
def parallel_linear(inputs, expert_weights, k,
|
||||
sorted_expert_idxs, sorted_scattered_idxs,
|
||||
padded_block_idxs, expert_offsets,
|
||||
gates=None):
|
||||
results = ParallelLinear.apply(inputs, expert_weights, k,
|
||||
sorted_expert_idxs, sorted_scattered_idxs,
|
||||
padded_block_idxs, expert_offsets, gates)
|
||||
return results
|
||||
|
||||
class ParallelExperts(nn.Module):
|
||||
def __init__(self, num_experts, input_size, output_size, device) -> None:
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(
|
||||
torch.empty(num_experts, output_size, input_size, device=device)
|
||||
)
|
||||
self.num_experts = num_experts
|
||||
self.input_size = input_size
|
||||
self.output_size = output_size
|
||||
|
||||
def extra_repr(self):
|
||||
return 'num_experts={}, input_size={}, output_size={}'.format(
|
||||
self.num_experts, self.input_size, self.output_size)
|
||||
|
||||
def forward(self, inputs, k, sorted_expert_idxs, sorted_scattered_idxs,
|
||||
padded_block_idxs, expert_offsets,
|
||||
gates=None, grouped_in=False, grouped_out=False):
|
||||
|
||||
results = ParallelLinear.apply(
|
||||
inputs, self.weight.permute(0, 2, 1), k,
|
||||
sorted_expert_idxs, sorted_scattered_idxs,
|
||||
padded_block_idxs, expert_offsets,
|
||||
gates, grouped_in, grouped_out
|
||||
)
|
||||
return results
|
||||
86
src/axolotl/monkeypatch/moe/mlp.py
Normal file
86
src/axolotl/monkeypatch/moe/mlp.py
Normal file
@@ -0,0 +1,86 @@
|
||||
"""
|
||||
Adapted from:
|
||||
https://github.com/shawntan/scattermoe
|
||||
https://arxiv.org/abs/2403.08245
|
||||
"""
|
||||
|
||||
import gc
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from axolotl.monkeypatch.moe import ops
|
||||
from axolotl.monkeypatch.moe.linear import ParallelExperts
|
||||
|
||||
|
||||
class FusedExperts(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
experts: nn.ModuleList =None,
|
||||
hidden_dim=128,
|
||||
ffn_dim=512,
|
||||
num_experts=8,
|
||||
top_k=2,
|
||||
activation=nn.SiLU(),
|
||||
):
|
||||
"""
|
||||
This implements fused experts that are compatible with Mixtral.
|
||||
MLP of type Gated-Linear Unit, typically with a SiLU activation function.
|
||||
"""
|
||||
super(FusedExperts, self).__init__()
|
||||
|
||||
device = experts[0].w1.weight.device
|
||||
self.num_experts = num_experts
|
||||
self.hidden_dim = hidden_dim
|
||||
self.ffn_dim = ffn_dim
|
||||
self.experts = ParallelExperts(num_experts, hidden_dim, 2 * ffn_dim, device=device)
|
||||
self.output_experts = ParallelExperts(num_experts, ffn_dim, hidden_dim, device=device)
|
||||
self.top_k = min(top_k, self.num_experts)
|
||||
self.activation = activation
|
||||
|
||||
with torch.no_grad():
|
||||
for i in range(len(experts)):
|
||||
self.experts.weight.data[i].copy_(
|
||||
torch.cat(
|
||||
[experts[i].w1.weight.detach(), experts[i].w3.weight.detach()],
|
||||
dim=0
|
||||
)
|
||||
)
|
||||
self.output_experts.weight.data[i].copy_(
|
||||
experts[i].w2.weight.detach()
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, routing_weights: torch.Tensor, selected_experts: torch.Tensor
|
||||
):
|
||||
x_shape = x.size()
|
||||
x = x.view(-1, x_shape[-1])
|
||||
with torch.no_grad():
|
||||
sorted_expert_idxs, sorted_scattered_idxs = ops.flatten_and_sort(
|
||||
selected_experts
|
||||
)
|
||||
padded_block_idxs, expert_offsets = ops.padded_block_indices(
|
||||
sorted_expert_idxs, self.num_experts
|
||||
)
|
||||
|
||||
h, gates = self.experts(
|
||||
x,
|
||||
self.top_k,
|
||||
sorted_expert_idxs,
|
||||
sorted_scattered_idxs,
|
||||
padded_block_idxs,
|
||||
expert_offsets,
|
||||
grouped_out=True,
|
||||
).chunk(2, dim=-1)
|
||||
h = self.activation(gates) * h
|
||||
y = self.output_experts(
|
||||
h,
|
||||
1,
|
||||
sorted_expert_idxs,
|
||||
sorted_scattered_idxs,
|
||||
padded_block_idxs,
|
||||
expert_offsets,
|
||||
grouped_in=True,
|
||||
gates=routing_weights,
|
||||
)
|
||||
y = y.view(*x_shape[:-1], y.size(-1))
|
||||
return y
|
||||
50
src/axolotl/monkeypatch/moe/moe.py
Normal file
50
src/axolotl/monkeypatch/moe/moe.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from axolotl.monkeypatch.moe.mlp import FusedExperts
|
||||
|
||||
class SparseMoeBlock(nn.Module):
|
||||
def __init__(self, experts, gate, hidden_dim, ffn_dim, num_experts, top_k):
|
||||
super().__init__()
|
||||
self.hidden_dim = hidden_dim
|
||||
self.ffn_dim = ffn_dim
|
||||
self.num_experts = num_experts
|
||||
self.top_k = top_k
|
||||
self.gate = gate
|
||||
self.experts = FusedExperts(
|
||||
experts=experts,
|
||||
hidden_dim=hidden_dim,
|
||||
ffn_dim=ffn_dim,
|
||||
num_experts=num_experts,
|
||||
top_k=top_k,
|
||||
activation=experts[0].act_fn
|
||||
)
|
||||
|
||||
def _post_training(self, model, name):
|
||||
# get original weights back: reverse the concat + stack in the fused experts
|
||||
w1s, w3s = torch.split(torch.unbind(self.experts.experts.weight, dim=0), 2, dim=1)
|
||||
w2s = torch.unbind(self.experts.output_experts.weight, dim=0)
|
||||
|
||||
# TODO: recreate MoE class with original weights
|
||||
experts = []
|
||||
for i in range(self.num_experts):
|
||||
pass
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
|
||||
# router_logits: (batch * sequence_length, n_experts)
|
||||
router_logits = self.gate(hidden_states)
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
||||
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
# we cast back to the input dtype
|
||||
routing_weights = routing_weights.to(hidden_states.dtype)
|
||||
|
||||
# Fused expert forward
|
||||
final_hidden_states = self.experts(hidden_states, routing_weights, selected_experts)
|
||||
|
||||
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
|
||||
return final_hidden_states, router_logits
|
||||
353
src/axolotl/monkeypatch/moe/ops.py
Normal file
353
src/axolotl/monkeypatch/moe/ops.py
Normal file
@@ -0,0 +1,353 @@
|
||||
"""
|
||||
Adapted from:
|
||||
https://github.com/shawntan/scattermoe
|
||||
https://arxiv.org/abs/2403.08245
|
||||
"""
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from torch.nn import functional as F
|
||||
|
||||
BLOCK_M = 128
|
||||
|
||||
@torch.jit.script
|
||||
def flatten_and_sort(expert_idxs:torch.Tensor):
|
||||
flattened_expert_idxs = expert_idxs.flatten()
|
||||
sorted_expert_idxs, sorted_scattered_idxs = torch.sort(flattened_expert_idxs)
|
||||
return sorted_expert_idxs, sorted_scattered_idxs
|
||||
|
||||
@torch.jit.script
|
||||
def padded_block_indices(sorted_experts_idxs: torch.Tensor, k: int, N_BLOCK_SIZE: int=BLOCK_M) :
|
||||
expert_counts = torch.bincount(sorted_experts_idxs, minlength=k)
|
||||
padded_block_counts = ((expert_counts - 1) // N_BLOCK_SIZE) + 1
|
||||
padded_expert_block_end = padded_block_counts.cumsum(-1)
|
||||
expert_boundaries_end = expert_counts.cumsum(-1)
|
||||
expert_boundaries_start = expert_boundaries_end - expert_counts
|
||||
padded_expert_block_start = padded_expert_block_end - padded_block_counts
|
||||
block_idxs = torch.arange(padded_expert_block_end[-1],
|
||||
dtype=sorted_experts_idxs.dtype,
|
||||
device=sorted_experts_idxs.device)
|
||||
block_mask = (
|
||||
(block_idxs[:, None] < padded_expert_block_start) |
|
||||
(block_idxs[:, None] >= padded_expert_block_end)
|
||||
)
|
||||
expanded_block_idxs = (
|
||||
N_BLOCK_SIZE * (block_idxs[:, None] - padded_expert_block_start) +
|
||||
expert_boundaries_start
|
||||
)
|
||||
expanded_block_idxs = expanded_block_idxs.masked_fill(block_mask, 0).sum(-1)
|
||||
return expanded_block_idxs, expert_boundaries_end
|
||||
|
||||
|
||||
|
||||
def _scatter2scatter_configs():
|
||||
return [
|
||||
triton.Config({'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
|
||||
]
|
||||
|
||||
@triton.autotune(configs=_scatter2scatter_configs(), key=['M', 'N', 'K'], )
|
||||
@triton.heuristics({
|
||||
"NO_K_MASK": lambda args: (args['K'] % args['BLOCK_K']) == 0,
|
||||
"NO_N_MASK": lambda args: (args['N'] % args['BLOCK_N']) == 0,
|
||||
})
|
||||
@triton.jit
|
||||
def _scatter2scatter(
|
||||
X_ptr, stride_xm, stride_xk,
|
||||
W_ptr, stride_we, stride_wk, stride_wn,
|
||||
Y_ptr, stride_ym, stride_yn,
|
||||
grouped_idx_ptr, expert_idxs_ptr, block_start_idx_ptr,
|
||||
FAN_OUT: tl.constexpr,
|
||||
M: tl.constexpr, K: tl.constexpr, N: tl.constexpr, E: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
ACC_TYPE: tl.constexpr,
|
||||
OUT_M: tl.constexpr,
|
||||
allow_tf32: tl.constexpr,
|
||||
x_grouped: tl.constexpr, y_grouped: tl.constexpr,
|
||||
NO_K_MASK: tl.constexpr, NO_N_MASK: tl.constexpr
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
|
||||
N_BLOCK_COUNT = tl.cdiv(N, BLOCK_N)
|
||||
M_block_id = pid // N_BLOCK_COUNT
|
||||
N_block_id = pid % N_BLOCK_COUNT
|
||||
M_range = tl.arange(0, BLOCK_M)
|
||||
block_start_idx = tl.load(block_start_idx_ptr + M_block_id)
|
||||
# M_block = tl.max_contiguous((block_start_idx + M_range) % OUT_M, BLOCK_M)
|
||||
M_block = tl.max_contiguous(block_start_idx + M_range, BLOCK_M)
|
||||
E_idxs = tl.load(expert_idxs_ptr + M_block, mask=M_block < (FAN_OUT * M), other=E)
|
||||
E_idx = tl.min(E_idxs)
|
||||
E_mask = E_idxs == E_idx
|
||||
M_idx = tl.load(grouped_idx_ptr + M_block, mask=E_mask, other=0)
|
||||
if x_grouped:
|
||||
M_in_idx = M_block
|
||||
else:
|
||||
M_in_idx = M_idx // FAN_OUT
|
||||
|
||||
if y_grouped:
|
||||
M_out_idx = M_block
|
||||
else:
|
||||
M_out_idx = M_idx
|
||||
|
||||
K_block = tl.arange(0, BLOCK_K)
|
||||
|
||||
N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
N_mask = N_block < N
|
||||
# N_block = tl.max_contiguous(tl.multiple_of(N_block % N, BLOCK_N), BLOCK_N)
|
||||
# N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
|
||||
X_blk_ptrs = X_ptr + M_in_idx[:, None] * stride_xm + K_block[None, :] * stride_xk
|
||||
W_blk_ptrs = W_ptr + K_block[:, None] * stride_wk + N_block[None, :] * stride_wn + E_idx * stride_we
|
||||
|
||||
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
|
||||
iters = tl.cdiv(K, BLOCK_K)
|
||||
for K_block_id in range(0, iters):
|
||||
if NO_K_MASK:
|
||||
x = tl.load(X_blk_ptrs, mask=E_mask[:, None])
|
||||
if NO_N_MASK:
|
||||
w = tl.load(W_blk_ptrs)
|
||||
else:
|
||||
w = tl.load(W_blk_ptrs, mask=N_mask[None, :])
|
||||
else:
|
||||
K_mask = (K_block_id * BLOCK_K + K_block) < K
|
||||
x = tl.load(X_blk_ptrs, mask=E_mask[:, None] & K_mask[None, :])
|
||||
w = tl.load(W_blk_ptrs, mask=K_mask[:, None] & N_mask[None, :])
|
||||
X_blk_ptrs += BLOCK_K * stride_xk
|
||||
W_blk_ptrs += BLOCK_K * stride_wk
|
||||
acc += tl.dot(x, w, allow_tf32=allow_tf32, out_dtype=ACC_TYPE)
|
||||
|
||||
Y_blk_ptrs = Y_ptr + (M_out_idx[:, None] * stride_ym + N_block[None, :] * stride_yn)
|
||||
tl.store(Y_blk_ptrs, acc, mask=E_mask[:, None] & N_mask[None, :])
|
||||
|
||||
def scatter2scatter(X, W, sorted_expert_idxs, sorted_scattered_idxs, k,
|
||||
padded_block_idxs, x_grouped=False, y_grouped=False,
|
||||
out=None):
|
||||
assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0)
|
||||
assert sorted_scattered_idxs.size(0) == X.size(0) * k
|
||||
# Pre-kernel setup
|
||||
x_dim = X.size(-1)
|
||||
y_dim = W.size(-1)
|
||||
L_scattered = sorted_expert_idxs.size(0)
|
||||
if out is None:
|
||||
O = torch.empty((L_scattered, y_dim), device=X.device, dtype=X.dtype)
|
||||
else:
|
||||
assert out.size(0) == L_scattered and out.size(1) == y_dim
|
||||
O = out
|
||||
|
||||
def grid(META):
|
||||
grid_num = (
|
||||
padded_block_idxs.size(0) *
|
||||
triton.cdiv(META['N'], META['BLOCK_N']),
|
||||
)
|
||||
return grid_num
|
||||
"""
|
||||
print("X", X.size(), X.stride(),
|
||||
"W", W.size(), W.stride(),
|
||||
"O", O.size(), O.stride(),
|
||||
"sorted_idxs", sorted_scattered_idxs.size(),
|
||||
"FAN_OUT", k,
|
||||
"BLOCK_M", BLOCK_M,
|
||||
"grouped", (x_grouped, y_grouped))
|
||||
"""
|
||||
_scatter2scatter[grid](
|
||||
# X_ptr, stride_xm, stride_xk,
|
||||
X, X.stride(0), X.stride(1),
|
||||
# W_ptr, stride_we, stride_wk, stride_wn,
|
||||
W, W.stride(0), W.stride(1), W.stride(2),
|
||||
# Y_ptr, stride_ym, stride_yn,
|
||||
O, O.stride(0), O.stride(1),
|
||||
grouped_idx_ptr=sorted_scattered_idxs,
|
||||
expert_idxs_ptr=sorted_expert_idxs,
|
||||
block_start_idx_ptr=padded_block_idxs,
|
||||
FAN_OUT=k,
|
||||
M=X.size(0),
|
||||
K=X.size(1),
|
||||
N=O.size(1), E=W.size(0),
|
||||
BLOCK_M=BLOCK_M,
|
||||
ACC_TYPE=tl.float32,
|
||||
OUT_M=O.size(0),
|
||||
allow_tf32=True,
|
||||
x_grouped=x_grouped, y_grouped=y_grouped,
|
||||
)
|
||||
return O
|
||||
|
||||
|
||||
def _config_XtY():
|
||||
return [
|
||||
triton.Config({'BLOCK_N': 128, 'BLOCK_K': 128, 'BLOCK_M': 32}, num_stages=4, num_warps=4),
|
||||
]
|
||||
|
||||
def group_bwd_W(DY, X, expert_offsets, E):
|
||||
DWt = torch.zeros((E, DY.size(-1), X.size(-1)), device=DY.device, dtype=DY.dtype)
|
||||
DW = DWt.permute(0, 2, 1)
|
||||
def grid(META):
|
||||
grid = (
|
||||
E * triton.cdiv(META['K'], META['BLOCK_K']),
|
||||
triton.cdiv(META['N'], META['BLOCK_N']),
|
||||
)
|
||||
return grid
|
||||
_groupXtY[grid](
|
||||
# DY_ptr, stride_dym, stride_dyk,
|
||||
DY, DY.stride(0), DY.stride(1),
|
||||
# X_ptr, stride_xm, stride_xn,
|
||||
X, X.stride(0), X.stride(1),
|
||||
# DW_ptr, stride_dwe, stride_dwk, stride_dwn,
|
||||
DW, DW.stride(0), DW.stride(1), DW.stride(2),
|
||||
# expert_offsets_ptr,
|
||||
expert_offsets,
|
||||
# K: tl.constexpr, N: tl.constexpr,
|
||||
M=DY.size(0), N=DY.size(-1), K=X.size(-1),
|
||||
# ACC_TYPE: tl.constexpr,
|
||||
ACC_TYPE=tl.float32,
|
||||
allow_tf32=True
|
||||
)
|
||||
return DW
|
||||
|
||||
@triton.autotune(configs=_config_XtY(), key=['M', 'N', 'K'], )
|
||||
@triton.heuristics({
|
||||
"NO_K_MASK": lambda args: (args['K'] % args['BLOCK_K']) == 0,
|
||||
"NO_N_MASK": lambda args: (args['N'] % args['BLOCK_N']) == 0,
|
||||
})
|
||||
@triton.jit
|
||||
def _groupXtY(
|
||||
DY_ptr, stride_dym, stride_dyk,
|
||||
X_ptr, stride_xm, stride_xn,
|
||||
DW_ptr, stride_dwe, stride_dwk, stride_dwn,
|
||||
expert_offsets_ptr,
|
||||
M: tl.constexpr, K: tl.constexpr, N: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
ACC_TYPE: tl.constexpr,
|
||||
allow_tf32: tl.constexpr,
|
||||
NO_K_MASK: tl.constexpr, NO_N_MASK: tl.constexpr
|
||||
):
|
||||
pid0 = tl.program_id(axis=0)
|
||||
pid1 = tl.program_id(axis=1)
|
||||
num0 = tl.num_programs(0)
|
||||
num1 = tl.num_programs(1)
|
||||
pid1, pid0 = tl.swizzle2d(pid1, pid0, num1, num0, 128)
|
||||
|
||||
K_BLOCK_COUNT = tl.cdiv(K, BLOCK_K)
|
||||
E_idx = pid0 // K_BLOCK_COUNT
|
||||
K_block_id = pid0 % K_BLOCK_COUNT
|
||||
N_block_id = pid1
|
||||
|
||||
if E_idx == 0:
|
||||
start_idx = 0
|
||||
else:
|
||||
start_idx = tl.load(expert_offsets_ptr + E_idx - 1).to(tl.int32)
|
||||
end_idx = tl.load(expert_offsets_ptr + E_idx).to(tl.int32)
|
||||
|
||||
if end_idx > start_idx:
|
||||
M_block = tl.max_contiguous(start_idx + tl.arange(0, BLOCK_M), BLOCK_M)
|
||||
|
||||
K_block = K_block_id * BLOCK_K + tl.arange(0, BLOCK_K)
|
||||
K_mask = K_block < K
|
||||
K_block = tl.max_contiguous(tl.multiple_of(K_block % K, BLOCK_K), BLOCK_K)
|
||||
|
||||
N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
N_mask = N_block < N
|
||||
N_block = tl.max_contiguous(tl.multiple_of(N_block % N, BLOCK_N), BLOCK_N)
|
||||
|
||||
M_idxs = M_block
|
||||
xt_blk_ptrs = X_ptr + K_block[:, None] * stride_xn + M_idxs[None, :] * stride_xm
|
||||
dy_blk_ptrs = DY_ptr + M_idxs[:, None] * stride_dym + N_block[None, :] * stride_dyk
|
||||
|
||||
acc = tl.zeros((BLOCK_K, BLOCK_N), dtype=ACC_TYPE)
|
||||
iters = tl.cdiv(end_idx - start_idx, BLOCK_M)
|
||||
for i in range(0, iters):
|
||||
M_mask = (i * BLOCK_M + M_block) < end_idx
|
||||
if NO_K_MASK:
|
||||
xt = tl.load(xt_blk_ptrs, mask=M_mask[None, :])
|
||||
else:
|
||||
xt = tl.load(xt_blk_ptrs, mask=K_mask[:, None] & M_mask[None, :])
|
||||
if NO_N_MASK:
|
||||
dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None])
|
||||
else:
|
||||
dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None] & N_mask[None, :])
|
||||
acc += tl.dot(xt, dy, out_dtype=ACC_TYPE, allow_tf32=allow_tf32)
|
||||
xt_blk_ptrs += BLOCK_M * stride_xm
|
||||
dy_blk_ptrs += BLOCK_M * stride_dym
|
||||
|
||||
|
||||
DW_blk_ptrs = DW_ptr + E_idx * stride_dwe + K_block[:, None] * stride_dwk + N_block[None, :] * stride_dwn
|
||||
acc = acc.to(DW_blk_ptrs.dtype.element_ty)
|
||||
tl.store(DW_blk_ptrs, acc, mask=K_mask[:, None] & N_mask[None, :])
|
||||
|
||||
|
||||
def _config_grouping():
|
||||
return [
|
||||
triton.Config({'BLOCK_N': 256, 'BLOCK_K': 128}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_N': 64, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
|
||||
]
|
||||
|
||||
def group(A, sorted_expert_idxs, coeff=None, fan_out=1, out=None):
|
||||
N = sorted_expert_idxs.size(0)
|
||||
K = A.size(1)
|
||||
assert A.size(0) * fan_out == N
|
||||
if out is not None:
|
||||
Y = out
|
||||
else:
|
||||
Y = torch.empty((N, K), dtype=A.dtype, device=A.device)
|
||||
# print("grp init:", Y.size())
|
||||
def grid(META):
|
||||
grid_num = (triton.cdiv(META['N'], META['BLOCK_N']),)
|
||||
return grid_num
|
||||
_group[grid](
|
||||
# A_ptr, stride_an, stride_ai,
|
||||
A, A.stride(0), A.stride(1), coeff is not None, coeff, fan_out,
|
||||
# Y_ptr, stride_yn, stride_yk,
|
||||
Y, Y.stride(0), Y.stride(1),
|
||||
# grouped_idx_ptr,
|
||||
sorted_expert_idxs,
|
||||
# N: tl.constexpr, K: tl.constexpr,
|
||||
N, K
|
||||
)
|
||||
return Y
|
||||
|
||||
@triton.autotune(configs=_config_grouping(), key=['K'])
|
||||
@triton.heuristics({
|
||||
"NO_K_MASK": lambda args: (args['K'] % args['BLOCK_K']) == 0
|
||||
})
|
||||
@triton.jit
|
||||
def _group(
|
||||
src_ptr, stride_sn, stride_sk, has_coeff: tl.constexpr, coeff_ptr, FAN_OUT: tl.constexpr,
|
||||
tgt_ptr, stride_tn, stride_ti,
|
||||
grouped_idx_ptr,
|
||||
N: tl.constexpr, K: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
NO_K_MASK: tl.constexpr
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
|
||||
N_block_id = pid
|
||||
N_blk = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
N_mask = N_blk < N
|
||||
N_blk = tl.max_contiguous(tl.multiple_of(N_blk % N, BLOCK_N), BLOCK_N)
|
||||
N_idx = tl.load(grouped_idx_ptr + N_blk, mask=N_mask, other=0)
|
||||
|
||||
K_blk = tl.arange(0, BLOCK_K)
|
||||
src_blk_ptrs = src_ptr + (N_idx // FAN_OUT)[:, None] * stride_sn + K_blk[None, :] * stride_sk
|
||||
tgt_blk_ptrs = tgt_ptr + N_blk[:, None] * stride_tn + K_blk[None, :] * stride_ti
|
||||
|
||||
if has_coeff:
|
||||
c = tl.load(coeff_ptr + N_idx, mask=N_mask)[:, None]
|
||||
|
||||
iters = tl.cdiv(K, BLOCK_K)
|
||||
for i in range(0, iters):
|
||||
if NO_K_MASK:
|
||||
block = tl.load(src_blk_ptrs) # , mask=N_mask[:, None])
|
||||
if has_coeff:
|
||||
block *= c
|
||||
tl.store(tgt_blk_ptrs, block, mask=N_mask[:, None])
|
||||
|
||||
else:
|
||||
K_mask = (i * BLOCK_K + K_blk) < K
|
||||
mask = N_mask[:, None] & K_mask[None, :]
|
||||
block = tl.load(src_blk_ptrs, mask=mask)
|
||||
if has_coeff:
|
||||
block *= c
|
||||
tl.store(tgt_blk_ptrs, block, mask=mask)
|
||||
|
||||
src_blk_ptrs += BLOCK_K * stride_sk
|
||||
tgt_blk_ptrs += BLOCK_K * stride_ti
|
||||
66
src/axolotl/monkeypatch/moe/single.py
Normal file
66
src/axolotl/monkeypatch/moe/single.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""
|
||||
Adapted from:
|
||||
https://github.com/shawntan/scattermoe
|
||||
https://arxiv.org/abs/2403.08245
|
||||
"""
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from torch.nn import functional as F
|
||||
|
||||
@triton.jit
|
||||
def _single2scatter(
|
||||
X_ptr, stride_xm, stride_xk,
|
||||
W_ptr, stride_we, stride_wk, stride_wn,
|
||||
Y_ptr, stride_ym, stride_yn,
|
||||
expert_idxs_ptr,
|
||||
FAN_OUT: tl.constexpr,
|
||||
K: tl.constexpr, N: tl.constexpr, E: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
ACC_TYPE: tl.constexpr,
|
||||
):
|
||||
pid0 = tl.program_id(axis=0)
|
||||
pid1 = tl.program_id(axis=1)
|
||||
|
||||
N_block_id = pid0
|
||||
if FAN_OUT == 1:
|
||||
in_idx = pid1
|
||||
else:
|
||||
in_idx = 0
|
||||
out_idx = pid1
|
||||
|
||||
K_block = tl.arange(0, BLOCK_K)
|
||||
N_block = tl.max_contiguous(tl.multiple_of((N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)) % N, BLOCK_N), BLOCK_N)
|
||||
E_idx = tl.load(expert_idxs_ptr + pid1)
|
||||
X_blk_ptrs = X_ptr + in_idx * stride_xm + K_block[:, None] * stride_xk
|
||||
W_blk_ptrs = W_ptr + E_idx * stride_we + K_block[:, None] * stride_wk + N_block[None, :] * stride_wn
|
||||
acc = tl.zeros((1, BLOCK_N), dtype=ACC_TYPE)
|
||||
for K_block_id in range(0, tl.cdiv(K, BLOCK_K)):
|
||||
x = tl.load(X_blk_ptrs)
|
||||
w = tl.load(W_blk_ptrs)
|
||||
acc += tl.sum(x * w, axis=0)[None, :]
|
||||
X_blk_ptrs += BLOCK_K * stride_xk
|
||||
W_blk_ptrs += BLOCK_K * stride_wk
|
||||
Y_blk_ptrs = Y_ptr + out_idx * stride_ym + N_block[None, :] * stride_yn
|
||||
tl.store(Y_blk_ptrs, acc)
|
||||
|
||||
def single2scatter(X, W, expert_idxs):
|
||||
E, xdim, ydim = W.size()
|
||||
k = expert_idxs.size(1)
|
||||
assert X.size(0) == k or X.size(0) == 1
|
||||
Y = torch.empty((k, ydim), device=X.device, dtype=X.dtype)
|
||||
BLOCK_N = 128
|
||||
BLOCK_K = 128
|
||||
grid = ydim // BLOCK_N, k
|
||||
_single2scatter[grid](
|
||||
X, X.stride(0), X.stride(1),
|
||||
W, W.stride(0), W.stride(1), W.stride(2),
|
||||
Y, Y.stride(0), Y.stride(1),
|
||||
expert_idxs,
|
||||
FAN_OUT=Y.size(0) // X.size(0),
|
||||
K=xdim, N=ydim, E=E,
|
||||
BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
|
||||
ACC_TYPE=tl.float32
|
||||
)
|
||||
return Y
|
||||
@@ -715,32 +715,27 @@ def load_model(
|
||||
if cfg.flash_attn_fuse_qkv:
|
||||
LOG.info("patching with fused QKV")
|
||||
replace_llama_qkv_with_fused(model)
|
||||
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
|
||||
# This is a WIP, still an issue with the backward pass
|
||||
# RuntimeError: grad can be implicitly created only for scalar outputs
|
||||
# TODO: try config.sequence_parallel = False
|
||||
# # https://github.com/HazyResearch/flash-attention/blob/40a25c8ee7465cf547b929cfa2937034e37bfce9/tests/models/test_gpt_neox.py#L12
|
||||
# # https://github.com/HazyResearch/flash-attention/tree/main/training#model-components
|
||||
# # add `**kwargs` to https://github.com/HazyResearch/flash-attention/blob/40a25c8ee7465cf547b929cfa2937034e37bfce9/flash_attn/models/gpt.py#L442
|
||||
# from flash_attn.utils.pretrained import state_dict_from_pretrained
|
||||
# from flash_attn.models.gpt import GPTLMHeadModel
|
||||
# from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox, gpt_neox_config_to_gpt2_config
|
||||
# from transformers import GPTNeoXConfig
|
||||
# config = gpt_neox_config_to_gpt2_config(GPTNeoXConfig.from_pretrained(base_model))
|
||||
# config.use_flash_attn = True
|
||||
# config.fused_bias_fc = True
|
||||
# config.fused_mlp = True # GPT-NeoX-20B uses "gelu_fast"
|
||||
# config.activation_function = "gelu_fast"
|
||||
# config.fused_dropout_add_ln = True
|
||||
# # config.residual_in_fp32 = True
|
||||
#
|
||||
# model: GPTLMHeadModel = GPTLMHeadModel.from_pretrained(
|
||||
# base_model,
|
||||
# config,
|
||||
# dtype=torch_dtype,
|
||||
# device=cfg.device,
|
||||
# )
|
||||
# model.train() # sets to train instead of eval mode
|
||||
elif (
|
||||
model_config.model_type == "mixtral"
|
||||
and not cfg.adapter
|
||||
and cfg.fuse_moe
|
||||
):
|
||||
from axolotl.monkeypatch.utils import set_module_name
|
||||
from axolotl.monkeypatch.moe.moe import SparseMoeBlock
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, MixtralSparseMoeBlock):
|
||||
smoe = SparseMoeBlock(
|
||||
experts=module.experts,
|
||||
gate=module.gate,
|
||||
hidden_dim=module.hidden_dim,
|
||||
ffn_dim=module.ffn_dim,
|
||||
num_experts=module.num_experts,
|
||||
top_k=module.top_k,
|
||||
)
|
||||
set_module_name(model, name, smoe)
|
||||
|
||||
elif model_type == "MambaLMHeadModel":
|
||||
# FIXME this is janky at best and hacked together to make it work
|
||||
MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name
|
||||
|
||||
60
tests/monkeypatch/test_moe.py
Normal file
60
tests/monkeypatch/test_moe.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import torch
|
||||
import pytest
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from axolotl.monkeypatch.moe.mlp import FusedExperts
|
||||
from axolotl.monkeypatch.moe.moe import SparseMoeBlock
|
||||
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock, MixtralConfig
|
||||
|
||||
def test_fused_mixtral_moe():
|
||||
# NOTE: Requires torch 2.2.0
|
||||
# Set random seeds for reproducibility
|
||||
torch.set_default_dtype(torch.float16)
|
||||
torch.set_default_device("cuda")
|
||||
torch.manual_seed(0)
|
||||
|
||||
# Define the configuration for the MixtralSparseMoeBlock
|
||||
config = MixtralConfig(
|
||||
hidden_size=128,
|
||||
intermediate_size=512,
|
||||
num_local_experts=8,
|
||||
num_experts_per_tok=2,
|
||||
)
|
||||
|
||||
# Initialize the MixtralSparseMoeBlock and SparseMoeBlock with the same configuration
|
||||
mixtral_moe = MixtralSparseMoeBlock(config)
|
||||
sparse_moe = SparseMoeBlock(
|
||||
experts=mixtral_moe.experts,
|
||||
gate=mixtral_moe.gate,
|
||||
hidden_dim=config.hidden_size,
|
||||
ffn_dim=config.intermediate_size,
|
||||
num_experts=config.num_local_experts,
|
||||
top_k=config.num_experts_per_tok
|
||||
)
|
||||
|
||||
assert torch.cat([
|
||||
mixtral_moe.experts[0].w1.weight.data,
|
||||
mixtral_moe.experts[0].w3.weight.data], dim=0
|
||||
).equal(sparse_moe.experts.experts.weight[0])
|
||||
|
||||
# Generate random input data
|
||||
batch_size = 16
|
||||
sequence_length = 32
|
||||
input_data = torch.randn(batch_size, sequence_length, config.hidden_size)
|
||||
|
||||
# Run the forward pass with gradients for both models
|
||||
with torch.no_grad():
|
||||
mixtral_output, mixtral_router_logits = mixtral_moe(input_data)
|
||||
sparse_output, sparse_router_logits = sparse_moe(input_data)
|
||||
|
||||
# Compute the difference between the outputs
|
||||
output_diff = torch.abs(mixtral_output - sparse_output).mean().item()
|
||||
router_diff = torch.abs(mixtral_router_logits - sparse_router_logits).mean().item()
|
||||
|
||||
# Define the tolerance for the difference
|
||||
tolerance = 0.05
|
||||
|
||||
# # Check if the difference is within the tolerance
|
||||
assert output_diff < 0.05, f"Output difference is {output_diff}, which is greater than the tolerance of {tolerance}"
|
||||
assert router_diff == 0, f"Output difference is {output_diff}, which is greater than the tolerance of {tolerance}"
|
||||
Reference in New Issue
Block a user