Lint models.py
This commit is contained in:
@@ -1,13 +1,16 @@
|
|||||||
|
"""Module for models and model loading"""
|
||||||
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Tuple, TYPE_CHECKING
|
from typing import Optional, Tuple, TYPE_CHECKING # noqa: F401
|
||||||
|
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from transformers import (
|
from transformers import ( # noqa: F401
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
PreTrainedModel,
|
PreTrainedModel,
|
||||||
@@ -18,9 +21,8 @@ from transformers import (
|
|||||||
try:
|
try:
|
||||||
from transformers import (
|
from transformers import (
|
||||||
LlamaForCausalLM,
|
LlamaForCausalLM,
|
||||||
LlamaTokenizer,
|
|
||||||
)
|
)
|
||||||
except:
|
except ImportError:
|
||||||
logging.warning(
|
logging.warning(
|
||||||
"This version of transformers does not support Llama. Consider upgrading."
|
"This version of transformers does not support Llama. Consider upgrading."
|
||||||
)
|
)
|
||||||
@@ -28,9 +30,9 @@ except:
|
|||||||
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
|
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from peft import PeftModel, PeftConfig
|
from peft import PeftConfig # noqa: F401
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault # noqa: F401
|
||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
def load_tokenizer(
|
def load_tokenizer(
|
||||||
@@ -62,8 +64,8 @@ def load_tokenizer(
|
|||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
if cfg.special_tokens:
|
if cfg.special_tokens:
|
||||||
for k, v in cfg.special_tokens.items():
|
for k, val in cfg.special_tokens.items():
|
||||||
tokenizer.add_special_tokens({k: v})
|
tokenizer.add_special_tokens({k: val})
|
||||||
if cfg.tokens:
|
if cfg.tokens:
|
||||||
tokenizer.add_tokens(list(cfg.tokens))
|
tokenizer.add_tokens(list(cfg.tokens))
|
||||||
|
|
||||||
@@ -80,6 +82,9 @@ def load_model(
|
|||||||
inference=False,
|
inference=False,
|
||||||
):
|
):
|
||||||
# type: (str, str, str, str, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, PreTrainedTokenizer, Optional[PeftConfig]]
|
# type: (str, str, str, str, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, PreTrainedTokenizer, Optional[PeftConfig]]
|
||||||
|
"""
|
||||||
|
Load a model from a base model and a model type.
|
||||||
|
"""
|
||||||
|
|
||||||
# TODO refactor as a kwarg
|
# TODO refactor as a kwarg
|
||||||
load_in_8bit = cfg.load_in_8bit
|
load_in_8bit = cfg.load_in_8bit
|
||||||
@@ -115,9 +120,9 @@ def load_model(
|
|||||||
|
|
||||||
replace_peft_model_with_int4_lora_model()
|
replace_peft_model_with_int4_lora_model()
|
||||||
from peft import prepare_model_for_int8_training
|
from peft import prepare_model_for_int8_training
|
||||||
except Exception as e:
|
except Exception as err:
|
||||||
logging.exception(e)
|
logging.exception(err)
|
||||||
raise e
|
raise err
|
||||||
|
|
||||||
model_kwargs = {}
|
model_kwargs = {}
|
||||||
if cfg.adapter == "qlora" and cfg.load_in_4bit:
|
if cfg.adapter == "qlora" and cfg.load_in_4bit:
|
||||||
@@ -155,7 +160,7 @@ def load_model(
|
|||||||
"unable to find a cached model file, this will likely fail..."
|
"unable to find a cached model file, this will likely fail..."
|
||||||
)
|
)
|
||||||
model_path = str(cache_model_path)
|
model_path = str(cache_model_path)
|
||||||
except:
|
except Exception: # pylint: disable=broad-exception-caught
|
||||||
model_path = cfg.base_model
|
model_path = cfg.base_model
|
||||||
model, _ = load_llama_model_4bit_low_ram(
|
model, _ = load_llama_model_4bit_low_ram(
|
||||||
base_model_config if base_model_config else base_model,
|
base_model_config if base_model_config else base_model,
|
||||||
@@ -210,13 +215,13 @@ def load_model(
|
|||||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
device_map=cfg.device_map,
|
device_map=cfg.device_map,
|
||||||
trust_remote_code=True if cfg.trust_remote_code is True else False,
|
trust_remote_code=cfg.trust_remote_code or False,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
config = AutoConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
trust_remote_code=True if cfg.trust_remote_code is True else False,
|
trust_remote_code=cfg.trust_remote_code or False,
|
||||||
)
|
)
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
@@ -225,30 +230,29 @@ def load_model(
|
|||||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
device_map=cfg.device_map,
|
device_map=cfg.device_map,
|
||||||
trust_remote_code=True if cfg.trust_remote_code is True else False,
|
trust_remote_code=cfg.trust_remote_code or False,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as err: # pylint: disable=broad-exception-caught
|
||||||
logging.error(
|
logging.error(
|
||||||
"Exception raised attempting to load model, retrying with AutoModelForCausalLM"
|
"Exception raised attempting to load model, retrying with AutoModelForCausalLM"
|
||||||
)
|
)
|
||||||
logging.exception(e)
|
logging.exception(err)
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
device_map=cfg.device_map,
|
device_map=cfg.device_map,
|
||||||
trust_remote_code=True if cfg.trust_remote_code is True else False,
|
trust_remote_code=cfg.trust_remote_code or False,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
embeddings_len = math.ceil(len(tokenizer) / 32) * 32
|
embeddings_len = math.ceil(len(tokenizer) / 32) * 32
|
||||||
model.resize_token_embeddings(embeddings_len)
|
model.resize_token_embeddings(embeddings_len)
|
||||||
|
|
||||||
if (
|
if not cfg.gptq and (
|
||||||
((cfg.adapter == "lora" and load_in_8bit) or cfg.adapter == "qlora")
|
(cfg.adapter == "lora" and load_in_8bit)
|
||||||
and not cfg.gptq
|
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
|
||||||
and (load_in_8bit or cfg.load_in_4bit)
|
|
||||||
):
|
):
|
||||||
logging.info("converting PEFT model w/ prepare_model_for_int8_training")
|
logging.info("converting PEFT model w/ prepare_model_for_int8_training")
|
||||||
model = prepare_model_for_int8_training(model)
|
model = prepare_model_for_int8_training(model)
|
||||||
@@ -261,14 +265,14 @@ def load_model(
|
|||||||
if cfg.gptq:
|
if cfg.gptq:
|
||||||
# Scales to half
|
# Scales to half
|
||||||
logging.info("Fitting 4bit scales and zeros to half")
|
logging.info("Fitting 4bit scales and zeros to half")
|
||||||
for n, m in model.named_modules():
|
for _, module in model.named_modules():
|
||||||
if "Autograd4bitQuantLinear" in str(type(m)) or "Linear4bitLt" in str(
|
if "Autograd4bitQuantLinear" in str(type(module)) or "Linear4bitLt" in str(
|
||||||
type(m)
|
type(module)
|
||||||
):
|
):
|
||||||
if hasattr(m, "is_v1_model") and m.is_v1_model:
|
if hasattr(module, "is_v1_model") and module.is_v1_model:
|
||||||
m.zeros = m.zeros.half()
|
module.zeros = module.zeros.half()
|
||||||
m.scales = m.scales.half()
|
module.scales = module.scales.half()
|
||||||
m.bias = m.bias.half()
|
module.bias = module.bias.half()
|
||||||
|
|
||||||
if (
|
if (
|
||||||
torch.cuda.device_count() > 1
|
torch.cuda.device_count() > 1
|
||||||
|
|||||||
Reference in New Issue
Block a user