From 71a43f8479a1cef0247ceb2cc00c7c1a048ed863 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 28 May 2023 08:56:08 -0400 Subject: [PATCH] add validation/warning for bettertransformers and torch version --- src/axolotl/utils/validation.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/axolotl/utils/validation.py b/src/axolotl/utils/validation.py index ba5feafe8..db19900cc 100644 --- a/src/axolotl/utils/validation.py +++ b/src/axolotl/utils/validation.py @@ -1,7 +1,7 @@ """Module for validating config files""" import logging - +import torch def validate_config(cfg): if cfg.gradient_accumulation_steps and cfg.batch_size: @@ -63,7 +63,10 @@ def validate_config(cfg): if cfg.fp16 or cfg.bf16: raise ValueError("AMP is not supported with BetterTransformer") if cfg.float16 is not True: - logging.warning("You should probably set float16 to true") + logging.warning("You should probably set float16 to true to load the model in float16 for BetterTransformers") + if torch.__version__.split(".")[0] < 2: + logging.warning("torch>=2.0.0 required") + raise ValueError(f"flash_optimum for BetterTransformers may not be used with {torch.__version__}") # TODO # MPT 7b