From 22f4eafa557bc5009877443c601e40a762832c2b Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 23 Aug 2024 20:23:08 -0400 Subject: [PATCH] simplify logic (#1856) --- src/axolotl/utils/models.py | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 6261ce20f..e18330199 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -589,19 +589,12 @@ def load_model( # sample packing uses custom FA2 patch if cfg.flash_attention: - if not cfg.sample_packing: - if cfg.s2_attention: - pass - # most other models support flash attention, we can define exceptions as they come up - model_kwargs["attn_implementation"] = "flash_attention_2" - model_config._attn_implementation = ( # pylint: disable=protected-access - "flash_attention_2" - ) - else: - model_kwargs["attn_implementation"] = "flash_attention_2" - model_config._attn_implementation = ( # pylint: disable=protected-access - "flash_attention_2" - ) + if not cfg.sample_packing and cfg.s2_attention: + pass + model_kwargs["attn_implementation"] = "flash_attention_2" + model_config._attn_implementation = ( # pylint: disable=protected-access + "flash_attention_2" + ) elif cfg.sdp_attention: model_kwargs["attn_implementation"] = "sdpa" model_config._attn_implementation = "sdpa" # pylint: disable=protected-access