Merge pull request #336 from tmm1/flash-attn
Fix flash-attn + qlora not working with llama models
This commit is contained in:
@@ -92,7 +92,9 @@ def load_model(
|
||||
|
||||
if cfg.is_llama_derived_model and cfg.flash_attention:
|
||||
if cfg.device not in ["mps", "cpu"] and not cfg.inference:
|
||||
from axolotl.flash_attn import replace_llama_attn_with_flash_attn
|
||||
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||
replace_llama_attn_with_flash_attn,
|
||||
)
|
||||
|
||||
LOG.info("patching with flash attention")
|
||||
replace_llama_attn_with_flash_attn()
|
||||
@@ -331,6 +333,16 @@ def load_model(
|
||||
model, use_gradient_checkpointing=cfg.gradient_checkpointing
|
||||
)
|
||||
|
||||
# LlamaRMSNorm layers are in fp32 after kbit_training, so we need to
|
||||
# convert them back to fp16/bf16 for flash-attn compatibility.
|
||||
if cfg.flash_attention and cfg.is_llama_derived_model:
|
||||
for name, module in model.named_modules():
|
||||
if "norm" in name:
|
||||
module.to(torch_dtype)
|
||||
if "lm_head" in name or "embed_tokens" in name:
|
||||
if hasattr(module, "weight"):
|
||||
module.to(torch_dtype)
|
||||
|
||||
model, lora_config = load_adapter(model, cfg, adapter)
|
||||
|
||||
if cfg.ddp and not load_in_8bit:
|
||||
@@ -407,14 +419,6 @@ def load_llama_adapter(model, cfg):
|
||||
else:
|
||||
model = get_peft_model(model, peft_config)
|
||||
|
||||
if cfg.flash_attention:
|
||||
for name, module in model.named_modules():
|
||||
if "norm" in name:
|
||||
module.to(torch.float16)
|
||||
if "lm_head" in name or "embed_tokens" in name:
|
||||
if hasattr(module, "weight"):
|
||||
module.to(torch.float16)
|
||||
|
||||
model.print_trainable_parameters()
|
||||
|
||||
return model, peft_config
|
||||
|
||||
Reference in New Issue
Block a user