From 131afdbd897c5e93aabb0dd607c1d20abc55fbae Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 17 Sep 2023 13:49:03 -0400 Subject: [PATCH] add bf16 check (#587) --- src/axolotl/utils/config.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index dab6961a2..1c0487ff8 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -4,6 +4,7 @@ import logging import os import torch +from transformers.utils import is_torch_bf16_gpu_available from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.models import load_model_config @@ -89,6 +90,14 @@ def normalize_config(cfg): def validate_config(cfg): + if is_torch_bf16_gpu_available(): + if not cfg.bf16 and not cfg.bfloat16: + LOG.info("bf16 support detected, but not enabled for this configuration.") + else: + if cfg.bf16 or cfg.bfloat16: + raise ValueError( + "bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above." + ) if cfg.max_packed_sequence_len and cfg.sample_packing: raise ValueError( "please set only one of max_packed_sequence_len (deprecated soon) or sample_packing"