black formatting

This commit is contained in:
Wing Lian
2023-05-10 16:01:08 -04:00
parent 7a490a4646
commit 2bc1a5bde1
11 changed files with 132 additions and 64 deletions

View File

@@ -8,15 +8,19 @@ import transformers
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
PreTrainedModel, AutoConfig,
PreTrainedModel,
AutoConfig,
)
try:
from transformers import (
LlamaForCausalLM,
LlamaTokenizer,
)
except:
logging.warning("This version of transformers does not support Llama. Consider upgrading.")
logging.warning(
"This version of transformers does not support Llama. Consider upgrading."
)
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
@@ -40,7 +44,9 @@ def load_model(
# TODO refactor as a kwarg
load_in_8bit = cfg.load_in_8bit
tokenizer = None
is_llama_derived_model = "llama" in base_model or (cfg.model_type and "llama" in cfg.model_type.lower())
is_llama_derived_model = "llama" in base_model or (
cfg.model_type and "llama" in cfg.model_type.lower()
)
if is_llama_derived_model and cfg.flash_attention:
if cfg.device not in ["mps", "cpu"] and inference is False:
@@ -49,11 +55,16 @@ def load_model(
logging.info("patching with flash attention")
replace_llama_attn_with_flash_attn()
elif is_llama_derived_model and cfg.xformers_attention:
from alpaca_lora_4bit.monkeypatch.llama_attn_hijack_xformers import hijack_llama_attention
from alpaca_lora_4bit.monkeypatch.llama_attn_hijack_xformers import (
hijack_llama_attention,
)
logging.info("patching with xformers attention")
hijack_llama_attention()
torch_dtype = torch.float16 if cfg.load_in_8bit or cfg.fp16 or cfg.bf16 else torch.float32
torch_dtype = (
torch.float16 if cfg.load_in_8bit or cfg.fp16 or cfg.bf16 else torch.float32
)
try:
if cfg.load_4bit:
from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import (
@@ -74,8 +85,12 @@ def load_model(
try:
snapshot_download_kwargs = {}
if cfg.base_model_ignore_patterns:
snapshot_download_kwargs["ignore_patterns"] = cfg.base_model_ignore_patterns
cache_model_path = Path(snapshot_download(base_model, **snapshot_download_kwargs))
snapshot_download_kwargs[
"ignore_patterns"
] = cfg.base_model_ignore_patterns
cache_model_path = Path(
snapshot_download(base_model, **snapshot_download_kwargs)
)
files = (
list(cache_model_path.glob("*.pt"))
+ list(cache_model_path.glob("*.safetensors"))