Simplify conversion + more debug

This commit is contained in:
Casper Hansen
2024-03-17 20:21:46 +00:00
parent d43a79b7bf
commit 04168801c9
2 changed files with 27 additions and 21 deletions

View File

@@ -1,14 +1,26 @@
import torch
from tqdm import tqdm
from axolotl.monkeypatch.moe.moe import SparseMoeBlock from axolotl.monkeypatch.moe.moe import SparseMoeBlock
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
def compute_memory_used_pct(device):
memory_used = torch.cuda.max_memory_allocated(device) / (1024**3)
memory_pct = (
memory_used
/ (torch.cuda.get_device_properties(device).total_memory / (1024**3))
* 100
)
return memory_pct
model_path = "mistralai/Mixtral-8x7B-Instruct-v0.1" model_path = "mistralai/Mixtral-8x7B-Instruct-v0.1"
# Load model # Load model
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto") model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
modules = {k:v for k,v in model.named_modules() if isinstance(v, MixtralSparseMoeBlock)}
for name, module in model.named_modules(): with tqdm(modules.items(), desc="scatter moe") as pbar:
if isinstance(module, MixtralSparseMoeBlock): for name, module in pbar:
smoe = SparseMoeBlock( smoe = SparseMoeBlock(
experts=module.experts, experts=module.experts,
gate=module.gate, gate=module.gate,
@@ -18,6 +30,9 @@ for name, module in model.named_modules():
top_k=module.top_k, top_k=module.top_k,
) )
setattr(model, name, smoe) setattr(model, name, smoe)
for device_index in range(torch.cuda.device_count()):
device_memory_pct = compute_memory_used_pct(device_index)
print(device_index, device_memory_pct)
tokenizer = AutoTokenizer.from_pretrained(model_path) tokenizer = AutoTokenizer.from_pretrained(model_path)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

View File

@@ -4,6 +4,7 @@ https://github.com/shawntan/scattermoe
https://arxiv.org/abs/2403.08245 https://arxiv.org/abs/2403.08245
""" """
import gc
import torch import torch
from torch import nn from torch import nn
@@ -26,34 +27,24 @@ class FusedExperts(nn.Module):
MLP of type Gated-Linear Unit, typically with a SiLU activation function. MLP of type Gated-Linear Unit, typically with a SiLU activation function.
""" """
super(FusedExperts, self).__init__() super(FusedExperts, self).__init__()
expert_device = experts[0].w1.weight.device
output_expert_device = experts[0].w2.weight.device
device = experts[0].w1.weight.device
self.num_experts = num_experts self.num_experts = num_experts
self.hidden_dim = hidden_dim self.hidden_dim = hidden_dim
self.ffn_dim = ffn_dim self.ffn_dim = ffn_dim
self.experts = ParallelExperts(num_experts, hidden_dim, 2 * ffn_dim, expert_device) self.experts = ParallelExperts(num_experts, hidden_dim, 2 * ffn_dim, device=device)
self.output_experts = ParallelExperts(num_experts, ffn_dim, hidden_dim, output_expert_device) self.output_experts = ParallelExperts(num_experts, ffn_dim, hidden_dim, device=device)
self.top_k = min(top_k, self.num_experts) self.top_k = min(top_k, self.num_experts)
self.activation = activation self.activation = activation
# parallelize all w1 and w3 computation by concat + stack
with torch.no_grad(): with torch.no_grad():
torch.stack( for i in range(len(experts)):
[ self.experts.weight.data[i] = torch.cat([experts[i].w1.weight, experts[i].w3.weight], dim=0)
torch.cat([experts[i].w1.weight, experts[i].w3.weight], dim=0) self.output_experts.weight.data[i] = experts[i].w2.weight
for i in range(len(experts))
],
dim=0,
out=self.experts.weight.data,
)
# parallelize all w2 computation by stack del experts[i].w1, experts[i].w2, experts[i].w3
torch.stack( gc.collect()
[expert.w2.weight for expert in experts], torch.cuda.empty_cache()
dim=0,
out=self.output_experts.weight.data,
)
def forward( def forward(
self, x: torch.Tensor, routing_weights: torch.Tensor, selected_experts: torch.Tensor self, x: torch.Tensor, routing_weights: torch.Tensor, selected_experts: torch.Tensor