Simplify creating parameters
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import gc
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from axolotl.monkeypatch.moe.moe import SparseMoeBlock
|
||||
@@ -16,8 +17,14 @@ def compute_memory_used_pct(device):
|
||||
model_path = "mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||
|
||||
# Load model
|
||||
config = MixtralConfig.from_pretrained(model_path, max_position_embeddings=2048)
|
||||
model = MixtralForCausalLM.from_pretrained(model_path, config=config, device_map="auto", low_cpu_mem_usage=True, torch_dtype=torch.float16)
|
||||
config = MixtralConfig.from_pretrained(model_path, max_position_embeddings=2048, use_cache=False)
|
||||
model = MixtralForCausalLM.from_pretrained(
|
||||
model_path,
|
||||
config=config,
|
||||
device_map="auto",
|
||||
low_cpu_mem_usage=True,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
modules = {k:v for k,v in model.named_modules() if isinstance(v, MixtralSparseMoeBlock)}
|
||||
|
||||
for device_index in range(torch.cuda.device_count()):
|
||||
@@ -34,7 +41,13 @@ with tqdm(modules.items(), desc="scatter moe") as pbar:
|
||||
num_experts=module.num_experts,
|
||||
top_k=module.top_k,
|
||||
)
|
||||
old_module = model.model.layers[i].block_sparse_moe
|
||||
setattr(model.model.layers[i], "block_sparse_moe", smoe)
|
||||
del old_module
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
for device_index in range(torch.cuda.device_count()):
|
||||
device_memory_pct = compute_memory_used_pct(device_index)
|
||||
print(device_index, device_memory_pct)
|
||||
|
||||
@@ -39,13 +39,15 @@ class FusedExperts(nn.Module):
|
||||
|
||||
with torch.no_grad():
|
||||
for i in range(len(experts)):
|
||||
self.experts.weight.data[i] = torch.cat([experts[i].w1.weight, experts[i].w3.weight], dim=0)
|
||||
self.output_experts.weight.data[i] = experts[i].w2.weight
|
||||
|
||||
experts = experts.cpu()
|
||||
del experts
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
self.experts.weight.data[i].copy_(
|
||||
torch.cat(
|
||||
[experts[i].w1.weight.detach(), experts[i].w3.weight.detach()],
|
||||
dim=0
|
||||
)
|
||||
)
|
||||
self.output_experts.weight.data[i].copy_(
|
||||
experts[i].w2.weight.detach()
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, routing_weights: torch.Tensor, selected_experts: torch.Tensor
|
||||
|
||||
Reference in New Issue
Block a user