diff --git a/tests/e2e/test_lora_llama.py b/tests/e2e/test_lora_llama.py index 29f8a036c..02d71d174 100644 --- a/tests/e2e/test_lora_llama.py +++ b/tests/e2e/test_lora_llama.py @@ -125,6 +125,6 @@ class TestLoraLlama(unittest.TestCase): config = f_handle.read() config = json.loads(config) if is_torch_bf16_gpu_available(): - assert config["torch_dtype"] == "torch.bfloat16" + assert config["torch_dtype"] == "bfloat16" else: - assert config["torch_dtype"] == "torch.float16" + assert config["torch_dtype"] == "float16"