remove dead gemma4 branch in _set_attention_config
This commit is contained in:
@@ -633,20 +633,13 @@ class ModelLoader:
|
||||
# replaces F.scaled_dot_product_attention post-load, so load under sdpa.
|
||||
# Every other canonical name (and hub-kernel paths) is passed through
|
||||
# verbatim — xformers/sage/flash_attention_* are registered under their
|
||||
# own names in ALL_ATTENTION_FUNCTIONS before model load.
|
||||
# own names in ALL_ATTENTION_FUNCTIONS before model load. gemma4_hybrid
|
||||
# is already pinned to flash_attention_2 by normalize_attn_implementation.
|
||||
_LOAD_TIME_OVERRIDE = {"s2": "flash_attention_2", "fp8": "sdpa"}
|
||||
if self.cfg.gemma4_hybrid_attn_impl:
|
||||
# Load with flash_attention_2 for sliding-window layers; global
|
||||
# layers are swapped to sdpa post-load.
|
||||
hf_impl = "flash_attention_2"
|
||||
elif self.cfg.attn_implementation:
|
||||
if self.cfg.attn_implementation:
|
||||
hf_impl = _LOAD_TIME_OVERRIDE.get(
|
||||
self.cfg.attn_implementation, self.cfg.attn_implementation
|
||||
)
|
||||
else:
|
||||
hf_impl = None
|
||||
|
||||
if hf_impl is not None:
|
||||
self.model_kwargs["attn_implementation"] = hf_impl
|
||||
self.model_config._attn_implementation = hf_impl
|
||||
|
||||
|
||||
Reference in New Issue
Block a user