device_map auto

This commit is contained in:
Casper
2024-03-17 19:52:56 +01:00
parent 884d81331e
commit d43a79b7bf

View File

@@ -5,7 +5,7 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
model_path = "mistralai/Mixtral-8x7B-Instruct-v0.1"
# Load model
model = AutoModelForCausalLM.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
for name, module in model.named_modules():
if isinstance(module, MixtralSparseMoeBlock):