diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index fa6a7c1c3..2c199a16e 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -407,6 +407,14 @@ def load_llama_adapter(model, cfg): else: model = get_peft_model(model, peft_config) + if cfg.flash_attention: + for name, module in model.named_modules(): + if "norm" in name: + module.to(torch.float16) + if "lm_head" in name or "embed_tokens" in name: + if hasattr(module, "weight"): + module.to(torch.float16) + model.print_trainable_parameters() return model, peft_config