update doc snippets + reject gemma4-hybrid with non-FA2 backend
This commit is contained in:
@@ -116,11 +116,24 @@ class TestGemma4Hybrid:
|
||||
)
|
||||
assert result["attn_implementation"] == "flash_attention_2"
|
||||
|
||||
def test_gemma4_hybrid_respects_explicit(self):
|
||||
def test_gemma4_hybrid_with_incompatible_impl_raises(self):
|
||||
"""Setting gemma4_hybrid alongside a non-FA2 attn_implementation is a
|
||||
configuration error — the hybrid path requires FA2 under the hood."""
|
||||
with pytest.raises(
|
||||
ValueError, match="requires attn_implementation=flash_attention_2"
|
||||
):
|
||||
AxolotlInputConfig.normalize_attn_implementation(
|
||||
{"gemma4_hybrid_attn_impl": True, "attn_implementation": "sdpa"}
|
||||
)
|
||||
|
||||
def test_gemma4_hybrid_with_explicit_fa2_passes(self):
|
||||
result = AxolotlInputConfig.normalize_attn_implementation(
|
||||
{"gemma4_hybrid_attn_impl": True, "attn_implementation": "sdpa"}
|
||||
{
|
||||
"gemma4_hybrid_attn_impl": True,
|
||||
"attn_implementation": "flash_attention_2",
|
||||
}
|
||||
)
|
||||
assert result["attn_implementation"] == "sdpa"
|
||||
assert result["attn_implementation"] == "flash_attention_2"
|
||||
|
||||
|
||||
class TestFieldValidator:
|
||||
|
||||
Reference in New Issue
Block a user