Fix: Fail bf16 check when running on cpu during merge (#631)

This commit is contained in:
NanoCode012
2023-09-25 13:48:18 +09:00
committed by GitHub
parent 67b9888630
commit cfbce020e9
2 changed files with 24 additions and 1 deletions

View File

@@ -94,7 +94,7 @@ def validate_config(cfg):
if not cfg.bf16 and not cfg.bfloat16: if not cfg.bf16 and not cfg.bfloat16:
LOG.info("bf16 support detected, but not enabled for this configuration.") LOG.info("bf16 support detected, but not enabled for this configuration.")
else: else:
if cfg.bf16 or cfg.bfloat16: if not cfg.merge_lora and (cfg.bf16 or cfg.bfloat16):
raise ValueError( raise ValueError(
"bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above." "bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above."
) )

View File

@@ -351,3 +351,26 @@ class ValidationTest(unittest.TestCase):
regex_exp = r".*set only one of max_packed_sequence_len \(deprecated soon\) or sample_packing.*" regex_exp = r".*set only one of max_packed_sequence_len \(deprecated soon\) or sample_packing.*"
with pytest.raises(ValueError, match=regex_exp): with pytest.raises(ValueError, match=regex_exp):
validate_config(cfg) validate_config(cfg)
def test_merge_lora_no_bf16_fail(self):
"""
This is assumed to be run on a CPU machine, so bf16 is not supported.
"""
cfg = DictDefault(
{
"bf16": True,
}
)
with pytest.raises(ValueError, match=r".*AMP is not supported on this GPU*"):
validate_config(cfg)
cfg = DictDefault(
{
"bf16": True,
"merge_lora": True,
}
)
validate_config(cfg)