From 198d775d6d0307a8c168ff762339c4904e233a3f Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 15 Apr 2025 22:15:42 -0700 Subject: [PATCH] make sure the all of the model is on the same device, so this test will pass on multigpu (#2524) [skip ci] --- tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py b/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py index 69c84f445..bada305b3 100644 --- a/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py +++ b/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py @@ -144,7 +144,7 @@ def test_swiglu_mlp_integration(small_llama_model): def test_geglu_model_integration(): """Test GeGLU activation with Gemma model.""" model = AutoModelForCausalLM.from_pretrained( - "mhenrichsen/gemma-2b", torch_dtype=torch.float16, device_map="auto" + "mhenrichsen/gemma-2b", torch_dtype=torch.float16, device_map="cuda:0" ) peft_config = get_peft_config( { @@ -347,7 +347,7 @@ def test_model_architecture(model_config): """Test LoRA kernel patches across different model architectures.""" # Load model with appropriate dtype model = AutoModelForCausalLM.from_pretrained( - model_config["name"], torch_dtype=model_config["dtype"], device_map="auto" + model_config["name"], torch_dtype=model_config["dtype"], device_map="cuda:0" ) # Apply LoRA configuration