recast loralayer, norm, lmhead + embed token weights per original qlora (#393)

* recast loralayer, norm, lmhead + embed token weights per original qlora

* try again for the fix

* refactor torch dtype picking

* linter fixes

* missing import for LoraLayer

* fix install for tests now that peft is involved
This commit is contained in:
Wing Lian
2023-08-21 18:41:12 -04:00
committed by GitHub
parent 50682a3c06
commit 96deb6bd67
4 changed files with 33 additions and 22 deletions

View File

@@ -24,7 +24,7 @@ jobs:
- name: Install dependencies
run: |
pip install -e .
pip install -e .[peft]
pip install -r requirements-tests.txt
- name: Run tests

View File

@@ -32,5 +32,8 @@ setup(
"extras": [
"deepspeed",
],
"peft": [
"peft @ git+https://github.com/huggingface/peft.git",
],
},
)

View File

@@ -62,6 +62,13 @@ def normalize_config(cfg):
else:
torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False
if cfg.bf16 or cfg.bfloat16:
cfg.torch_dtype = torch.bfloat16
elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
cfg.torch_dtype = torch.float16
else:
cfg.torch_dtype = torch.float32
log_gpu_memory_usage(LOG, "baseline", cfg.device)

View File

@@ -11,6 +11,7 @@ import bitsandbytes as bnb
import torch
import transformers
from optimum.bettertransformer import BetterTransformer
from peft.tuners.lora import LoraLayer
from transformers import ( # noqa: F401
AutoConfig,
AutoModelForCausalLM,
@@ -146,12 +147,6 @@ def load_model(
LOG.info("patching _expand_mask")
hijack_expand_mask()
if cfg.bf16 or cfg.bfloat16:
torch_dtype = torch.bfloat16
elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
try:
if cfg.gptq:
from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import (
@@ -183,7 +178,7 @@ def load_model(
load_in_4bit=True,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=torch_dtype,
bnb_4bit_compute_dtype=cfg.torch_dtype,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
@@ -242,7 +237,7 @@ def load_model(
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=torch_dtype,
torch_dtype=cfg.torch_dtype,
**model_kwargs,
)
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
@@ -277,7 +272,7 @@ def load_model(
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=torch_dtype,
torch_dtype=cfg.torch_dtype,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
@@ -308,7 +303,7 @@ def load_model(
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=torch_dtype,
torch_dtype=cfg.torch_dtype,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
@@ -322,7 +317,7 @@ def load_model(
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=torch_dtype,
torch_dtype=cfg.torch_dtype,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
@@ -356,16 +351,6 @@ def load_model(
model, use_gradient_checkpointing=cfg.gradient_checkpointing
)
# LlamaRMSNorm layers are in fp32 after kbit_training, so we need to
# convert them back to fp16/bf16 for flash-attn compatibility.
if cfg.flash_attention and cfg.is_llama_derived_model:
for name, module in model.named_modules():
if "norm" in name:
module.to(torch_dtype)
if "lm_head" in name or "embed_tokens" in name:
if hasattr(module, "weight"):
module.to(torch_dtype)
model, lora_config = load_adapter(model, cfg, cfg.adapter)
if cfg.ddp and not load_in_8bit:
@@ -509,6 +494,22 @@ def load_lora(model, cfg):
else:
model = get_peft_model(model, lora_config)
for name, module in model.named_modules():
if isinstance(module, LoraLayer):
module = module.to(cfg.torch_dtype)
if "norm" in name:
module = module.to(torch.float32)
if "lm_head" in name or "embed_tokens" in name:
if hasattr(module, "weight"):
module = module.to(cfg.torch_dtype)
# LlamaRMSNorm layers are in fp32 after kbit_training, so we need to
# convert them back to fp16/bf16 for flash-attn compatibility.
if cfg.flash_attention and cfg.is_llama_derived_model:
for name, module in model.named_modules():
if "norm" in name:
module = module.to(cfg.torch_dtype)
model.print_trainable_parameters()
return model, lora_config