Files
axolotl/tests/prompt_strategies/test_jinja_template_analyzer.py
Dan Saunders c907ac173e adding pre-commit auto-update GH action and bumping plugin versions (#2428)
* adding pre-commit auto-update GH action and bumping plugin versions

* running updated pre-commit plugins

* sorry to revert, but pylint complained

* Update .pre-commit-config.yaml

Co-authored-by: Wing Lian <wing.lian@gmail.com>

---------

Co-authored-by: Dan Saunders <dan@axolotl.ai>
Co-authored-by: Wing Lian <wing.lian@gmail.com>
2025-03-21 11:02:43 -04:00

161 lines
5.8 KiB
Python

"""
tests for jinja_template_analyzer
"""
import logging
import pytest
from axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnalyzer
logging.basicConfig(level=logging.DEBUG)
LOG = logging.getLogger("axolotl")
class TestJinjaTemplateAnalyzer:
"""
tests for jinja_template_analyzer
"""
def test_basic_variable_extraction(self, basic_jinja_template_analyzer):
"""Test that all top-level variables are correctly extracted."""
LOG.info("Testing with train_on_inputs=True")
variables = basic_jinja_template_analyzer.get_template_variables()
expected_vars = {"messages", "add_generation_prompt", "eos_token", "message"}
assert set(variables.keys()) == expected_vars
def test_mixtral_variable_extraction(self, mistral_jinja_template_analyzer):
"""Test that all top-level variables are correctly extracted."""
LOG.info("Testing with train_on_inputs=True")
variables = mistral_jinja_template_analyzer.get_template_variables()
expected_vars = {
"messages",
"content",
"eos_token",
"message",
"tools",
"system_message",
"loop_messages",
"ns",
"tool_call",
"tool",
"loop",
"bos_token",
"raise_exception",
}
assert set(variables.keys()) == expected_vars
message_vars = variables["message"]
assert message_vars == {"role", "content", "tool_calls", "tool_call_id"}
def test_message_property_access(self, basic_jinja_template_analyzer):
"""Test that properties accessed on 'message' variable are correctly identified."""
LOG.info("Testing message property access")
variables = basic_jinja_template_analyzer.get_template_variables()
assert "messages" in variables
assert "message" in variables
assert "role" in variables["message"]
assert "content" in variables["message"]
def test_detailed_analysis(self, basic_jinja_template_analyzer):
"""Test the detailed analysis of variable usage."""
LOG.info("Testing detailed analysis")
analysis = basic_jinja_template_analyzer.analyze_template()
assert analysis["messages"]["is_iterated"] is True
assert "role" in analysis["message"]["accessed_properties"]
assert "content" in analysis["message"]["accessed_properties"]
assert analysis["add_generation_prompt"]["is_conditional"] is True
assert len(analysis["add_generation_prompt"]["accessed_properties"]) == 0
assert not analysis["eos_token"]["is_iterated"]
assert len(analysis["eos_token"]["accessed_properties"]) == 0
def test_nested_property_access(self):
"""Test handling of nested property access."""
LOG.info("Testing nested property access")
template = """{{ user.profile.name }}{{ user.settings['preference'] }}"""
analyzer = JinjaTemplateAnalyzer(template)
variables = analyzer.get_template_variables()
assert "user" in variables
assert "profile" in variables["user"]
assert "settings" in variables["user"]
def test_loop_variable_handling(self):
"""Test handling of loop variables and their properties."""
LOG.info("Testing loop variable handling")
template = """
{% for item in items %}
{{ item.name }}
{% for subitem in item.subitems %}
{{ subitem.value }}
{% endfor %}
{% endfor %}
"""
analyzer = JinjaTemplateAnalyzer(template)
analysis = analyzer.analyze_template()
assert analysis["items"]["is_iterated"]
assert "name" in analysis["item"]["accessed_properties"]
assert "subitems" in analysis["item"]["accessed_properties"]
def test_conditional_variable_usage(self):
"""Test detection of variables used in conditional statements."""
LOG.info("Testing conditional variable usage")
template = """
{% if user.is_admin and config.debug_mode %}
{{ debug_info }}
{% endif %}
"""
analyzer = JinjaTemplateAnalyzer(template)
analysis = analyzer.analyze_template()
assert analysis["user"]["is_conditional"]
assert analysis["config"]["is_conditional"]
assert "is_admin" in analysis["user"]["accessed_properties"]
assert "debug_mode" in analysis["config"]["accessed_properties"]
def test_complex_expressions(self):
"""Test handling of complex expressions and filters."""
LOG.info("Testing complex expressions and filters")
template = """
{{ user.name | upper }}
{{ messages | length > 0 and messages[0].content }}
{{ data['key'].nested['value'] }}
"""
analyzer = JinjaTemplateAnalyzer(template)
variables = analyzer.get_template_variables()
assert "user" in variables
assert "name" in variables["user"]
assert "messages" in variables
assert "content" in variables["messages"]
assert "data" in variables
def test_basic_msg_vars(self, basic_jinja_template_analyzer):
"""Test that the basic message variables are correctly identified."""
LOG.info("Testing basic message variables")
variables = basic_jinja_template_analyzer.get_message_vars()
assert variables == {"role", "content"}
def test_mixtral_msg_vars(self, mistral_jinja_template_analyzer):
"""Test that the mixtral message variables are correctly identified."""
LOG.info("Testing mixtral message variables")
variables = mistral_jinja_template_analyzer.get_message_vars()
assert variables == {"role", "content", "tool_calls", "tool_call_id"}
if __name__ == "__main__":
pytest.main([__file__])