device_map auto
This commit is contained in:
@@ -5,7 +5,7 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
|||||||
model_path = "mistralai/Mixtral-8x7B-Instruct-v0.1"
|
model_path = "mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||||
|
|
||||||
# Load model
|
# Load model
|
||||||
model = AutoModelForCausalLM.from_pretrained(model_path)
|
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
|
||||||
|
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
if isinstance(module, MixtralSparseMoeBlock):
|
if isinstance(module, MixtralSparseMoeBlock):
|
||||||
|
|||||||
Reference in New Issue
Block a user