remove dead gemma4 branch in _set_attention_config
This commit is contained in:
@@ -633,20 +633,13 @@ class ModelLoader:
|
|||||||
# replaces F.scaled_dot_product_attention post-load, so load under sdpa.
|
# replaces F.scaled_dot_product_attention post-load, so load under sdpa.
|
||||||
# Every other canonical name (and hub-kernel paths) is passed through
|
# Every other canonical name (and hub-kernel paths) is passed through
|
||||||
# verbatim — xformers/sage/flash_attention_* are registered under their
|
# verbatim — xformers/sage/flash_attention_* are registered under their
|
||||||
# own names in ALL_ATTENTION_FUNCTIONS before model load.
|
# own names in ALL_ATTENTION_FUNCTIONS before model load. gemma4_hybrid
|
||||||
|
# is already pinned to flash_attention_2 by normalize_attn_implementation.
|
||||||
_LOAD_TIME_OVERRIDE = {"s2": "flash_attention_2", "fp8": "sdpa"}
|
_LOAD_TIME_OVERRIDE = {"s2": "flash_attention_2", "fp8": "sdpa"}
|
||||||
if self.cfg.gemma4_hybrid_attn_impl:
|
if self.cfg.attn_implementation:
|
||||||
# Load with flash_attention_2 for sliding-window layers; global
|
|
||||||
# layers are swapped to sdpa post-load.
|
|
||||||
hf_impl = "flash_attention_2"
|
|
||||||
elif self.cfg.attn_implementation:
|
|
||||||
hf_impl = _LOAD_TIME_OVERRIDE.get(
|
hf_impl = _LOAD_TIME_OVERRIDE.get(
|
||||||
self.cfg.attn_implementation, self.cfg.attn_implementation
|
self.cfg.attn_implementation, self.cfg.attn_implementation
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
hf_impl = None
|
|
||||||
|
|
||||||
if hf_impl is not None:
|
|
||||||
self.model_kwargs["attn_implementation"] = hf_impl
|
self.model_kwargs["attn_implementation"] = hf_impl
|
||||||
self.model_config._attn_implementation = hf_impl
|
self.model_config._attn_implementation = hf_impl
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user