Fix: Fail bf16 check when running on cpu during merge (#631)
This commit is contained in:
@@ -94,7 +94,7 @@ def validate_config(cfg):
|
|||||||
if not cfg.bf16 and not cfg.bfloat16:
|
if not cfg.bf16 and not cfg.bfloat16:
|
||||||
LOG.info("bf16 support detected, but not enabled for this configuration.")
|
LOG.info("bf16 support detected, but not enabled for this configuration.")
|
||||||
else:
|
else:
|
||||||
if cfg.bf16 or cfg.bfloat16:
|
if not cfg.merge_lora and (cfg.bf16 or cfg.bfloat16):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above."
|
"bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above."
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -351,3 +351,26 @@ class ValidationTest(unittest.TestCase):
|
|||||||
regex_exp = r".*set only one of max_packed_sequence_len \(deprecated soon\) or sample_packing.*"
|
regex_exp = r".*set only one of max_packed_sequence_len \(deprecated soon\) or sample_packing.*"
|
||||||
with pytest.raises(ValueError, match=regex_exp):
|
with pytest.raises(ValueError, match=regex_exp):
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
|
def test_merge_lora_no_bf16_fail(self):
|
||||||
|
"""
|
||||||
|
This is assumed to be run on a CPU machine, so bf16 is not supported.
|
||||||
|
"""
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"bf16": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match=r".*AMP is not supported on this GPU*"):
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"bf16": True,
|
||||||
|
"merge_lora": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
validate_config(cfg)
|
||||||
|
|||||||
Reference in New Issue
Block a user