Compare commits

..

2 Commits

Author SHA1 Message Date
Eric Hartford
9c221a6761 code review feedback 2024-03-15 14:10:22 -07:00
Eric Hartford
301cc4c006 implement post training 2024-03-15 13:16:06 -07:00
4 changed files with 49 additions and 99 deletions

View File

@@ -1,75 +0,0 @@
import gc
import torch
from tqdm import tqdm
from axolotl.monkeypatch.moe.moe import SparseMoeBlock
from transformers import AutoTokenizer, TextStreamer
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock, MixtralForCausalLM, MixtralConfig
def compute_memory_used_pct(device):
memory_used = torch.cuda.max_memory_allocated(device) / (1024**3)
memory_pct = (
memory_used
/ (torch.cuda.get_device_properties(device).total_memory / (1024**3))
* 100
)
return memory_pct
model_path = "mistralai/Mixtral-8x7B-Instruct-v0.1"
# Load model
config = MixtralConfig.from_pretrained(model_path, max_position_embeddings=2048, use_cache=False)
model = MixtralForCausalLM.from_pretrained(
model_path,
config=config,
device_map="auto",
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
)
modules = {k:v for k,v in model.named_modules() if isinstance(v, MixtralSparseMoeBlock)}
for device_index in range(torch.cuda.device_count()):
device_memory_pct = compute_memory_used_pct(device_index)
print(device_index, device_memory_pct)
with tqdm(modules.items(), desc="scatter moe") as pbar:
for i, (name, module) in enumerate(pbar):
smoe = SparseMoeBlock(
experts=module.experts,
gate=module.gate,
hidden_dim=module.hidden_dim,
ffn_dim=module.ffn_dim,
num_experts=module.num_experts,
top_k=module.top_k,
)
old_module = model.model.layers[i].block_sparse_moe
setattr(model.model.layers[i], "block_sparse_moe", smoe)
del old_module
torch.cuda.empty_cache()
gc.collect()
torch.cuda.empty_cache()
for device_index in range(torch.cuda.device_count()):
device_memory_pct = compute_memory_used_pct(device_index)
print(device_index, device_memory_pct)
tokenizer = AutoTokenizer.from_pretrained(model_path)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# Convert prompt to tokens
prompt_template = "[INST] {prompt} [/INST]"
prompt = "You're standing on the surface of the Earth. "\
"You walk one mile south, one mile west and one mile north. "\
"You end up exactly where you started. Where are you?"
tokens = tokenizer(
prompt_template.format(prompt=prompt),
return_tensors='pt'
).input_ids.cuda()
# Generate output
generation_output = model.generate(
tokens,
streamer=streamer,
max_new_tokens=512
)

View File

@@ -123,11 +123,9 @@ def parallel_linear(inputs, expert_weights, k,
return results
class ParallelExperts(nn.Module):
def __init__(self, num_experts, input_size, output_size, device) -> None:
def __init__(self, num_experts, input_size, output_size) -> None:
super().__init__()
self.weight = nn.Parameter(
torch.empty(num_experts, output_size, input_size, device=device)
)
self.weight = nn.Parameter(torch.empty(num_experts, output_size, input_size))
self.num_experts = num_experts
self.input_size = input_size
self.output_size = output_size

View File

@@ -4,7 +4,6 @@ https://github.com/shawntan/scattermoe
https://arxiv.org/abs/2403.08245
"""
import gc
import torch
from torch import nn
@@ -15,7 +14,7 @@ from axolotl.monkeypatch.moe.linear import ParallelExperts
class FusedExperts(nn.Module):
def __init__(
self,
experts: nn.ModuleList =None,
experts=None,
hidden_dim=128,
ffn_dim=512,
num_experts=8,
@@ -28,26 +27,31 @@ class FusedExperts(nn.Module):
"""
super(FusedExperts, self).__init__()
device = experts[0].w1.weight.device
self.num_experts = num_experts
self.hidden_dim = hidden_dim
self.ffn_dim = ffn_dim
self.experts = ParallelExperts(num_experts, hidden_dim, 2 * ffn_dim, device=device)
self.output_experts = ParallelExperts(num_experts, ffn_dim, hidden_dim, device=device)
self.experts = ParallelExperts(num_experts, hidden_dim, 2 * ffn_dim)
self.output_experts = ParallelExperts(num_experts, ffn_dim, hidden_dim)
self.top_k = min(top_k, self.num_experts)
self.activation = activation
# parallelize all w1 and w3 computation by concat + stack
with torch.no_grad():
for i in range(len(experts)):
self.experts.weight.data[i].copy_(
torch.cat(
[experts[i].w1.weight.detach(), experts[i].w3.weight.detach()],
dim=0
)
)
self.output_experts.weight.data[i].copy_(
experts[i].w2.weight.detach()
)
torch.stack(
[
torch.cat([experts[i].w1.weight, experts[i].w3.weight], dim=0)
for i in range(len(experts))
],
dim=0,
out=self.experts.weight.data,
)
# parallelize all w2 computation by stack
torch.stack(
[expert.w2.weight for expert in experts],
dim=0,
out=self.output_experts.weight.data,
)
def forward(
self, x: torch.Tensor, routing_weights: torch.Tensor, selected_experts: torch.Tensor

View File

@@ -21,14 +21,37 @@ class SparseMoeBlock(nn.Module):
)
def _post_training(self, model, name):
# get original weights back: reverse the concat + stack in the fused experts
# Get original weights back: reverse the concat + stack in the fused experts
w1s, w3s = torch.split(torch.unbind(self.experts.experts.weight, dim=0), 2, dim=1)
w2s = torch.unbind(self.experts.output_experts.weight, dim=0)
# TODO: recreate MoE class with original weights
experts = []
for i in range(self.num_experts):
pass
# Recreate the structure of the original MixtralSparseMoeBlock
original_moe = nn.Module()
original_moe.hidden_dim = self.hidden_dim
original_moe.ffn_dim = self.ffn_dim
original_moe.num_experts = self.num_experts
original_moe.top_k = self.top_k
# Recreate the gating module
original_moe.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
original_moe.gate.weight.data = self.gate.weight.data
# Recreate the experts as a ModuleList
original_moe.experts = nn.ModuleList()
for expert_idx in range(self.num_experts):
expert = nn.Module()
expert.w1 = nn.Linear(self.hidden_dim, 2 * self.ffn_dim, bias=False)
expert.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
expert.w3 = nn.Linear(self.hidden_dim, 2 * self.ffn_dim, bias=False)
expert.act_fn = self.experts.activation
expert.w1.weight.data = torch.cat([w1s[expert_idx], w3s[expert_idx]], dim=0)
expert.w2.weight.data = w2s[expert_idx]
original_moe.experts.append(expert)
# Replace the SparseMoeBlock with the recreated MixtralSparseMoeBlock structure
setattr(model, name, original_moe)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape