Implement fused modules (#747)
* MLP: Memory saving * Remove RMSNorm restrictions * Map packed weights to original * FusedAttention module * Simplify code * Move fused modules * Fix critical typo * Split inplace * Add FFT config * Add validation of fused arguments * Add fused arguments to config * Update docs * Fix validation logic * Add fused modules to flash attn * Only fuse during training * Remove timing * Formatting * Formatting * Formatting * chore: lint * chore: lint * add e2e tests for fused llama * no lora for tests --------- Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
@@ -272,6 +272,20 @@ def load_model(
|
||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
if cfg.flash_attention and not inference:
|
||||
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||
replace_llama_mlp_with_swiglu,
|
||||
replace_llama_qkv_with_fused,
|
||||
)
|
||||
|
||||
if cfg.flash_attn_fuse_mlp:
|
||||
LOG.info("patching with SwiGLU")
|
||||
replace_llama_mlp_with_swiglu(model)
|
||||
|
||||
if cfg.flash_attn_fuse_qkv:
|
||||
LOG.info("patching with fused QKV")
|
||||
replace_llama_qkv_with_fused(model)
|
||||
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
|
||||
# This is a WIP, still an issue with the backward pass
|
||||
# RuntimeError: grad can be implicitly created only for scalar outputs
|
||||
|
||||
Reference in New Issue
Block a user