From 4c92b51cd57c2cd17726e34403241f0759b5ddcf Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 11 Apr 2024 08:56:46 -0400 Subject: [PATCH] fix the torch dtype check --- tests/e2e/test_lora_llama.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/e2e/test_lora_llama.py b/tests/e2e/test_lora_llama.py index 29f8a036c..02d71d174 100644 --- a/tests/e2e/test_lora_llama.py +++ b/tests/e2e/test_lora_llama.py @@ -125,6 +125,6 @@ class TestLoraLlama(unittest.TestCase): config = f_handle.read() config = json.loads(config) if is_torch_bf16_gpu_available(): - assert config["torch_dtype"] == "torch.bfloat16" + assert config["torch_dtype"] == "bfloat16" else: - assert config["torch_dtype"] == "torch.float16" + assert config["torch_dtype"] == "float16"