torch_dtype -> dtype (#3177)
* torch_dtype -> dtype * torch_dtype -> dtype
This commit is contained in:
@@ -39,7 +39,7 @@ def model():
|
||||
dummy_model = AutoModelForCausalLM.from_pretrained(
|
||||
"Qwen/Qwen2-0.5B",
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
with torch.device(dummy_model.device):
|
||||
dummy_model.model.embed_tokens = torch.nn.Embedding(
|
||||
|
||||
Reference in New Issue
Block a user