add validation/warning for bettertransformers and torch version

This commit is contained in:
Wing Lian
2023-05-28 08:56:08 -04:00
parent 39619028a3
commit 71a43f8479

View File

@@ -1,7 +1,7 @@
"""Module for validating config files"""
import logging
import torch
def validate_config(cfg):
if cfg.gradient_accumulation_steps and cfg.batch_size:
@@ -63,7 +63,10 @@ def validate_config(cfg):
if cfg.fp16 or cfg.bf16:
raise ValueError("AMP is not supported with BetterTransformer")
if cfg.float16 is not True:
logging.warning("You should probably set float16 to true")
logging.warning("You should probably set float16 to true to load the model in float16 for BetterTransformers")
if torch.__version__.split(".")[0] < 2:
logging.warning("torch>=2.0.0 required")
raise ValueError(f"flash_optimum for BetterTransformers may not be used with {torch.__version__}")
# TODO
# MPT 7b