Fix: Fail bf16 check when running on cpu during merge (#631)

This commit is contained in:
NanoCode012
2023-09-25 13:48:18 +09:00
committed by GitHub
parent 67b9888630
commit cfbce020e9
2 changed files with 24 additions and 1 deletions

View File

@@ -94,7 +94,7 @@ def validate_config(cfg):
if not cfg.bf16 and not cfg.bfloat16:
LOG.info("bf16 support detected, but not enabled for this configuration.")
else:
if cfg.bf16 or cfg.bfloat16:
if not cfg.merge_lora and (cfg.bf16 or cfg.bfloat16):
raise ValueError(
"bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above."
)

View File

@@ -351,3 +351,26 @@ class ValidationTest(unittest.TestCase):
regex_exp = r".*set only one of max_packed_sequence_len \(deprecated soon\) or sample_packing.*"
with pytest.raises(ValueError, match=regex_exp):
validate_config(cfg)
def test_merge_lora_no_bf16_fail(self):
"""
This is assumed to be run on a CPU machine, so bf16 is not supported.
"""
cfg = DictDefault(
{
"bf16": True,
}
)
with pytest.raises(ValueError, match=r".*AMP is not supported on this GPU*"):
validate_config(cfg)
cfg = DictDefault(
{
"bf16": True,
"merge_lora": True,
}
)
validate_config(cfg)