most model types now support flash attention 2 regardless of multipack support (#1854)
This commit is contained in:
@@ -17,6 +17,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
|||||||
"qwen2_moe",
|
"qwen2_moe",
|
||||||
"falcon",
|
"falcon",
|
||||||
"phi",
|
"phi",
|
||||||
|
"phi3",
|
||||||
"gemma",
|
"gemma",
|
||||||
"gemma2",
|
"gemma2",
|
||||||
"gemmoe",
|
"gemmoe",
|
||||||
|
|||||||
@@ -591,16 +591,10 @@ def load_model(
|
|||||||
"flash_attention_2"
|
"flash_attention_2"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if model_config.model_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
|
model_kwargs["attn_implementation"] = "flash_attention_2"
|
||||||
model_kwargs["attn_implementation"] = "flash_attention_2"
|
model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||||
model_config._attn_implementation = ( # pylint: disable=protected-access
|
"flash_attention_2"
|
||||||
"flash_attention_2"
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
model_kwargs["attn_implementation"] = "eager"
|
|
||||||
model_config._attn_implementation = ( # pylint: disable=protected-access
|
|
||||||
"eager"
|
|
||||||
)
|
|
||||||
elif cfg.sdp_attention:
|
elif cfg.sdp_attention:
|
||||||
model_kwargs["attn_implementation"] = "sdpa"
|
model_kwargs["attn_implementation"] = "sdpa"
|
||||||
model_config._attn_implementation = "sdpa" # pylint: disable=protected-access
|
model_config._attn_implementation = "sdpa" # pylint: disable=protected-access
|
||||||
|
|||||||
Reference in New Issue
Block a user