Stop transformers from using all memory

This commit is contained in:
Casper Hansen
2024-03-18 11:47:47 +00:00
parent 04168801c9
commit 5bfc470d57
2 changed files with 15 additions and 9 deletions

View File

@@ -1,8 +1,8 @@
import torch import torch
from tqdm import tqdm from tqdm import tqdm
from axolotl.monkeypatch.moe.moe import SparseMoeBlock from axolotl.monkeypatch.moe.moe import SparseMoeBlock
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer from transformers import AutoTokenizer, TextStreamer
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock, MixtralForCausalLM, MixtralConfig
def compute_memory_used_pct(device): def compute_memory_used_pct(device):
memory_used = torch.cuda.max_memory_allocated(device) / (1024**3) memory_used = torch.cuda.max_memory_allocated(device) / (1024**3)
@@ -16,11 +16,16 @@ def compute_memory_used_pct(device):
model_path = "mistralai/Mixtral-8x7B-Instruct-v0.1" model_path = "mistralai/Mixtral-8x7B-Instruct-v0.1"
# Load model # Load model
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto") config = MixtralConfig.from_pretrained(model_path, max_position_embeddings=2048)
model = MixtralForCausalLM.from_pretrained(model_path, config=config, device_map="auto", low_cpu_mem_usage=True, torch_dtype=torch.float16)
modules = {k:v for k,v in model.named_modules() if isinstance(v, MixtralSparseMoeBlock)} modules = {k:v for k,v in model.named_modules() if isinstance(v, MixtralSparseMoeBlock)}
for device_index in range(torch.cuda.device_count()):
device_memory_pct = compute_memory_used_pct(device_index)
print(device_index, device_memory_pct)
with tqdm(modules.items(), desc="scatter moe") as pbar: with tqdm(modules.items(), desc="scatter moe") as pbar:
for name, module in pbar: for i, (name, module) in enumerate(pbar):
smoe = SparseMoeBlock( smoe = SparseMoeBlock(
experts=module.experts, experts=module.experts,
gate=module.gate, gate=module.gate,
@@ -29,7 +34,7 @@ with tqdm(modules.items(), desc="scatter moe") as pbar:
num_experts=module.num_experts, num_experts=module.num_experts,
top_k=module.top_k, top_k=module.top_k,
) )
setattr(model, name, smoe) setattr(model.model.layers[i], "block_sparse_moe", smoe)
for device_index in range(torch.cuda.device_count()): for device_index in range(torch.cuda.device_count()):
device_memory_pct = compute_memory_used_pct(device_index) device_memory_pct = compute_memory_used_pct(device_index)
print(device_index, device_memory_pct) print(device_index, device_memory_pct)

View File

@@ -15,7 +15,7 @@ from axolotl.monkeypatch.moe.linear import ParallelExperts
class FusedExperts(nn.Module): class FusedExperts(nn.Module):
def __init__( def __init__(
self, self,
experts=None, experts: nn.ModuleList =None,
hidden_dim=128, hidden_dim=128,
ffn_dim=512, ffn_dim=512,
num_experts=8, num_experts=8,
@@ -42,9 +42,10 @@ class FusedExperts(nn.Module):
self.experts.weight.data[i] = torch.cat([experts[i].w1.weight, experts[i].w3.weight], dim=0) self.experts.weight.data[i] = torch.cat([experts[i].w1.weight, experts[i].w3.weight], dim=0)
self.output_experts.weight.data[i] = experts[i].w2.weight self.output_experts.weight.data[i] = experts[i].w2.weight
del experts[i].w1, experts[i].w2, experts[i].w3 experts = experts.cpu()
gc.collect() del experts
torch.cuda.empty_cache() gc.collect()
torch.cuda.empty_cache()
def forward( def forward(
self, x: torch.Tensor, routing_weights: torch.Tensor, selected_experts: torch.Tensor self, x: torch.Tensor, routing_weights: torch.Tensor, selected_experts: torch.Tensor