Merge pull request #36 from OpenAccess-AI-Collective/qlora

Qlora
This commit is contained in:
Wing Lian
2023-05-24 22:57:37 -04:00
committed by GitHub
5 changed files with 35 additions and 4 deletions

View File

@@ -1,10 +1,10 @@
peft @ git+https://github.com/huggingface/peft.git
transformers @ git+https://github.com/huggingface/transformers.git
bitsandbytes>=0.39.0
attrdict
fire
PyYAML==6.0
black
bitsandbytes==0.37.2
datasets
accelerate>=0.19.0
sentencepiece

View File

@@ -14,6 +14,7 @@ from attrdict import AttrDefault
# add src to the pythonpath so we don't need to pip install this
from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.validation import validate_config
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")
@@ -158,6 +159,8 @@ def train(
cfg.fp16 = True
cfg.bf16 = False
validate_config(cfg)
# Load the model and tokenizer
logging.info("loading model, tokenizer, and peft_config...")
model, tokenizer, peft_config = load_model(

View File

@@ -11,6 +11,7 @@ class PromptStyle(Enum):
instruct = "instruct"
chat = "chat"
class AlpacaPrompter:
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
@@ -50,6 +51,10 @@ class AlpacaPrompter:
return output.split(self.response_split)[1].strip()
class UnpromptedPrompter(AlpacaPrompter):
system_prompt = ""
system_no_input_prompt = ""
class JeopardyPrompter(AlpacaPrompter):
prompt_input = "Below is a Jeopardy clue paired with input providing the category of the clue. Write a concise response that best answers tbe clue given the category.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"

View File

@@ -98,6 +98,11 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa
ds = load_dataset("json", data_files=fp, streaming=False, split=None)
if not ds:
raise Exception("unhandled dataset load")
# support for using a subset of the data
if d.shards:
ds = ds.shuffle(seed=42)["train"].shard(
num_shards=cfg.shards, index=0
)
d_type = d.type
d_type_split = d_type.split(":")
d_base_type = d_type_split[0]

View File

@@ -6,11 +6,12 @@ from typing import Optional, Tuple, TYPE_CHECKING
import torch
import transformers
from torch import nn
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
PreTrainedModel,
AutoConfig,
AutoConfig, BitsAndBytesConfig,
)
try:
@@ -81,6 +82,16 @@ def load_model(
logging.exception(e)
raise e
model_kwargs = {}
if cfg.adapter == "qlora":
model_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
try:
if cfg.load_4bit and is_llama_derived_model:
from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram
@@ -123,8 +134,10 @@ def load_model(
model = LlamaForCausalLM.from_pretrained(
base_model,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=torch_dtype,
device_map=cfg.device_map,
**model_kwargs,
)
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
# This is a WIP, still an issue with the backward pass
@@ -156,9 +169,11 @@ def load_model(
model = getattr(transformers, model_type).from_pretrained(
base_model,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=torch_dtype,
device_map=cfg.device_map,
trust_remote_code=True if cfg.trust_remote_code is True else False,
**model_kwargs,
)
else:
config = AutoConfig.from_pretrained(
@@ -169,9 +184,11 @@ def load_model(
base_model,
config=config,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=torch_dtype,
device_map=cfg.device_map,
trust_remote_code=True if cfg.trust_remote_code is True else False,
**model_kwargs,
)
except Exception as e:
logging.error(
@@ -184,6 +201,7 @@ def load_model(
torch_dtype=torch_dtype,
device_map=cfg.device_map,
trust_remote_code=True if cfg.trust_remote_code is True else False,
**model_kwargs,
)
if not tokenizer:
@@ -225,7 +243,7 @@ def load_model(
embeddings_len = math.ceil(len(tokenizer) / 32) * 32
model.resize_token_embeddings(embeddings_len)
if cfg.adapter and load_in_8bit and not cfg.load_4bit:
if ((cfg.adapter == "lora" and load_in_8bit) or cfg.adapter == "qlora") and not cfg.load_4bit:
logging.info("converting PEFT model w/ prepare_model_for_int8_training")
model = prepare_model_for_int8_training(model)
@@ -270,7 +288,7 @@ def load_adapter(model, cfg, adapter):
if adapter is None:
return model, None
if adapter == "lora":
if adapter == "lora" or adapter == "qlora":
return load_lora(model, cfg)
if adapter == "llama-adapter":
return load_llama_adapter(model, cfg)