allow bf16 flag but warn (#3563) [skip ci]
* allow bf16 flag but warn Reason: when doing e.g. LoRA merges with CUDA_VISIBLE_DEVICES=, this will unnecessarily crash, even though the LoRA merge operation would have finished successfully. This seems to warrant changing it to a warning instead, as the code will most likely crash later if bf16 is unavailable and training begins anyway. * don't use deprecated LOG.warn * update tests to reflect validation change
This commit is contained in:
@@ -1352,8 +1352,8 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
and not self.is_preprocess
|
||||
and (self.bf16 is True or self.bfloat16 is True)
|
||||
):
|
||||
raise ValueError(
|
||||
"bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above."
|
||||
LOG.warning(
|
||||
"bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above. Training will fail, but other operations (such as merging) are still functional."
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
@@ -726,8 +726,12 @@ class TestValidation(BaseValidation):
|
||||
| minimal_cfg
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match=r".*AMP is not supported on this GPU*"):
|
||||
with self._caplog.at_level("WARNING"):
|
||||
AxolotlConfigWCapabilities(**cfg.to_dict())
|
||||
assert any(
|
||||
"AMP is not supported" in record.message
|
||||
for record in self._caplog.records
|
||||
)
|
||||
|
||||
cfg = (
|
||||
DictDefault(
|
||||
|
||||
Reference in New Issue
Block a user