From 77085ea24e33c5f2676e04e09a3e857504753554 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 1 Aug 2023 23:26:16 -0400 Subject: [PATCH] qlora w flash attention fixes (#333) --- src/axolotl/utils/models.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index fa6a7c1c3..2c199a16e 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -407,6 +407,14 @@ def load_llama_adapter(model, cfg): else: model = get_peft_model(model, peft_config) + if cfg.flash_attention: + for name, module in model.named_modules(): + if "norm" in name: + module.to(torch.float16) + if "lm_head" in name or "embed_tokens" in name: + if hasattr(module, "weight"): + module.to(torch.float16) + model.print_trainable_parameters() return model, peft_config