From 04168801c9cf75dcd500f5a3028bee6703ff40f2 Mon Sep 17 00:00:00 2001 From: Casper Hansen Date: Sun, 17 Mar 2024 20:21:46 +0000 Subject: [PATCH] Simplify conversion + more debug --- examples/mistral/mixtral_fused.py | 19 +++++++++++++++++-- src/axolotl/monkeypatch/moe/mlp.py | 29 ++++++++++------------------- 2 files changed, 27 insertions(+), 21 deletions(-) diff --git a/examples/mistral/mixtral_fused.py b/examples/mistral/mixtral_fused.py index 5e72e2266..d954d5995 100644 --- a/examples/mistral/mixtral_fused.py +++ b/examples/mistral/mixtral_fused.py @@ -1,14 +1,26 @@ +import torch +from tqdm import tqdm from axolotl.monkeypatch.moe.moe import SparseMoeBlock from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock +def compute_memory_used_pct(device): + memory_used = torch.cuda.max_memory_allocated(device) / (1024**3) + memory_pct = ( + memory_used + / (torch.cuda.get_device_properties(device).total_memory / (1024**3)) + * 100 + ) + return memory_pct + model_path = "mistralai/Mixtral-8x7B-Instruct-v0.1" # Load model model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto") +modules = {k:v for k,v in model.named_modules() if isinstance(v, MixtralSparseMoeBlock)} -for name, module in model.named_modules(): - if isinstance(module, MixtralSparseMoeBlock): +with tqdm(modules.items(), desc="scatter moe") as pbar: + for name, module in pbar: smoe = SparseMoeBlock( experts=module.experts, gate=module.gate, @@ -18,6 +30,9 @@ for name, module in model.named_modules(): top_k=module.top_k, ) setattr(model, name, smoe) + for device_index in range(torch.cuda.device_count()): + device_memory_pct = compute_memory_used_pct(device_index) + print(device_index, device_memory_pct) tokenizer = AutoTokenizer.from_pretrained(model_path) streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) diff --git a/src/axolotl/monkeypatch/moe/mlp.py b/src/axolotl/monkeypatch/moe/mlp.py index f1a1328a7..5dbce5d61 100644 --- a/src/axolotl/monkeypatch/moe/mlp.py +++ b/src/axolotl/monkeypatch/moe/mlp.py @@ -4,6 +4,7 @@ https://github.com/shawntan/scattermoe https://arxiv.org/abs/2403.08245 """ +import gc import torch from torch import nn @@ -26,34 +27,24 @@ class FusedExperts(nn.Module): MLP of type Gated-Linear Unit, typically with a SiLU activation function. """ super(FusedExperts, self).__init__() - expert_device = experts[0].w1.weight.device - output_expert_device = experts[0].w2.weight.device + device = experts[0].w1.weight.device self.num_experts = num_experts self.hidden_dim = hidden_dim self.ffn_dim = ffn_dim - self.experts = ParallelExperts(num_experts, hidden_dim, 2 * ffn_dim, expert_device) - self.output_experts = ParallelExperts(num_experts, ffn_dim, hidden_dim, output_expert_device) + self.experts = ParallelExperts(num_experts, hidden_dim, 2 * ffn_dim, device=device) + self.output_experts = ParallelExperts(num_experts, ffn_dim, hidden_dim, device=device) self.top_k = min(top_k, self.num_experts) self.activation = activation - # parallelize all w1 and w3 computation by concat + stack with torch.no_grad(): - torch.stack( - [ - torch.cat([experts[i].w1.weight, experts[i].w3.weight], dim=0) - for i in range(len(experts)) - ], - dim=0, - out=self.experts.weight.data, - ) + for i in range(len(experts)): + self.experts.weight.data[i] = torch.cat([experts[i].w1.weight, experts[i].w3.weight], dim=0) + self.output_experts.weight.data[i] = experts[i].w2.weight - # parallelize all w2 computation by stack - torch.stack( - [expert.w2.weight for expert in experts], - dim=0, - out=self.output_experts.weight.data, - ) + del experts[i].w1, experts[i].w2, experts[i].w3 + gc.collect() + torch.cuda.empty_cache() def forward( self, x: torch.Tensor, routing_weights: torch.Tensor, selected_experts: torch.Tensor