Stop transformers from using all memory
This commit is contained in:
@@ -1,8 +1,8 @@
|
|||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from axolotl.monkeypatch.moe.moe import SparseMoeBlock
|
from axolotl.monkeypatch.moe.moe import SparseMoeBlock
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
|
from transformers import AutoTokenizer, TextStreamer
|
||||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock, MixtralForCausalLM, MixtralConfig
|
||||||
|
|
||||||
def compute_memory_used_pct(device):
|
def compute_memory_used_pct(device):
|
||||||
memory_used = torch.cuda.max_memory_allocated(device) / (1024**3)
|
memory_used = torch.cuda.max_memory_allocated(device) / (1024**3)
|
||||||
@@ -16,11 +16,16 @@ def compute_memory_used_pct(device):
|
|||||||
model_path = "mistralai/Mixtral-8x7B-Instruct-v0.1"
|
model_path = "mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||||
|
|
||||||
# Load model
|
# Load model
|
||||||
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
|
config = MixtralConfig.from_pretrained(model_path, max_position_embeddings=2048)
|
||||||
|
model = MixtralForCausalLM.from_pretrained(model_path, config=config, device_map="auto", low_cpu_mem_usage=True, torch_dtype=torch.float16)
|
||||||
modules = {k:v for k,v in model.named_modules() if isinstance(v, MixtralSparseMoeBlock)}
|
modules = {k:v for k,v in model.named_modules() if isinstance(v, MixtralSparseMoeBlock)}
|
||||||
|
|
||||||
|
for device_index in range(torch.cuda.device_count()):
|
||||||
|
device_memory_pct = compute_memory_used_pct(device_index)
|
||||||
|
print(device_index, device_memory_pct)
|
||||||
|
|
||||||
with tqdm(modules.items(), desc="scatter moe") as pbar:
|
with tqdm(modules.items(), desc="scatter moe") as pbar:
|
||||||
for name, module in pbar:
|
for i, (name, module) in enumerate(pbar):
|
||||||
smoe = SparseMoeBlock(
|
smoe = SparseMoeBlock(
|
||||||
experts=module.experts,
|
experts=module.experts,
|
||||||
gate=module.gate,
|
gate=module.gate,
|
||||||
@@ -29,7 +34,7 @@ with tqdm(modules.items(), desc="scatter moe") as pbar:
|
|||||||
num_experts=module.num_experts,
|
num_experts=module.num_experts,
|
||||||
top_k=module.top_k,
|
top_k=module.top_k,
|
||||||
)
|
)
|
||||||
setattr(model, name, smoe)
|
setattr(model.model.layers[i], "block_sparse_moe", smoe)
|
||||||
for device_index in range(torch.cuda.device_count()):
|
for device_index in range(torch.cuda.device_count()):
|
||||||
device_memory_pct = compute_memory_used_pct(device_index)
|
device_memory_pct = compute_memory_used_pct(device_index)
|
||||||
print(device_index, device_memory_pct)
|
print(device_index, device_memory_pct)
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ from axolotl.monkeypatch.moe.linear import ParallelExperts
|
|||||||
class FusedExperts(nn.Module):
|
class FusedExperts(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
experts=None,
|
experts: nn.ModuleList =None,
|
||||||
hidden_dim=128,
|
hidden_dim=128,
|
||||||
ffn_dim=512,
|
ffn_dim=512,
|
||||||
num_experts=8,
|
num_experts=8,
|
||||||
@@ -42,9 +42,10 @@ class FusedExperts(nn.Module):
|
|||||||
self.experts.weight.data[i] = torch.cat([experts[i].w1.weight, experts[i].w3.weight], dim=0)
|
self.experts.weight.data[i] = torch.cat([experts[i].w1.weight, experts[i].w3.weight], dim=0)
|
||||||
self.output_experts.weight.data[i] = experts[i].w2.weight
|
self.output_experts.weight.data[i] = experts[i].w2.weight
|
||||||
|
|
||||||
del experts[i].w1, experts[i].w2, experts[i].w3
|
experts = experts.cpu()
|
||||||
gc.collect()
|
del experts
|
||||||
torch.cuda.empty_cache()
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, x: torch.Tensor, routing_weights: torch.Tensor, selected_experts: torch.Tensor
|
self, x: torch.Tensor, routing_weights: torch.Tensor, selected_experts: torch.Tensor
|
||||||
|
|||||||
Reference in New Issue
Block a user