From 5bfc470d5712eb986fa841efe983680e90e482b6 Mon Sep 17 00:00:00 2001 From: Casper Hansen Date: Mon, 18 Mar 2024 11:47:47 +0000 Subject: [PATCH] Stop transformers from using all memory --- examples/mistral/mixtral_fused.py | 15 ++++++++++----- src/axolotl/monkeypatch/moe/mlp.py | 9 +++++---- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/examples/mistral/mixtral_fused.py b/examples/mistral/mixtral_fused.py index d954d5995..beb2cdc9e 100644 --- a/examples/mistral/mixtral_fused.py +++ b/examples/mistral/mixtral_fused.py @@ -1,8 +1,8 @@ import torch from tqdm import tqdm from axolotl.monkeypatch.moe.moe import SparseMoeBlock -from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer -from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock +from transformers import AutoTokenizer, TextStreamer +from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock, MixtralForCausalLM, MixtralConfig def compute_memory_used_pct(device): memory_used = torch.cuda.max_memory_allocated(device) / (1024**3) @@ -16,11 +16,16 @@ def compute_memory_used_pct(device): model_path = "mistralai/Mixtral-8x7B-Instruct-v0.1" # Load model -model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto") +config = MixtralConfig.from_pretrained(model_path, max_position_embeddings=2048) +model = MixtralForCausalLM.from_pretrained(model_path, config=config, device_map="auto", low_cpu_mem_usage=True, torch_dtype=torch.float16) modules = {k:v for k,v in model.named_modules() if isinstance(v, MixtralSparseMoeBlock)} +for device_index in range(torch.cuda.device_count()): + device_memory_pct = compute_memory_used_pct(device_index) + print(device_index, device_memory_pct) + with tqdm(modules.items(), desc="scatter moe") as pbar: - for name, module in pbar: + for i, (name, module) in enumerate(pbar): smoe = SparseMoeBlock( experts=module.experts, gate=module.gate, @@ -29,7 +34,7 @@ with tqdm(modules.items(), desc="scatter moe") as pbar: num_experts=module.num_experts, top_k=module.top_k, ) - setattr(model, name, smoe) + setattr(model.model.layers[i], "block_sparse_moe", smoe) for device_index in range(torch.cuda.device_count()): device_memory_pct = compute_memory_used_pct(device_index) print(device_index, device_memory_pct) diff --git a/src/axolotl/monkeypatch/moe/mlp.py b/src/axolotl/monkeypatch/moe/mlp.py index 5dbce5d61..dff55b0f6 100644 --- a/src/axolotl/monkeypatch/moe/mlp.py +++ b/src/axolotl/monkeypatch/moe/mlp.py @@ -15,7 +15,7 @@ from axolotl.monkeypatch.moe.linear import ParallelExperts class FusedExperts(nn.Module): def __init__( self, - experts=None, + experts: nn.ModuleList =None, hidden_dim=128, ffn_dim=512, num_experts=8, @@ -42,9 +42,10 @@ class FusedExperts(nn.Module): self.experts.weight.data[i] = torch.cat([experts[i].w1.weight, experts[i].w3.weight], dim=0) self.output_experts.weight.data[i] = experts[i].w2.weight - del experts[i].w1, experts[i].w2, experts[i].w3 - gc.collect() - torch.cuda.empty_cache() + experts = experts.cpu() + del experts + gc.collect() + torch.cuda.empty_cache() def forward( self, x: torch.Tensor, routing_weights: torch.Tensor, selected_experts: torch.Tensor