fix types w lora (#478)
This commit is contained in:
@@ -11,7 +11,6 @@ import bitsandbytes as bnb
|
||||
import torch
|
||||
import transformers
|
||||
from optimum.bettertransformer import BetterTransformer
|
||||
from peft.tuners.lora import LoraLayer
|
||||
from transformers import ( # noqa: F401
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
@@ -348,6 +347,14 @@ def load_model(
|
||||
if model.device.type == "cuda":
|
||||
log_gpu_memory_usage(LOG, "after model load", model.device)
|
||||
|
||||
# make sure these are fp32 per Ramesh et al. (2021)
|
||||
for name, module in model.named_modules():
|
||||
if "norm" in name:
|
||||
module.to(torch.float32)
|
||||
if "lm_head" in name or "embed_tokens" in name:
|
||||
if hasattr(module, "weight"):
|
||||
module.to(torch.float32)
|
||||
|
||||
if not cfg.gptq and (
|
||||
(cfg.adapter == "lora" and load_in_8bit)
|
||||
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
|
||||
@@ -357,6 +364,16 @@ def load_model(
|
||||
model, use_gradient_checkpointing=cfg.gradient_checkpointing
|
||||
)
|
||||
|
||||
# LlamaRMSNorm layers are in fp32 after kbit_training, so we need to
|
||||
# convert them back to fp16/bf16 for flash-attn compatibility.
|
||||
if cfg.flash_attention and cfg.is_llama_derived_model:
|
||||
for name, module in model.named_modules():
|
||||
if "norm" in name:
|
||||
module.to(cfg.torch_dtype)
|
||||
if "lm_head" in name or "embed_tokens" in name:
|
||||
if hasattr(module, "weight"):
|
||||
module.to(cfg.torch_dtype)
|
||||
|
||||
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
||||
|
||||
if cfg.ddp and not load_in_8bit:
|
||||
@@ -500,22 +517,6 @@ def load_lora(model, cfg):
|
||||
else:
|
||||
model = get_peft_model(model, lora_config)
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, LoraLayer):
|
||||
module = module.to(cfg.torch_dtype)
|
||||
if "norm" in name:
|
||||
module = module.to(torch.float32)
|
||||
if "lm_head" in name or "embed_tokens" in name:
|
||||
if hasattr(module, "weight"):
|
||||
module = module.to(cfg.torch_dtype)
|
||||
|
||||
# LlamaRMSNorm layers are in fp32 after kbit_training, so we need to
|
||||
# convert them back to fp16/bf16 for flash-attn compatibility.
|
||||
if cfg.flash_attention and cfg.is_llama_derived_model:
|
||||
for name, module in model.named_modules():
|
||||
if "norm" in name:
|
||||
module = module.to(cfg.torch_dtype)
|
||||
|
||||
model.print_trainable_parameters()
|
||||
|
||||
return model, lora_config
|
||||
|
||||
Reference in New Issue
Block a user