Lint models.py

This commit is contained in:
NanoCode012
2023-05-29 13:32:04 +09:00
parent daf47ccf45
commit f4e5d86268

View File

@@ -1,13 +1,16 @@
"""Module for models and model loading"""
import logging import logging
import math import math
import os import os
from pathlib import Path from pathlib import Path
from typing import Optional, Tuple, TYPE_CHECKING from typing import Optional, Tuple, TYPE_CHECKING # noqa: F401
import bitsandbytes as bnb import bitsandbytes as bnb
import torch import torch
import transformers import transformers
from transformers import ( from transformers import ( # noqa: F401
AutoModelForCausalLM, AutoModelForCausalLM,
AutoTokenizer, AutoTokenizer,
PreTrainedModel, PreTrainedModel,
@@ -18,9 +21,8 @@ from transformers import (
try: try:
from transformers import ( from transformers import (
LlamaForCausalLM, LlamaForCausalLM,
LlamaTokenizer,
) )
except: except ImportError:
logging.warning( logging.warning(
"This version of transformers does not support Llama. Consider upgrading." "This version of transformers does not support Llama. Consider upgrading."
) )
@@ -28,9 +30,9 @@ except:
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
if TYPE_CHECKING: if TYPE_CHECKING:
from peft import PeftModel, PeftConfig from peft import PeftConfig # noqa: F401
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault # noqa: F401
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer # noqa: F401
def load_tokenizer( def load_tokenizer(
@@ -62,8 +64,8 @@ def load_tokenizer(
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
if cfg.special_tokens: if cfg.special_tokens:
for k, v in cfg.special_tokens.items(): for k, val in cfg.special_tokens.items():
tokenizer.add_special_tokens({k: v}) tokenizer.add_special_tokens({k: val})
if cfg.tokens: if cfg.tokens:
tokenizer.add_tokens(list(cfg.tokens)) tokenizer.add_tokens(list(cfg.tokens))
@@ -80,6 +82,9 @@ def load_model(
inference=False, inference=False,
): ):
# type: (str, str, str, str, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, PreTrainedTokenizer, Optional[PeftConfig]] # type: (str, str, str, str, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, PreTrainedTokenizer, Optional[PeftConfig]]
"""
Load a model from a base model and a model type.
"""
# TODO refactor as a kwarg # TODO refactor as a kwarg
load_in_8bit = cfg.load_in_8bit load_in_8bit = cfg.load_in_8bit
@@ -115,9 +120,9 @@ def load_model(
replace_peft_model_with_int4_lora_model() replace_peft_model_with_int4_lora_model()
from peft import prepare_model_for_int8_training from peft import prepare_model_for_int8_training
except Exception as e: except Exception as err:
logging.exception(e) logging.exception(err)
raise e raise err
model_kwargs = {} model_kwargs = {}
if cfg.adapter == "qlora" and cfg.load_in_4bit: if cfg.adapter == "qlora" and cfg.load_in_4bit:
@@ -155,7 +160,7 @@ def load_model(
"unable to find a cached model file, this will likely fail..." "unable to find a cached model file, this will likely fail..."
) )
model_path = str(cache_model_path) model_path = str(cache_model_path)
except: except Exception: # pylint: disable=broad-exception-caught
model_path = cfg.base_model model_path = cfg.base_model
model, _ = load_llama_model_4bit_low_ram( model, _ = load_llama_model_4bit_low_ram(
base_model_config if base_model_config else base_model, base_model_config if base_model_config else base_model,
@@ -210,13 +215,13 @@ def load_model(
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
device_map=cfg.device_map, device_map=cfg.device_map,
trust_remote_code=True if cfg.trust_remote_code is True else False, trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs, **model_kwargs,
) )
else: else:
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
base_model, base_model,
trust_remote_code=True if cfg.trust_remote_code is True else False, trust_remote_code=cfg.trust_remote_code or False,
) )
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
base_model, base_model,
@@ -225,30 +230,29 @@ def load_model(
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
device_map=cfg.device_map, device_map=cfg.device_map,
trust_remote_code=True if cfg.trust_remote_code is True else False, trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs, **model_kwargs,
) )
except Exception as e: except Exception as err: # pylint: disable=broad-exception-caught
logging.error( logging.error(
"Exception raised attempting to load model, retrying with AutoModelForCausalLM" "Exception raised attempting to load model, retrying with AutoModelForCausalLM"
) )
logging.exception(e) logging.exception(err)
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
base_model, base_model,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
device_map=cfg.device_map, device_map=cfg.device_map,
trust_remote_code=True if cfg.trust_remote_code is True else False, trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs, **model_kwargs,
) )
embeddings_len = math.ceil(len(tokenizer) / 32) * 32 embeddings_len = math.ceil(len(tokenizer) / 32) * 32
model.resize_token_embeddings(embeddings_len) model.resize_token_embeddings(embeddings_len)
if ( if not cfg.gptq and (
((cfg.adapter == "lora" and load_in_8bit) or cfg.adapter == "qlora") (cfg.adapter == "lora" and load_in_8bit)
and not cfg.gptq or (cfg.adapter == "qlora" and cfg.load_in_4bit)
and (load_in_8bit or cfg.load_in_4bit)
): ):
logging.info("converting PEFT model w/ prepare_model_for_int8_training") logging.info("converting PEFT model w/ prepare_model_for_int8_training")
model = prepare_model_for_int8_training(model) model = prepare_model_for_int8_training(model)
@@ -261,14 +265,14 @@ def load_model(
if cfg.gptq: if cfg.gptq:
# Scales to half # Scales to half
logging.info("Fitting 4bit scales and zeros to half") logging.info("Fitting 4bit scales and zeros to half")
for n, m in model.named_modules(): for _, module in model.named_modules():
if "Autograd4bitQuantLinear" in str(type(m)) or "Linear4bitLt" in str( if "Autograd4bitQuantLinear" in str(type(module)) or "Linear4bitLt" in str(
type(m) type(module)
): ):
if hasattr(m, "is_v1_model") and m.is_v1_model: if hasattr(module, "is_v1_model") and module.is_v1_model:
m.zeros = m.zeros.half() module.zeros = module.zeros.half()
m.scales = m.scales.half() module.scales = module.scales.half()
m.bias = m.bias.half() module.bias = module.bias.half()
if ( if (
torch.cuda.device_count() > 1 torch.cuda.device_count() > 1