add float16 docs and tweak typehints
This commit is contained in:
@@ -11,13 +11,14 @@ import bitsandbytes as bnb
|
||||
import torch
|
||||
import transformers
|
||||
from optimum.bettertransformer import BetterTransformer
|
||||
from transformers import PreTrainedModel # noqa: F401
|
||||
from transformers import (
|
||||
from transformers import ( # noqa: F401
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
BitsAndBytesConfig,
|
||||
LlamaConfig,
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizerBase,
|
||||
)
|
||||
|
||||
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
|
||||
@@ -71,7 +72,7 @@ def load_tokenizer(
|
||||
def load_model(
|
||||
base_model, base_model_config, model_type, tokenizer, cfg, adapter="lora"
|
||||
):
|
||||
# type: (str, str, str, AutoTokenizer, DictDefault, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
||||
# type: (str, str, str, PreTrainedTokenizerBase, DictDefault, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
||||
"""
|
||||
Load a model from a base model and a model type.
|
||||
"""
|
||||
@@ -284,6 +285,7 @@ def load_model(
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||
torch_dtype=torch_dtype,
|
||||
device_map=cfg.device_map,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
|
||||
Reference in New Issue
Block a user