allow bf16 flag but warn (#3563) [skip ci]
* allow bf16 flag but warn Reason: when doing e.g. LoRA merges with CUDA_VISIBLE_DEVICES=, this will unnecessarily crash, even though the LoRA merge operation would have finished successfully. This seems to warrant changing it to a warning instead, as the code will most likely crash later if bf16 is unavailable and training begins anyway. * don't use deprecated LOG.warn * update tests to reflect validation change
This commit is contained in:
@@ -1352,8 +1352,8 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
and not self.is_preprocess
|
and not self.is_preprocess
|
||||||
and (self.bf16 is True or self.bfloat16 is True)
|
and (self.bf16 is True or self.bfloat16 is True)
|
||||||
):
|
):
|
||||||
raise ValueError(
|
LOG.warning(
|
||||||
"bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above."
|
"bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above. Training will fail, but other operations (such as merging) are still functional."
|
||||||
)
|
)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|||||||
@@ -726,8 +726,12 @@ class TestValidation(BaseValidation):
|
|||||||
| minimal_cfg
|
| minimal_cfg
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(ValueError, match=r".*AMP is not supported on this GPU*"):
|
with self._caplog.at_level("WARNING"):
|
||||||
AxolotlConfigWCapabilities(**cfg.to_dict())
|
AxolotlConfigWCapabilities(**cfg.to_dict())
|
||||||
|
assert any(
|
||||||
|
"AMP is not supported" in record.message
|
||||||
|
for record in self._caplog.records
|
||||||
|
)
|
||||||
|
|
||||||
cfg = (
|
cfg = (
|
||||||
DictDefault(
|
DictDefault(
|
||||||
|
|||||||
Reference in New Issue
Block a user