Compare commits
6 Commits
scatter_mo
...
scatter_mo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
10328b3429 | ||
|
|
5bfc470d57 | ||
|
|
04168801c9 | ||
|
|
d43a79b7bf | ||
|
|
884d81331e | ||
|
|
2ea75b4160 |
75
examples/mistral/mixtral_fused.py
Normal file
75
examples/mistral/mixtral_fused.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
import gc
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
from axolotl.monkeypatch.moe.moe import SparseMoeBlock
|
||||||
|
from transformers import AutoTokenizer, TextStreamer
|
||||||
|
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock, MixtralForCausalLM, MixtralConfig
|
||||||
|
|
||||||
|
def compute_memory_used_pct(device):
|
||||||
|
memory_used = torch.cuda.max_memory_allocated(device) / (1024**3)
|
||||||
|
memory_pct = (
|
||||||
|
memory_used
|
||||||
|
/ (torch.cuda.get_device_properties(device).total_memory / (1024**3))
|
||||||
|
* 100
|
||||||
|
)
|
||||||
|
return memory_pct
|
||||||
|
|
||||||
|
model_path = "mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||||
|
|
||||||
|
# Load model
|
||||||
|
config = MixtralConfig.from_pretrained(model_path, max_position_embeddings=2048, use_cache=False)
|
||||||
|
model = MixtralForCausalLM.from_pretrained(
|
||||||
|
model_path,
|
||||||
|
config=config,
|
||||||
|
device_map="auto",
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
)
|
||||||
|
modules = {k:v for k,v in model.named_modules() if isinstance(v, MixtralSparseMoeBlock)}
|
||||||
|
|
||||||
|
for device_index in range(torch.cuda.device_count()):
|
||||||
|
device_memory_pct = compute_memory_used_pct(device_index)
|
||||||
|
print(device_index, device_memory_pct)
|
||||||
|
|
||||||
|
with tqdm(modules.items(), desc="scatter moe") as pbar:
|
||||||
|
for i, (name, module) in enumerate(pbar):
|
||||||
|
smoe = SparseMoeBlock(
|
||||||
|
experts=module.experts,
|
||||||
|
gate=module.gate,
|
||||||
|
hidden_dim=module.hidden_dim,
|
||||||
|
ffn_dim=module.ffn_dim,
|
||||||
|
num_experts=module.num_experts,
|
||||||
|
top_k=module.top_k,
|
||||||
|
)
|
||||||
|
old_module = model.model.layers[i].block_sparse_moe
|
||||||
|
setattr(model.model.layers[i], "block_sparse_moe", smoe)
|
||||||
|
del old_module
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
for device_index in range(torch.cuda.device_count()):
|
||||||
|
device_memory_pct = compute_memory_used_pct(device_index)
|
||||||
|
print(device_index, device_memory_pct)
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||||
|
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
||||||
|
|
||||||
|
# Convert prompt to tokens
|
||||||
|
prompt_template = "[INST] {prompt} [/INST]"
|
||||||
|
|
||||||
|
prompt = "You're standing on the surface of the Earth. "\
|
||||||
|
"You walk one mile south, one mile west and one mile north. "\
|
||||||
|
"You end up exactly where you started. Where are you?"
|
||||||
|
|
||||||
|
tokens = tokenizer(
|
||||||
|
prompt_template.format(prompt=prompt),
|
||||||
|
return_tensors='pt'
|
||||||
|
).input_ids.cuda()
|
||||||
|
|
||||||
|
# Generate output
|
||||||
|
generation_output = model.generate(
|
||||||
|
tokens,
|
||||||
|
streamer=streamer,
|
||||||
|
max_new_tokens=512
|
||||||
|
)
|
||||||
@@ -123,9 +123,11 @@ def parallel_linear(inputs, expert_weights, k,
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
class ParallelExperts(nn.Module):
|
class ParallelExperts(nn.Module):
|
||||||
def __init__(self, num_experts, input_size, output_size) -> None:
|
def __init__(self, num_experts, input_size, output_size, device) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.weight = nn.Parameter(torch.empty(num_experts, output_size, input_size))
|
self.weight = nn.Parameter(
|
||||||
|
torch.empty(num_experts, output_size, input_size, device=device)
|
||||||
|
)
|
||||||
self.num_experts = num_experts
|
self.num_experts = num_experts
|
||||||
self.input_size = input_size
|
self.input_size = input_size
|
||||||
self.output_size = output_size
|
self.output_size = output_size
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ https://github.com/shawntan/scattermoe
|
|||||||
https://arxiv.org/abs/2403.08245
|
https://arxiv.org/abs/2403.08245
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import gc
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
@@ -14,7 +15,7 @@ from axolotl.monkeypatch.moe.linear import ParallelExperts
|
|||||||
class FusedExperts(nn.Module):
|
class FusedExperts(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
experts=None,
|
experts: nn.ModuleList =None,
|
||||||
hidden_dim=128,
|
hidden_dim=128,
|
||||||
ffn_dim=512,
|
ffn_dim=512,
|
||||||
num_experts=8,
|
num_experts=8,
|
||||||
@@ -27,31 +28,26 @@ class FusedExperts(nn.Module):
|
|||||||
"""
|
"""
|
||||||
super(FusedExperts, self).__init__()
|
super(FusedExperts, self).__init__()
|
||||||
|
|
||||||
|
device = experts[0].w1.weight.device
|
||||||
self.num_experts = num_experts
|
self.num_experts = num_experts
|
||||||
self.hidden_dim = hidden_dim
|
self.hidden_dim = hidden_dim
|
||||||
self.ffn_dim = ffn_dim
|
self.ffn_dim = ffn_dim
|
||||||
self.experts = ParallelExperts(num_experts, hidden_dim, 2 * ffn_dim)
|
self.experts = ParallelExperts(num_experts, hidden_dim, 2 * ffn_dim, device=device)
|
||||||
self.output_experts = ParallelExperts(num_experts, ffn_dim, hidden_dim)
|
self.output_experts = ParallelExperts(num_experts, ffn_dim, hidden_dim, device=device)
|
||||||
self.top_k = min(top_k, self.num_experts)
|
self.top_k = min(top_k, self.num_experts)
|
||||||
self.activation = activation
|
self.activation = activation
|
||||||
|
|
||||||
# parallelize all w1 and w3 computation by concat + stack
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
torch.stack(
|
for i in range(len(experts)):
|
||||||
[
|
self.experts.weight.data[i].copy_(
|
||||||
torch.cat([experts[i].w1.weight, experts[i].w3.weight], dim=0)
|
torch.cat(
|
||||||
for i in range(len(experts))
|
[experts[i].w1.weight.detach(), experts[i].w3.weight.detach()],
|
||||||
],
|
dim=0
|
||||||
dim=0,
|
)
|
||||||
out=self.experts.weight.data,
|
)
|
||||||
)
|
self.output_experts.weight.data[i].copy_(
|
||||||
|
experts[i].w2.weight.detach()
|
||||||
# parallelize all w2 computation by stack
|
)
|
||||||
torch.stack(
|
|
||||||
[expert.w2.weight for expert in experts],
|
|
||||||
dim=0,
|
|
||||||
out=self.output_experts.weight.data,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, x: torch.Tensor, routing_weights: torch.Tensor, selected_experts: torch.Tensor
|
self, x: torch.Tensor, routing_weights: torch.Tensor, selected_experts: torch.Tensor
|
||||||
|
|||||||
@@ -21,37 +21,14 @@ class SparseMoeBlock(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _post_training(self, model, name):
|
def _post_training(self, model, name):
|
||||||
# Get original weights back: reverse the concat + stack in the fused experts
|
# get original weights back: reverse the concat + stack in the fused experts
|
||||||
w1s, w3s = torch.split(torch.unbind(self.experts.experts.weight, dim=0), 2, dim=1)
|
w1s, w3s = torch.split(torch.unbind(self.experts.experts.weight, dim=0), 2, dim=1)
|
||||||
w2s = torch.unbind(self.experts.output_experts.weight, dim=0)
|
w2s = torch.unbind(self.experts.output_experts.weight, dim=0)
|
||||||
|
|
||||||
# Recreate the structure of the original MixtralSparseMoeBlock
|
# TODO: recreate MoE class with original weights
|
||||||
original_moe = nn.Module()
|
experts = []
|
||||||
original_moe.hidden_dim = self.hidden_dim
|
for i in range(self.num_experts):
|
||||||
original_moe.ffn_dim = self.ffn_dim
|
pass
|
||||||
original_moe.num_experts = self.num_experts
|
|
||||||
original_moe.top_k = self.top_k
|
|
||||||
|
|
||||||
# Recreate the gating module
|
|
||||||
original_moe.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
|
|
||||||
original_moe.gate.weight.data = self.gate.weight.data
|
|
||||||
|
|
||||||
# Recreate the experts as a ModuleList
|
|
||||||
original_moe.experts = nn.ModuleList()
|
|
||||||
for expert_idx in range(self.num_experts):
|
|
||||||
expert = nn.Module()
|
|
||||||
expert.w1 = nn.Linear(self.hidden_dim, 2 * self.ffn_dim, bias=False)
|
|
||||||
expert.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
|
|
||||||
expert.w3 = nn.Linear(self.hidden_dim, 2 * self.ffn_dim, bias=False)
|
|
||||||
expert.act_fn = self.experts.activation
|
|
||||||
|
|
||||||
expert.w1.weight.data = torch.cat([w1s[expert_idx], w3s[expert_idx]], dim=0)
|
|
||||||
expert.w2.weight.data = w2s[expert_idx]
|
|
||||||
|
|
||||||
original_moe.experts.append(expert)
|
|
||||||
|
|
||||||
# Replace the SparseMoeBlock with the recreated MixtralSparseMoeBlock structure
|
|
||||||
setattr(model, name, original_moe)
|
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
||||||
|
|||||||
Reference in New Issue
Block a user