From 10328b342936f65a8cd4a552ad486548df5d0175 Mon Sep 17 00:00:00 2001 From: Casper Hansen Date: Mon, 18 Mar 2024 12:32:59 +0000 Subject: [PATCH] Simplify creating parameters --- examples/mistral/mixtral_fused.py | 17 +++++++++++++++-- src/axolotl/monkeypatch/moe/mlp.py | 16 +++++++++------- 2 files changed, 24 insertions(+), 9 deletions(-) diff --git a/examples/mistral/mixtral_fused.py b/examples/mistral/mixtral_fused.py index beb2cdc9e..9e85faeb7 100644 --- a/examples/mistral/mixtral_fused.py +++ b/examples/mistral/mixtral_fused.py @@ -1,3 +1,4 @@ +import gc import torch from tqdm import tqdm from axolotl.monkeypatch.moe.moe import SparseMoeBlock @@ -16,8 +17,14 @@ def compute_memory_used_pct(device): model_path = "mistralai/Mixtral-8x7B-Instruct-v0.1" # Load model -config = MixtralConfig.from_pretrained(model_path, max_position_embeddings=2048) -model = MixtralForCausalLM.from_pretrained(model_path, config=config, device_map="auto", low_cpu_mem_usage=True, torch_dtype=torch.float16) +config = MixtralConfig.from_pretrained(model_path, max_position_embeddings=2048, use_cache=False) +model = MixtralForCausalLM.from_pretrained( + model_path, + config=config, + device_map="auto", + low_cpu_mem_usage=True, + torch_dtype=torch.float16, +) modules = {k:v for k,v in model.named_modules() if isinstance(v, MixtralSparseMoeBlock)} for device_index in range(torch.cuda.device_count()): @@ -34,7 +41,13 @@ with tqdm(modules.items(), desc="scatter moe") as pbar: num_experts=module.num_experts, top_k=module.top_k, ) + old_module = model.model.layers[i].block_sparse_moe setattr(model.model.layers[i], "block_sparse_moe", smoe) + del old_module + torch.cuda.empty_cache() + gc.collect() + torch.cuda.empty_cache() + for device_index in range(torch.cuda.device_count()): device_memory_pct = compute_memory_used_pct(device_index) print(device_index, device_memory_pct) diff --git a/src/axolotl/monkeypatch/moe/mlp.py b/src/axolotl/monkeypatch/moe/mlp.py index dff55b0f6..556ab49d3 100644 --- a/src/axolotl/monkeypatch/moe/mlp.py +++ b/src/axolotl/monkeypatch/moe/mlp.py @@ -39,13 +39,15 @@ class FusedExperts(nn.Module): with torch.no_grad(): for i in range(len(experts)): - self.experts.weight.data[i] = torch.cat([experts[i].w1.weight, experts[i].w3.weight], dim=0) - self.output_experts.weight.data[i] = experts[i].w2.weight - - experts = experts.cpu() - del experts - gc.collect() - torch.cuda.empty_cache() + self.experts.weight.data[i].copy_( + torch.cat( + [experts[i].w1.weight.detach(), experts[i].w3.weight.detach()], + dim=0 + ) + ) + self.output_experts.weight.data[i].copy_( + experts[i].w2.weight.detach() + ) def forward( self, x: torch.Tensor, routing_weights: torch.Tensor, selected_experts: torch.Tensor