make sure the all of the model is on the same device, so this test will pass on multigpu (#2524) [skip ci]
This commit is contained in:
@@ -144,7 +144,7 @@ def test_swiglu_mlp_integration(small_llama_model):
|
|||||||
def test_geglu_model_integration():
|
def test_geglu_model_integration():
|
||||||
"""Test GeGLU activation with Gemma model."""
|
"""Test GeGLU activation with Gemma model."""
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
"mhenrichsen/gemma-2b", torch_dtype=torch.float16, device_map="auto"
|
"mhenrichsen/gemma-2b", torch_dtype=torch.float16, device_map="cuda:0"
|
||||||
)
|
)
|
||||||
peft_config = get_peft_config(
|
peft_config = get_peft_config(
|
||||||
{
|
{
|
||||||
@@ -347,7 +347,7 @@ def test_model_architecture(model_config):
|
|||||||
"""Test LoRA kernel patches across different model architectures."""
|
"""Test LoRA kernel patches across different model architectures."""
|
||||||
# Load model with appropriate dtype
|
# Load model with appropriate dtype
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_config["name"], torch_dtype=model_config["dtype"], device_map="auto"
|
model_config["name"], torch_dtype=model_config["dtype"], device_map="cuda:0"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply LoRA configuration
|
# Apply LoRA configuration
|
||||||
|
|||||||
Reference in New Issue
Block a user