diff --git a/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py b/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py index 69c84f445..bada305b3 100644 --- a/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py +++ b/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py @@ -144,7 +144,7 @@ def test_swiglu_mlp_integration(small_llama_model): def test_geglu_model_integration(): """Test GeGLU activation with Gemma model.""" model = AutoModelForCausalLM.from_pretrained( - "mhenrichsen/gemma-2b", torch_dtype=torch.float16, device_map="auto" + "mhenrichsen/gemma-2b", torch_dtype=torch.float16, device_map="cuda:0" ) peft_config = get_peft_config( { @@ -347,7 +347,7 @@ def test_model_architecture(model_config): """Test LoRA kernel patches across different model architectures.""" # Load model with appropriate dtype model = AutoModelForCausalLM.from_pretrained( - model_config["name"], torch_dtype=model_config["dtype"], device_map="auto" + model_config["name"], torch_dtype=model_config["dtype"], device_map="cuda:0" ) # Apply LoRA configuration