Fix: Fail bf16 check when running on cpu during merge (#631)
This commit is contained in:
@@ -351,3 +351,26 @@ class ValidationTest(unittest.TestCase):
|
||||
regex_exp = r".*set only one of max_packed_sequence_len \(deprecated soon\) or sample_packing.*"
|
||||
with pytest.raises(ValueError, match=regex_exp):
|
||||
validate_config(cfg)
|
||||
|
||||
def test_merge_lora_no_bf16_fail(self):
|
||||
"""
|
||||
This is assumed to be run on a CPU machine, so bf16 is not supported.
|
||||
"""
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"bf16": True,
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match=r".*AMP is not supported on this GPU*"):
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"bf16": True,
|
||||
"merge_lora": True,
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
Reference in New Issue
Block a user