diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 41a3582ea..8f148a342 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -324,6 +324,10 @@ def load_model( model_config._attn_implementation = ( # pylint: disable=protected-access "flash_attention_2" ) + else: + model_config._attn_implementation = ( # pylint: disable=protected-access + "eager" + ) try: if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq: