flex attention support
This commit is contained in:
@@ -710,7 +710,13 @@ class ModelLoader:
|
||||
"""
|
||||
sample packing uses custom FA2 patch
|
||||
"""
|
||||
if self.cfg.flash_attention:
|
||||
|
||||
if self.cfg.flex_attention:
|
||||
self.model_kwargs["attn_implementation"] = "flex_attention"
|
||||
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||
"flex_attention"
|
||||
)
|
||||
elif self.cfg.flash_attention:
|
||||
if not self.cfg.sample_packing and self.cfg.s2_attention:
|
||||
pass
|
||||
self.model_kwargs["attn_implementation"] = "flash_attention_2"
|
||||
|
||||
Reference in New Issue
Block a user