From 248bf90f89721a13bf869b7566c398ebc3380d49 Mon Sep 17 00:00:00 2001 From: Aman Karmani Date: Wed, 2 Aug 2023 20:15:03 +0000 Subject: [PATCH] ensure flash-attn fixes happen in both adapter/lora modes, and use torch_dtype --- src/axolotl/utils/models.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 2c199a16e..d4bda130c 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -331,6 +331,14 @@ def load_model( model, use_gradient_checkpointing=cfg.gradient_checkpointing ) + if cfg.flash_attention: + for name, module in model.named_modules(): + if "norm" in name: + module.to(torch_dtype) + if "lm_head" in name or "embed_tokens" in name: + if hasattr(module, "weight"): + module.to(torch_dtype) + model, lora_config = load_adapter(model, cfg, adapter) if cfg.ddp and not load_in_8bit: @@ -407,14 +415,6 @@ def load_llama_adapter(model, cfg): else: model = get_peft_model(model, peft_config) - if cfg.flash_attention: - for name, module in model.named_modules(): - if "norm" in name: - module.to(torch.float16) - if "lm_head" in name or "embed_tokens" in name: - if hasattr(module, "weight"): - module.to(torch.float16) - model.print_trainable_parameters() return model, peft_config