qlora w flash attention fixes (#333)

This commit is contained in:
Wing Lian
2023-08-01 23:26:16 -04:00
committed by GitHub
parent db2a3586f3
commit 77085ea24e

View File

@@ -407,6 +407,14 @@ def load_llama_adapter(model, cfg):
else:
model = get_peft_model(model, peft_config)
if cfg.flash_attention:
for name, module in model.named_modules():
if "norm" in name:
module.to(torch.float16)
if "lm_head" in name or "embed_tokens" in name:
if hasattr(module, "weight"):
module.to(torch.float16)
model.print_trainable_parameters()
return model, peft_config