diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index 4bded9b02..8d409527e 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -90,37 +90,43 @@ def replace_llama_attn_with_flash_attn( llama_model_forward ) - # skip only if explicitly disabled if cross_entropy: - try: - from flash_attn.losses.cross_entropy import CrossEntropyLoss + patch_cross_entropy() - LOG.info("patching with flash_attn.losses.cross_entropy") - transformers.models.llama.modeling_llama.CrossEntropyLoss = partial( - CrossEntropyLoss, inplace_backward=True - ) - except ImportError: - LOG.info( - "optimized flash-attention CrossEntropyLoss not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy'`)" - ) - - # skip only if explicitly disabled if rms_norm: - try: - from flash_attn.ops.rms_norm import RMSNorm + patch_rms_norm() - class LlamaRMSNorm(RMSNorm): - """Patched LLamaRMSNorm""" - def __init__(self, hidden_size, eps=1e-6): - super().__init__(hidden_size, eps=eps) +def patch_cross_entropy(): + try: + from flash_attn.losses.cross_entropy import CrossEntropyLoss - LOG.info("patching with flash_attn.ops.rms_norm") - transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm - except ImportError: - LOG.info( - "optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)" - ) + LOG.info("patching with flash_attn.losses.cross_entropy") + transformers.models.llama.modeling_llama.CrossEntropyLoss = partial( + CrossEntropyLoss, inplace_backward=True + ) + except ImportError: + LOG.info( + "optimized flash-attention CrossEntropyLoss not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy'`)" + ) + + +def patch_rms_norm(): + try: + from flash_attn.ops.rms_norm import RMSNorm + + class LlamaRMSNorm(RMSNorm): + """Patched LLamaRMSNorm""" + + def __init__(self, hidden_size, eps=1e-6): + super().__init__(hidden_size, eps=eps) + + LOG.info("patching with flash_attn.ops.rms_norm") + transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm + except ImportError: + LOG.info( + "optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)" + ) class FusedAttention(LlamaAttention): diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 72427f645..2fee9cf5a 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -24,6 +24,12 @@ from transformers import ( # noqa: F401 from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from axolotl.models.mamba import fix_mamba_attn_for_loss +from axolotl.monkeypatch.llama_attn_hijack_flash import ( + patch_cross_entropy as llama_patch_cross_entropy, +) +from axolotl.monkeypatch.llama_attn_hijack_flash import ( + patch_rms_norm as llama_patch_rms_norm, +) from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.chat_templates import chat_templates @@ -281,15 +287,7 @@ def load_model( replace_llama_attn_with_flash_attn, ) - if cfg.sample_packing: - if cfg.device not in ["mps", "cpu"] and not inference: - LOG.info("patching with flash attention for sample packing") - replace_llama_attn_with_flash_attn( - packed=True, - cross_entropy=cfg.flash_attn_cross_entropy, - rms_norm=cfg.flash_attn_rms_norm, - ) - elif cfg.s2_attention: + if cfg.s2_attention: LOG.info("patching w/ flash-enabled, shifted-sparse attention") replace_llama_attn_with_flash_attn( packed=False, @@ -297,6 +295,21 @@ def load_model( rms_norm=cfg.flash_attn_rms_norm, use_shifted_sparse_attn=True, ) + elif cfg.device not in ["mps", "cpu"] and not inference: + if cfg.sample_packing: + LOG.info("patching with flash attention for sample packing") + replace_llama_attn_with_flash_attn( + packed=True, + cross_entropy=cfg.flash_attn_cross_entropy, + rms_norm=cfg.flash_attn_rms_norm, + ) + else: + if cfg.flash_attn_cross_entropy: + llama_patch_cross_entropy() + + if cfg.flash_attn_rms_norm: + llama_patch_rms_norm() + elif cfg.xformers_attention: from axolotl.monkeypatch.llama_attn_hijack_xformers import ( hijack_llama_attention,