add support for opimum bettertransformers
This commit is contained in:
@@ -11,7 +11,8 @@ import bitsandbytes as bnb
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import PreTrainedModel # noqa: F401
|
||||
from transformers import ( # noqa: F401
|
||||
from optimum.bettertransformer import BetterTransformer
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
@@ -137,7 +138,7 @@ def load_model(
|
||||
|
||||
if cfg.bf16:
|
||||
torch_dtype = torch.bfloat16
|
||||
elif cfg.load_in_8bit or cfg.fp16:
|
||||
elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
|
||||
torch_dtype = torch.float16
|
||||
else:
|
||||
torch_dtype = torch.float32
|
||||
@@ -342,6 +343,9 @@ def load_model(
|
||||
logging.warning("there are no parameters that require gradient updates")
|
||||
model.config.use_cache = False
|
||||
|
||||
if cfg.flash_optimum:
|
||||
model = BetterTransformer.transform(model)
|
||||
|
||||
# TODO resume_from_checkpoint handling
|
||||
return model, lora_config
|
||||
|
||||
|
||||
@@ -57,6 +57,14 @@ def validate_config(cfg):
|
||||
if (cfg.base_model and "falcon" in cfg.base_model.lower()) and cfg.fsdp:
|
||||
raise ValueError("FSDP is not supported for falcon models")
|
||||
|
||||
if cfg.flash_optimum is True:
|
||||
if cfg.adapter:
|
||||
logging.warning("BetterTransformers probably doesn't work with PEFT adapters")
|
||||
if cfg.fp16 or cfg.bf16:
|
||||
raise ValueError("AMP is not supported with BetterTransformer")
|
||||
if cfg.float16 is not True:
|
||||
logging.warning("You should probably set float16 to true")
|
||||
|
||||
# TODO
|
||||
# MPT 7b
|
||||
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
||||
|
||||
Reference in New Issue
Block a user