torch_dtype -> dtype (#3177)

* torch_dtype -> dtype

* torch_dtype -> dtype
This commit is contained in:
VED
2025-10-01 13:32:51 +05:30
committed by GitHub
parent f4376748f3
commit a6bfbe3400
5 changed files with 6 additions and 8 deletions

View File

@@ -160,7 +160,7 @@ def test_geglu_model_integration():
"""Test GeGLU activation with Gemma model."""
model = AutoModelForCausalLM.from_pretrained(
"trl-internal-testing/tiny-Gemma2ForCausalLM",
torch_dtype=torch.float16,
dtype=torch.float16,
device_map="cuda:0",
)
peft_config = get_peft_config(

View File

@@ -39,7 +39,7 @@ def model():
dummy_model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen2-0.5B",
device_map="auto",
torch_dtype=torch.bfloat16,
dtype=torch.bfloat16,
)
with torch.device(dummy_model.device):
dummy_model.model.embed_tokens = torch.nn.Embedding(