flex attention support
This commit is contained in:
@@ -710,7 +710,13 @@ class ModelLoader:
|
|||||||
"""
|
"""
|
||||||
sample packing uses custom FA2 patch
|
sample packing uses custom FA2 patch
|
||||||
"""
|
"""
|
||||||
if self.cfg.flash_attention:
|
|
||||||
|
if self.cfg.flex_attention:
|
||||||
|
self.model_kwargs["attn_implementation"] = "flex_attention"
|
||||||
|
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||||
|
"flex_attention"
|
||||||
|
)
|
||||||
|
elif self.cfg.flash_attention:
|
||||||
if not self.cfg.sample_packing and self.cfg.s2_attention:
|
if not self.cfg.sample_packing and self.cfg.s2_attention:
|
||||||
pass
|
pass
|
||||||
self.model_kwargs["attn_implementation"] = "flash_attention_2"
|
self.model_kwargs["attn_implementation"] = "flash_attention_2"
|
||||||
|
|||||||
Reference in New Issue
Block a user