Simplify creating parameters

This commit is contained in:
Casper Hansen
2024-03-18 12:32:59 +00:00
parent 5bfc470d57
commit 10328b3429
2 changed files with 24 additions and 9 deletions

View File

@@ -1,3 +1,4 @@
import gc
import torch
from tqdm import tqdm
from axolotl.monkeypatch.moe.moe import SparseMoeBlock
@@ -16,8 +17,14 @@ def compute_memory_used_pct(device):
model_path = "mistralai/Mixtral-8x7B-Instruct-v0.1"
# Load model
config = MixtralConfig.from_pretrained(model_path, max_position_embeddings=2048)
model = MixtralForCausalLM.from_pretrained(model_path, config=config, device_map="auto", low_cpu_mem_usage=True, torch_dtype=torch.float16)
config = MixtralConfig.from_pretrained(model_path, max_position_embeddings=2048, use_cache=False)
model = MixtralForCausalLM.from_pretrained(
model_path,
config=config,
device_map="auto",
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
)
modules = {k:v for k,v in model.named_modules() if isinstance(v, MixtralSparseMoeBlock)}
for device_index in range(torch.cuda.device_count()):
@@ -34,7 +41,13 @@ with tqdm(modules.items(), desc="scatter moe") as pbar:
num_experts=module.num_experts,
top_k=module.top_k,
)
old_module = model.model.layers[i].block_sparse_moe
setattr(model.model.layers[i], "block_sparse_moe", smoe)
del old_module
torch.cuda.empty_cache()
gc.collect()
torch.cuda.empty_cache()
for device_index in range(torch.cuda.device_count()):
device_memory_pct = compute_memory_used_pct(device_index)
print(device_index, device_memory_pct)