From 0b7ba57ec42559bf75e5d1bc6ba58354a314d12e Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 25 Aug 2023 02:03:24 -0400 Subject: [PATCH] fix types w lora (#478) --- src/axolotl/utils/models.py | 35 ++++++++++++++++++----------------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 64c80109e..261acd934 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -11,7 +11,6 @@ import bitsandbytes as bnb import torch import transformers from optimum.bettertransformer import BetterTransformer -from peft.tuners.lora import LoraLayer from transformers import ( # noqa: F401 AutoConfig, AutoModelForCausalLM, @@ -348,6 +347,14 @@ def load_model( if model.device.type == "cuda": log_gpu_memory_usage(LOG, "after model load", model.device) + # make sure these are fp32 per Ramesh et al. (2021) + for name, module in model.named_modules(): + if "norm" in name: + module.to(torch.float32) + if "lm_head" in name or "embed_tokens" in name: + if hasattr(module, "weight"): + module.to(torch.float32) + if not cfg.gptq and ( (cfg.adapter == "lora" and load_in_8bit) or (cfg.adapter == "qlora" and cfg.load_in_4bit) @@ -357,6 +364,16 @@ def load_model( model, use_gradient_checkpointing=cfg.gradient_checkpointing ) + # LlamaRMSNorm layers are in fp32 after kbit_training, so we need to + # convert them back to fp16/bf16 for flash-attn compatibility. + if cfg.flash_attention and cfg.is_llama_derived_model: + for name, module in model.named_modules(): + if "norm" in name: + module.to(cfg.torch_dtype) + if "lm_head" in name or "embed_tokens" in name: + if hasattr(module, "weight"): + module.to(cfg.torch_dtype) + model, lora_config = load_adapter(model, cfg, cfg.adapter) if cfg.ddp and not load_in_8bit: @@ -500,22 +517,6 @@ def load_lora(model, cfg): else: model = get_peft_model(model, lora_config) - for name, module in model.named_modules(): - if isinstance(module, LoraLayer): - module = module.to(cfg.torch_dtype) - if "norm" in name: - module = module.to(torch.float32) - if "lm_head" in name or "embed_tokens" in name: - if hasattr(module, "weight"): - module = module.to(cfg.torch_dtype) - - # LlamaRMSNorm layers are in fp32 after kbit_training, so we need to - # convert them back to fp16/bf16 for flash-attn compatibility. - if cfg.flash_attention and cfg.is_llama_derived_model: - for name, module in model.named_modules(): - if "norm" in name: - module = module.to(cfg.torch_dtype) - model.print_trainable_parameters() return model, lora_config