diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 2c199a16e..d4bda130c 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -331,6 +331,14 @@ def load_model( model, use_gradient_checkpointing=cfg.gradient_checkpointing ) + if cfg.flash_attention: + for name, module in model.named_modules(): + if "norm" in name: + module.to(torch_dtype) + if "lm_head" in name or "embed_tokens" in name: + if hasattr(module, "weight"): + module.to(torch_dtype) + model, lora_config = load_adapter(model, cfg, adapter) if cfg.ddp and not load_in_8bit: @@ -407,14 +415,6 @@ def load_llama_adapter(model, cfg): else: model = get_peft_model(model, peft_config) - if cfg.flash_attention: - for name, module in model.named_modules(): - if "norm" in name: - module.to(torch.float16) - if "lm_head" in name or "embed_tokens" in name: - if hasattr(module, "weight"): - module.to(torch.float16) - model.print_trainable_parameters() return model, peft_config