flex attention support

This commit is contained in:
Sunny
2025-01-06 19:54:11 -05:00
parent 61ad375bf4
commit bcd9ad44e0

View File

@@ -710,7 +710,13 @@ class ModelLoader:
"""
sample packing uses custom FA2 patch
"""
if self.cfg.flash_attention:
if self.cfg.flex_attention:
self.model_kwargs["attn_implementation"] = "flex_attention"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"flex_attention"
)
elif self.cfg.flash_attention:
if not self.cfg.sample_packing and self.cfg.s2_attention:
pass
self.model_kwargs["attn_implementation"] = "flash_attention_2"