add validation/warning for bettertransformers and torch version
This commit is contained in:
@@ -1,7 +1,7 @@
|
|||||||
"""Module for validating config files"""
|
"""Module for validating config files"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import torch
|
||||||
|
|
||||||
def validate_config(cfg):
|
def validate_config(cfg):
|
||||||
if cfg.gradient_accumulation_steps and cfg.batch_size:
|
if cfg.gradient_accumulation_steps and cfg.batch_size:
|
||||||
@@ -54,7 +54,10 @@ def validate_config(cfg):
|
|||||||
if cfg.fp16 or cfg.bf16:
|
if cfg.fp16 or cfg.bf16:
|
||||||
raise ValueError("AMP is not supported with BetterTransformer")
|
raise ValueError("AMP is not supported with BetterTransformer")
|
||||||
if cfg.float16 is not True:
|
if cfg.float16 is not True:
|
||||||
logging.warning("You should probably set float16 to true")
|
logging.warning("You should probably set float16 to true to load the model in float16 for BetterTransformers")
|
||||||
|
if torch.__version__.split(".")[0] < 2:
|
||||||
|
logging.warning("torch>=2.0.0 required")
|
||||||
|
raise ValueError(f"flash_optimum for BetterTransformers may not be used with {torch.__version__}")
|
||||||
# TODO
|
# TODO
|
||||||
# MPT 7b
|
# MPT 7b
|
||||||
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
||||||
|
|||||||
Reference in New Issue
Block a user