Migrate QAT API; fix axolotl quantize for QAT-ed models; add NVFP4 (#3107)
This commit is contained in:
@@ -90,6 +90,18 @@ def require_torch_2_7_0(test_case):
|
||||
return unittest.skipUnless(is_min_2_7_0(), "test requires torch>=2.7.0")(test_case)
|
||||
|
||||
|
||||
def require_torch_2_8_0(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires torch >= 2.7.0
|
||||
"""
|
||||
|
||||
def is_min_2_8_0():
|
||||
torch_version = version.parse(torch.__version__)
|
||||
return torch_version >= version.parse("2.8.0")
|
||||
|
||||
return unittest.skipUnless(is_min_2_8_0(), "test requires torch>=2.8.0")(test_case)
|
||||
|
||||
|
||||
def require_torch_lt_2_6_0(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires torch < 2.6.0
|
||||
@@ -128,6 +140,24 @@ def require_llmcompressor(test_case):
|
||||
)(test_case)
|
||||
|
||||
|
||||
def requires_sm_ge_100(test_case):
|
||||
is_sm_ge_100 = (
|
||||
torch.cuda.is_available()
|
||||
and torch.version.cuda
|
||||
and torch.cuda.get_device_capability() >= (10, 0)
|
||||
)
|
||||
return unittest.skipUnless(is_sm_ge_100, "test requires sm>=100")(test_case)
|
||||
|
||||
|
||||
def requires_cuda_ge_8_9(test_case):
|
||||
is_cuda_ge_8_9 = (
|
||||
torch.cuda.is_available()
|
||||
and torch.version.cuda
|
||||
and torch.cuda.get_device_capability() >= (8, 9)
|
||||
)
|
||||
return unittest.skipUnless(is_cuda_ge_8_9, "test requires cuda>=8.9")(test_case)
|
||||
|
||||
|
||||
def is_hopper():
|
||||
compute_capability = torch.cuda.get_device_capability()
|
||||
return compute_capability == (9, 0)
|
||||
|
||||
Reference in New Issue
Block a user