Lint models.py

This commit is contained in:
NanoCode012
2023-05-29 13:32:04 +09:00
parent daf47ccf45
commit f4e5d86268

View File

@@ -1,13 +1,16 @@
"""Module for models and model loading"""
import logging
import math
import os
from pathlib import Path
from typing import Optional, Tuple, TYPE_CHECKING
from typing import Optional, Tuple, TYPE_CHECKING # noqa: F401
import bitsandbytes as bnb
import torch
import transformers
from transformers import (
from transformers import ( # noqa: F401
AutoModelForCausalLM,
AutoTokenizer,
PreTrainedModel,
@@ -18,9 +21,8 @@ from transformers import (
try:
from transformers import (
LlamaForCausalLM,
LlamaTokenizer,
)
except:
except ImportError:
logging.warning(
"This version of transformers does not support Llama. Consider upgrading."
)
@@ -28,9 +30,9 @@ except:
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
if TYPE_CHECKING:
from peft import PeftModel, PeftConfig
from axolotl.utils.dict import DictDefault
from transformers import PreTrainedTokenizer
from peft import PeftConfig # noqa: F401
from axolotl.utils.dict import DictDefault # noqa: F401
from transformers import PreTrainedTokenizer # noqa: F401
def load_tokenizer(
@@ -62,8 +64,8 @@ def load_tokenizer(
os.environ["TOKENIZERS_PARALLELISM"] = "false"
if cfg.special_tokens:
for k, v in cfg.special_tokens.items():
tokenizer.add_special_tokens({k: v})
for k, val in cfg.special_tokens.items():
tokenizer.add_special_tokens({k: val})
if cfg.tokens:
tokenizer.add_tokens(list(cfg.tokens))
@@ -80,6 +82,9 @@ def load_model(
inference=False,
):
# type: (str, str, str, str, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, PreTrainedTokenizer, Optional[PeftConfig]]
"""
Load a model from a base model and a model type.
"""
# TODO refactor as a kwarg
load_in_8bit = cfg.load_in_8bit
@@ -115,9 +120,9 @@ def load_model(
replace_peft_model_with_int4_lora_model()
from peft import prepare_model_for_int8_training
except Exception as e:
logging.exception(e)
raise e
except Exception as err:
logging.exception(err)
raise err
model_kwargs = {}
if cfg.adapter == "qlora" and cfg.load_in_4bit:
@@ -155,7 +160,7 @@ def load_model(
"unable to find a cached model file, this will likely fail..."
)
model_path = str(cache_model_path)
except:
except Exception: # pylint: disable=broad-exception-caught
model_path = cfg.base_model
model, _ = load_llama_model_4bit_low_ram(
base_model_config if base_model_config else base_model,
@@ -210,13 +215,13 @@ def load_model(
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=torch_dtype,
device_map=cfg.device_map,
trust_remote_code=True if cfg.trust_remote_code is True else False,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
else:
config = AutoConfig.from_pretrained(
base_model,
trust_remote_code=True if cfg.trust_remote_code is True else False,
trust_remote_code=cfg.trust_remote_code or False,
)
model = AutoModelForCausalLM.from_pretrained(
base_model,
@@ -225,30 +230,29 @@ def load_model(
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=torch_dtype,
device_map=cfg.device_map,
trust_remote_code=True if cfg.trust_remote_code is True else False,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
except Exception as e:
except Exception as err: # pylint: disable=broad-exception-caught
logging.error(
"Exception raised attempting to load model, retrying with AutoModelForCausalLM"
)
logging.exception(e)
logging.exception(err)
model = AutoModelForCausalLM.from_pretrained(
base_model,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
torch_dtype=torch_dtype,
device_map=cfg.device_map,
trust_remote_code=True if cfg.trust_remote_code is True else False,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
embeddings_len = math.ceil(len(tokenizer) / 32) * 32
model.resize_token_embeddings(embeddings_len)
if (
((cfg.adapter == "lora" and load_in_8bit) or cfg.adapter == "qlora")
and not cfg.gptq
and (load_in_8bit or cfg.load_in_4bit)
if not cfg.gptq and (
(cfg.adapter == "lora" and load_in_8bit)
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
):
logging.info("converting PEFT model w/ prepare_model_for_int8_training")
model = prepare_model_for_int8_training(model)
@@ -261,14 +265,14 @@ def load_model(
if cfg.gptq:
# Scales to half
logging.info("Fitting 4bit scales and zeros to half")
for n, m in model.named_modules():
if "Autograd4bitQuantLinear" in str(type(m)) or "Linear4bitLt" in str(
type(m)
for _, module in model.named_modules():
if "Autograd4bitQuantLinear" in str(type(module)) or "Linear4bitLt" in str(
type(module)
):
if hasattr(m, "is_v1_model") and m.is_v1_model:
m.zeros = m.zeros.half()
m.scales = m.scales.half()
m.bias = m.bias.half()
if hasattr(module, "is_v1_model") and module.is_v1_model:
module.zeros = module.zeros.half()
module.scales = module.scales.half()
module.bias = module.bias.half()
if (
torch.cuda.device_count() > 1