Simplify conversion + more debug
This commit is contained in:
@@ -1,14 +1,26 @@
|
|||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
from axolotl.monkeypatch.moe.moe import SparseMoeBlock
|
from axolotl.monkeypatch.moe.moe import SparseMoeBlock
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
|
||||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||||
|
|
||||||
|
def compute_memory_used_pct(device):
|
||||||
|
memory_used = torch.cuda.max_memory_allocated(device) / (1024**3)
|
||||||
|
memory_pct = (
|
||||||
|
memory_used
|
||||||
|
/ (torch.cuda.get_device_properties(device).total_memory / (1024**3))
|
||||||
|
* 100
|
||||||
|
)
|
||||||
|
return memory_pct
|
||||||
|
|
||||||
model_path = "mistralai/Mixtral-8x7B-Instruct-v0.1"
|
model_path = "mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||||
|
|
||||||
# Load model
|
# Load model
|
||||||
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
|
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
|
||||||
|
modules = {k:v for k,v in model.named_modules() if isinstance(v, MixtralSparseMoeBlock)}
|
||||||
|
|
||||||
for name, module in model.named_modules():
|
with tqdm(modules.items(), desc="scatter moe") as pbar:
|
||||||
if isinstance(module, MixtralSparseMoeBlock):
|
for name, module in pbar:
|
||||||
smoe = SparseMoeBlock(
|
smoe = SparseMoeBlock(
|
||||||
experts=module.experts,
|
experts=module.experts,
|
||||||
gate=module.gate,
|
gate=module.gate,
|
||||||
@@ -18,6 +30,9 @@ for name, module in model.named_modules():
|
|||||||
top_k=module.top_k,
|
top_k=module.top_k,
|
||||||
)
|
)
|
||||||
setattr(model, name, smoe)
|
setattr(model, name, smoe)
|
||||||
|
for device_index in range(torch.cuda.device_count()):
|
||||||
|
device_memory_pct = compute_memory_used_pct(device_index)
|
||||||
|
print(device_index, device_memory_pct)
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||||
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ https://github.com/shawntan/scattermoe
|
|||||||
https://arxiv.org/abs/2403.08245
|
https://arxiv.org/abs/2403.08245
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import gc
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
@@ -26,34 +27,24 @@ class FusedExperts(nn.Module):
|
|||||||
MLP of type Gated-Linear Unit, typically with a SiLU activation function.
|
MLP of type Gated-Linear Unit, typically with a SiLU activation function.
|
||||||
"""
|
"""
|
||||||
super(FusedExperts, self).__init__()
|
super(FusedExperts, self).__init__()
|
||||||
expert_device = experts[0].w1.weight.device
|
|
||||||
output_expert_device = experts[0].w2.weight.device
|
|
||||||
|
|
||||||
|
device = experts[0].w1.weight.device
|
||||||
self.num_experts = num_experts
|
self.num_experts = num_experts
|
||||||
self.hidden_dim = hidden_dim
|
self.hidden_dim = hidden_dim
|
||||||
self.ffn_dim = ffn_dim
|
self.ffn_dim = ffn_dim
|
||||||
self.experts = ParallelExperts(num_experts, hidden_dim, 2 * ffn_dim, expert_device)
|
self.experts = ParallelExperts(num_experts, hidden_dim, 2 * ffn_dim, device=device)
|
||||||
self.output_experts = ParallelExperts(num_experts, ffn_dim, hidden_dim, output_expert_device)
|
self.output_experts = ParallelExperts(num_experts, ffn_dim, hidden_dim, device=device)
|
||||||
self.top_k = min(top_k, self.num_experts)
|
self.top_k = min(top_k, self.num_experts)
|
||||||
self.activation = activation
|
self.activation = activation
|
||||||
|
|
||||||
# parallelize all w1 and w3 computation by concat + stack
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
torch.stack(
|
for i in range(len(experts)):
|
||||||
[
|
self.experts.weight.data[i] = torch.cat([experts[i].w1.weight, experts[i].w3.weight], dim=0)
|
||||||
torch.cat([experts[i].w1.weight, experts[i].w3.weight], dim=0)
|
self.output_experts.weight.data[i] = experts[i].w2.weight
|
||||||
for i in range(len(experts))
|
|
||||||
],
|
|
||||||
dim=0,
|
|
||||||
out=self.experts.weight.data,
|
|
||||||
)
|
|
||||||
|
|
||||||
# parallelize all w2 computation by stack
|
del experts[i].w1, experts[i].w2, experts[i].w3
|
||||||
torch.stack(
|
gc.collect()
|
||||||
[expert.w2.weight for expert in experts],
|
torch.cuda.empty_cache()
|
||||||
dim=0,
|
|
||||||
out=self.output_experts.weight.data,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, x: torch.Tensor, routing_weights: torch.Tensor, selected_experts: torch.Tensor
|
self, x: torch.Tensor, routing_weights: torch.Tensor, selected_experts: torch.Tensor
|
||||||
|
|||||||
Reference in New Issue
Block a user