Feat: Add support for gemma3_text and add e2e for gemma2 (#2406)

This commit is contained in:
NanoCode012
2025-03-23 07:33:21 +07:00
committed by GitHub
parent 86bac48d14
commit 9f00465a5c
12 changed files with 348 additions and 6 deletions

View File

@@ -144,7 +144,7 @@ def test_swiglu_mlp_integration(small_llama_model):
def test_geglu_model_integration():
"""Test GeGLU activation with Gemma model."""
model = AutoModelForCausalLM.from_pretrained(
"mhenrichsen/gemma-2b", torch_dtype=torch.float16, device_map="cuda"
"mhenrichsen/gemma-2b", torch_dtype=torch.float16, device_map="auto"
)
peft_config = get_peft_config(
{
@@ -347,7 +347,7 @@ def test_model_architecture(model_config):
"""Test LoRA kernel patches across different model architectures."""
# Load model with appropriate dtype
model = AutoModelForCausalLM.from_pretrained(
model_config["name"], torch_dtype=model_config["dtype"], device_map="cuda"
model_config["name"], torch_dtype=model_config["dtype"], device_map="auto"
)
# Apply LoRA configuration