From fefa95e35069a01c96583853e075bf0319e55e0a Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 22 Aug 2024 16:39:23 -0400 Subject: [PATCH] most model types now support flash attention 2 regardless of multipack support (#1854) --- src/axolotl/monkeypatch/multipack.py | 1 + src/axolotl/utils/models.py | 14 ++++---------- 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/src/axolotl/monkeypatch/multipack.py b/src/axolotl/monkeypatch/multipack.py index 904352010..44fc4cb47 100644 --- a/src/axolotl/monkeypatch/multipack.py +++ b/src/axolotl/monkeypatch/multipack.py @@ -17,6 +17,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [ "qwen2_moe", "falcon", "phi", + "phi3", "gemma", "gemma2", "gemmoe", diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 4f47d59bf..8d24524a2 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -591,16 +591,10 @@ def load_model( "flash_attention_2" ) else: - if model_config.model_type in SUPPORTED_MULTIPACK_MODEL_TYPES: - model_kwargs["attn_implementation"] = "flash_attention_2" - model_config._attn_implementation = ( # pylint: disable=protected-access - "flash_attention_2" - ) - else: - model_kwargs["attn_implementation"] = "eager" - model_config._attn_implementation = ( # pylint: disable=protected-access - "eager" - ) + model_kwargs["attn_implementation"] = "flash_attention_2" + model_config._attn_implementation = ( # pylint: disable=protected-access + "flash_attention_2" + ) elif cfg.sdp_attention: model_kwargs["attn_implementation"] = "sdpa" model_config._attn_implementation = "sdpa" # pylint: disable=protected-access