recast loralayer, norm, lmhead + embed token weights per original qlora (#393)
* recast loralayer, norm, lmhead + embed token weights per original qlora * try again for the fix * refactor torch dtype picking * linter fixes * missing import for LoraLayer * fix install for tests now that peft is involved
This commit is contained in:
2
.github/workflows/tests.yml
vendored
2
.github/workflows/tests.yml
vendored
@@ -24,7 +24,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
pip install -e .
|
pip install -e .[peft]
|
||||||
pip install -r requirements-tests.txt
|
pip install -r requirements-tests.txt
|
||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
|
|||||||
3
setup.py
3
setup.py
@@ -32,5 +32,8 @@ setup(
|
|||||||
"extras": [
|
"extras": [
|
||||||
"deepspeed",
|
"deepspeed",
|
||||||
],
|
],
|
||||||
|
"peft": [
|
||||||
|
"peft @ git+https://github.com/huggingface/peft.git",
|
||||||
|
],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -62,6 +62,13 @@ def normalize_config(cfg):
|
|||||||
else:
|
else:
|
||||||
torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False
|
torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False
|
||||||
|
|
||||||
|
if cfg.bf16 or cfg.bfloat16:
|
||||||
|
cfg.torch_dtype = torch.bfloat16
|
||||||
|
elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
|
||||||
|
cfg.torch_dtype = torch.float16
|
||||||
|
else:
|
||||||
|
cfg.torch_dtype = torch.float32
|
||||||
|
|
||||||
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import bitsandbytes as bnb
|
|||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from optimum.bettertransformer import BetterTransformer
|
from optimum.bettertransformer import BetterTransformer
|
||||||
|
from peft.tuners.lora import LoraLayer
|
||||||
from transformers import ( # noqa: F401
|
from transformers import ( # noqa: F401
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
@@ -146,12 +147,6 @@ def load_model(
|
|||||||
LOG.info("patching _expand_mask")
|
LOG.info("patching _expand_mask")
|
||||||
hijack_expand_mask()
|
hijack_expand_mask()
|
||||||
|
|
||||||
if cfg.bf16 or cfg.bfloat16:
|
|
||||||
torch_dtype = torch.bfloat16
|
|
||||||
elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
|
|
||||||
torch_dtype = torch.float16
|
|
||||||
else:
|
|
||||||
torch_dtype = torch.float32
|
|
||||||
try:
|
try:
|
||||||
if cfg.gptq:
|
if cfg.gptq:
|
||||||
from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import (
|
from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import (
|
||||||
@@ -183,7 +178,7 @@ def load_model(
|
|||||||
load_in_4bit=True,
|
load_in_4bit=True,
|
||||||
llm_int8_threshold=6.0,
|
llm_int8_threshold=6.0,
|
||||||
llm_int8_has_fp16_weight=False,
|
llm_int8_has_fp16_weight=False,
|
||||||
bnb_4bit_compute_dtype=torch_dtype,
|
bnb_4bit_compute_dtype=cfg.torch_dtype,
|
||||||
bnb_4bit_use_double_quant=True,
|
bnb_4bit_use_double_quant=True,
|
||||||
bnb_4bit_quant_type="nf4",
|
bnb_4bit_quant_type="nf4",
|
||||||
)
|
)
|
||||||
@@ -242,7 +237,7 @@ def load_model(
|
|||||||
device_map=cfg.device_map,
|
device_map=cfg.device_map,
|
||||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=cfg.torch_dtype,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
|
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
|
||||||
@@ -277,7 +272,7 @@ def load_model(
|
|||||||
device_map=cfg.device_map,
|
device_map=cfg.device_map,
|
||||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=cfg.torch_dtype,
|
||||||
trust_remote_code=cfg.trust_remote_code or False,
|
trust_remote_code=cfg.trust_remote_code or False,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
@@ -308,7 +303,7 @@ def load_model(
|
|||||||
device_map=cfg.device_map,
|
device_map=cfg.device_map,
|
||||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=cfg.torch_dtype,
|
||||||
trust_remote_code=cfg.trust_remote_code or False,
|
trust_remote_code=cfg.trust_remote_code or False,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
@@ -322,7 +317,7 @@ def load_model(
|
|||||||
device_map=cfg.device_map,
|
device_map=cfg.device_map,
|
||||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=cfg.torch_dtype,
|
||||||
trust_remote_code=cfg.trust_remote_code or False,
|
trust_remote_code=cfg.trust_remote_code or False,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
@@ -356,16 +351,6 @@ def load_model(
|
|||||||
model, use_gradient_checkpointing=cfg.gradient_checkpointing
|
model, use_gradient_checkpointing=cfg.gradient_checkpointing
|
||||||
)
|
)
|
||||||
|
|
||||||
# LlamaRMSNorm layers are in fp32 after kbit_training, so we need to
|
|
||||||
# convert them back to fp16/bf16 for flash-attn compatibility.
|
|
||||||
if cfg.flash_attention and cfg.is_llama_derived_model:
|
|
||||||
for name, module in model.named_modules():
|
|
||||||
if "norm" in name:
|
|
||||||
module.to(torch_dtype)
|
|
||||||
if "lm_head" in name or "embed_tokens" in name:
|
|
||||||
if hasattr(module, "weight"):
|
|
||||||
module.to(torch_dtype)
|
|
||||||
|
|
||||||
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
||||||
|
|
||||||
if cfg.ddp and not load_in_8bit:
|
if cfg.ddp and not load_in_8bit:
|
||||||
@@ -509,6 +494,22 @@ def load_lora(model, cfg):
|
|||||||
else:
|
else:
|
||||||
model = get_peft_model(model, lora_config)
|
model = get_peft_model(model, lora_config)
|
||||||
|
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if isinstance(module, LoraLayer):
|
||||||
|
module = module.to(cfg.torch_dtype)
|
||||||
|
if "norm" in name:
|
||||||
|
module = module.to(torch.float32)
|
||||||
|
if "lm_head" in name or "embed_tokens" in name:
|
||||||
|
if hasattr(module, "weight"):
|
||||||
|
module = module.to(cfg.torch_dtype)
|
||||||
|
|
||||||
|
# LlamaRMSNorm layers are in fp32 after kbit_training, so we need to
|
||||||
|
# convert them back to fp16/bf16 for flash-attn compatibility.
|
||||||
|
if cfg.flash_attention and cfg.is_llama_derived_model:
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if "norm" in name:
|
||||||
|
module = module.to(cfg.torch_dtype)
|
||||||
|
|
||||||
model.print_trainable_parameters()
|
model.print_trainable_parameters()
|
||||||
|
|
||||||
return model, lora_config
|
return model, lora_config
|
||||||
|
|||||||
Reference in New Issue
Block a user