Enable or disable bf16 support based on availability (#1116)

This commit is contained in:
Simon Hällqvist
2024-01-14 18:06:56 +01:00
committed by GitHub
parent 2202a20f60
commit 086561326f
2 changed files with 29 additions and 0 deletions

View File

@@ -61,6 +61,14 @@ def normalize_config(cfg):
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))} cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
cfg.batch_size = cfg.batch_size * cfg.world_size cfg.batch_size = cfg.batch_size * cfg.world_size
if cfg.bf16 == "auto":
if is_torch_bf16_gpu_available():
LOG.debug("bf16 support detected, enabling for this configuration.")
cfg.bf16 = True
else:
LOG.debug("bf16 support not detected, disabling for this configuration.")
cfg.bf16 = False
if cfg.device == "mps": if cfg.device == "mps":
cfg.load_in_8bit = False cfg.load_in_8bit = False
cfg.tf32 = False cfg.tf32 = False

View File

@@ -2,6 +2,7 @@
Test classes for checking functionality of the cfg normalization Test classes for checking functionality of the cfg normalization
""" """
import unittest import unittest
from unittest.mock import patch
from axolotl.utils.config import normalize_cfg_datasets, normalize_config from axolotl.utils.config import normalize_cfg_datasets, normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -67,3 +68,23 @@ class NormalizeConfigTestCase(unittest.TestCase):
assert cfg.datasets[0].conversation == "vicuna_v1.1" assert cfg.datasets[0].conversation == "vicuna_v1.1"
assert cfg.datasets[1].conversation == "chatml" assert cfg.datasets[1].conversation == "chatml"
@patch("axolotl.utils.config.is_torch_bf16_gpu_available")
def test_bf16_auto_setter_available(self, mock_bf16_avail):
cfg = self._get_base_cfg()
cfg.bf16 = "auto"
mock_bf16_avail.return_value = True
normalize_config(cfg)
self.assertTrue(cfg.bf16)
@patch("axolotl.utils.config.is_torch_bf16_gpu_available")
def test_bf16_auto_setter_not_available(self, mock_bf16_avail):
cfg = self._get_base_cfg()
cfg.bf16 = "auto"
mock_bf16_avail.return_value = False
normalize_config(cfg)
self.assertFalse(cfg.bf16)