Lint models.py
This commit is contained in:
@@ -1,13 +1,16 @@
|
||||
"""Module for models and model loading"""
|
||||
|
||||
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, TYPE_CHECKING
|
||||
from typing import Optional, Tuple, TYPE_CHECKING # noqa: F401
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import (
|
||||
from transformers import ( # noqa: F401
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
PreTrainedModel,
|
||||
@@ -18,9 +21,8 @@ from transformers import (
|
||||
try:
|
||||
from transformers import (
|
||||
LlamaForCausalLM,
|
||||
LlamaTokenizer,
|
||||
)
|
||||
except:
|
||||
except ImportError:
|
||||
logging.warning(
|
||||
"This version of transformers does not support Llama. Consider upgrading."
|
||||
)
|
||||
@@ -28,9 +30,9 @@ except:
|
||||
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from peft import PeftModel, PeftConfig
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from transformers import PreTrainedTokenizer
|
||||
from peft import PeftConfig # noqa: F401
|
||||
from axolotl.utils.dict import DictDefault # noqa: F401
|
||||
from transformers import PreTrainedTokenizer # noqa: F401
|
||||
|
||||
|
||||
def load_tokenizer(
|
||||
@@ -62,8 +64,8 @@ def load_tokenizer(
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
if cfg.special_tokens:
|
||||
for k, v in cfg.special_tokens.items():
|
||||
tokenizer.add_special_tokens({k: v})
|
||||
for k, val in cfg.special_tokens.items():
|
||||
tokenizer.add_special_tokens({k: val})
|
||||
if cfg.tokens:
|
||||
tokenizer.add_tokens(list(cfg.tokens))
|
||||
|
||||
@@ -80,6 +82,9 @@ def load_model(
|
||||
inference=False,
|
||||
):
|
||||
# type: (str, str, str, str, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, PreTrainedTokenizer, Optional[PeftConfig]]
|
||||
"""
|
||||
Load a model from a base model and a model type.
|
||||
"""
|
||||
|
||||
# TODO refactor as a kwarg
|
||||
load_in_8bit = cfg.load_in_8bit
|
||||
@@ -115,9 +120,9 @@ def load_model(
|
||||
|
||||
replace_peft_model_with_int4_lora_model()
|
||||
from peft import prepare_model_for_int8_training
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
raise e
|
||||
except Exception as err:
|
||||
logging.exception(err)
|
||||
raise err
|
||||
|
||||
model_kwargs = {}
|
||||
if cfg.adapter == "qlora" and cfg.load_in_4bit:
|
||||
@@ -155,7 +160,7 @@ def load_model(
|
||||
"unable to find a cached model file, this will likely fail..."
|
||||
)
|
||||
model_path = str(cache_model_path)
|
||||
except:
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
model_path = cfg.base_model
|
||||
model, _ = load_llama_model_4bit_low_ram(
|
||||
base_model_config if base_model_config else base_model,
|
||||
@@ -210,13 +215,13 @@ def load_model(
|
||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||
torch_dtype=torch_dtype,
|
||||
device_map=cfg.device_map,
|
||||
trust_remote_code=True if cfg.trust_remote_code is True else False,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
**model_kwargs,
|
||||
)
|
||||
else:
|
||||
config = AutoConfig.from_pretrained(
|
||||
base_model,
|
||||
trust_remote_code=True if cfg.trust_remote_code is True else False,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
@@ -225,30 +230,29 @@ def load_model(
|
||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||
torch_dtype=torch_dtype,
|
||||
device_map=cfg.device_map,
|
||||
trust_remote_code=True if cfg.trust_remote_code is True else False,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
**model_kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
except Exception as err: # pylint: disable=broad-exception-caught
|
||||
logging.error(
|
||||
"Exception raised attempting to load model, retrying with AutoModelForCausalLM"
|
||||
)
|
||||
logging.exception(e)
|
||||
logging.exception(err)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||
torch_dtype=torch_dtype,
|
||||
device_map=cfg.device_map,
|
||||
trust_remote_code=True if cfg.trust_remote_code is True else False,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
embeddings_len = math.ceil(len(tokenizer) / 32) * 32
|
||||
model.resize_token_embeddings(embeddings_len)
|
||||
|
||||
if (
|
||||
((cfg.adapter == "lora" and load_in_8bit) or cfg.adapter == "qlora")
|
||||
and not cfg.gptq
|
||||
and (load_in_8bit or cfg.load_in_4bit)
|
||||
if not cfg.gptq and (
|
||||
(cfg.adapter == "lora" and load_in_8bit)
|
||||
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
|
||||
):
|
||||
logging.info("converting PEFT model w/ prepare_model_for_int8_training")
|
||||
model = prepare_model_for_int8_training(model)
|
||||
@@ -261,14 +265,14 @@ def load_model(
|
||||
if cfg.gptq:
|
||||
# Scales to half
|
||||
logging.info("Fitting 4bit scales and zeros to half")
|
||||
for n, m in model.named_modules():
|
||||
if "Autograd4bitQuantLinear" in str(type(m)) or "Linear4bitLt" in str(
|
||||
type(m)
|
||||
for _, module in model.named_modules():
|
||||
if "Autograd4bitQuantLinear" in str(type(module)) or "Linear4bitLt" in str(
|
||||
type(module)
|
||||
):
|
||||
if hasattr(m, "is_v1_model") and m.is_v1_model:
|
||||
m.zeros = m.zeros.half()
|
||||
m.scales = m.scales.half()
|
||||
m.bias = m.bias.half()
|
||||
if hasattr(module, "is_v1_model") and module.is_v1_model:
|
||||
module.zeros = module.zeros.half()
|
||||
module.scales = module.scales.half()
|
||||
module.bias = module.bias.half()
|
||||
|
||||
if (
|
||||
torch.cuda.device_count() > 1
|
||||
|
||||
Reference in New Issue
Block a user