simplify logic (#1856)
This commit is contained in:
@@ -589,19 +589,12 @@ def load_model(
|
|||||||
|
|
||||||
# sample packing uses custom FA2 patch
|
# sample packing uses custom FA2 patch
|
||||||
if cfg.flash_attention:
|
if cfg.flash_attention:
|
||||||
if not cfg.sample_packing:
|
if not cfg.sample_packing and cfg.s2_attention:
|
||||||
if cfg.s2_attention:
|
pass
|
||||||
pass
|
model_kwargs["attn_implementation"] = "flash_attention_2"
|
||||||
# most other models support flash attention, we can define exceptions as they come up
|
model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||||
model_kwargs["attn_implementation"] = "flash_attention_2"
|
"flash_attention_2"
|
||||||
model_config._attn_implementation = ( # pylint: disable=protected-access
|
)
|
||||||
"flash_attention_2"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
model_kwargs["attn_implementation"] = "flash_attention_2"
|
|
||||||
model_config._attn_implementation = ( # pylint: disable=protected-access
|
|
||||||
"flash_attention_2"
|
|
||||||
)
|
|
||||||
elif cfg.sdp_attention:
|
elif cfg.sdp_attention:
|
||||||
model_kwargs["attn_implementation"] = "sdpa"
|
model_kwargs["attn_implementation"] = "sdpa"
|
||||||
model_config._attn_implementation = "sdpa" # pylint: disable=protected-access
|
model_config._attn_implementation = "sdpa" # pylint: disable=protected-access
|
||||||
|
|||||||
Reference in New Issue
Block a user