fix the torch dtype check

This commit is contained in:
Wing Lian
2024-04-11 08:56:46 -04:00
parent 5767eea874
commit 4c92b51cd5

View File

@@ -125,6 +125,6 @@ class TestLoraLlama(unittest.TestCase):
config = f_handle.read()
config = json.loads(config)
if is_torch_bf16_gpu_available():
assert config["torch_dtype"] == "torch.bfloat16"
assert config["torch_dtype"] == "bfloat16"
else:
assert config["torch_dtype"] == "torch.float16"
assert config["torch_dtype"] == "float16"