diff --git a/examples/mistral/mixtral_fused.py b/examples/mistral/mixtral_fused.py index 5aaf9652b..5e72e2266 100644 --- a/examples/mistral/mixtral_fused.py +++ b/examples/mistral/mixtral_fused.py @@ -5,7 +5,7 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock model_path = "mistralai/Mixtral-8x7B-Instruct-v0.1" # Load model -model = AutoModelForCausalLM.from_pretrained(model_path) +model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto") for name, module in model.named_modules(): if isinstance(module, MixtralSparseMoeBlock):