* adding pre-commit auto-update GH action and bumping plugin versions * running updated pre-commit plugins * sorry to revert, but pylint complained * Update .pre-commit-config.yaml Co-authored-by: Wing Lian <wing.lian@gmail.com> --------- Co-authored-by: Dan Saunders <dan@axolotl.ai> Co-authored-by: Wing Lian <wing.lian@gmail.com>
161 lines
5.8 KiB
Python
161 lines
5.8 KiB
Python
"""
|
|
tests for jinja_template_analyzer
|
|
"""
|
|
|
|
import logging
|
|
|
|
import pytest
|
|
|
|
from axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnalyzer
|
|
|
|
logging.basicConfig(level=logging.DEBUG)
|
|
LOG = logging.getLogger("axolotl")
|
|
|
|
|
|
class TestJinjaTemplateAnalyzer:
|
|
"""
|
|
tests for jinja_template_analyzer
|
|
"""
|
|
|
|
def test_basic_variable_extraction(self, basic_jinja_template_analyzer):
|
|
"""Test that all top-level variables are correctly extracted."""
|
|
LOG.info("Testing with train_on_inputs=True")
|
|
|
|
variables = basic_jinja_template_analyzer.get_template_variables()
|
|
expected_vars = {"messages", "add_generation_prompt", "eos_token", "message"}
|
|
assert set(variables.keys()) == expected_vars
|
|
|
|
def test_mixtral_variable_extraction(self, mistral_jinja_template_analyzer):
|
|
"""Test that all top-level variables are correctly extracted."""
|
|
LOG.info("Testing with train_on_inputs=True")
|
|
|
|
variables = mistral_jinja_template_analyzer.get_template_variables()
|
|
expected_vars = {
|
|
"messages",
|
|
"content",
|
|
"eos_token",
|
|
"message",
|
|
"tools",
|
|
"system_message",
|
|
"loop_messages",
|
|
"ns",
|
|
"tool_call",
|
|
"tool",
|
|
"loop",
|
|
"bos_token",
|
|
"raise_exception",
|
|
}
|
|
assert set(variables.keys()) == expected_vars
|
|
message_vars = variables["message"]
|
|
assert message_vars == {"role", "content", "tool_calls", "tool_call_id"}
|
|
|
|
def test_message_property_access(self, basic_jinja_template_analyzer):
|
|
"""Test that properties accessed on 'message' variable are correctly identified."""
|
|
LOG.info("Testing message property access")
|
|
|
|
variables = basic_jinja_template_analyzer.get_template_variables()
|
|
assert "messages" in variables
|
|
assert "message" in variables
|
|
assert "role" in variables["message"]
|
|
assert "content" in variables["message"]
|
|
|
|
def test_detailed_analysis(self, basic_jinja_template_analyzer):
|
|
"""Test the detailed analysis of variable usage."""
|
|
LOG.info("Testing detailed analysis")
|
|
|
|
analysis = basic_jinja_template_analyzer.analyze_template()
|
|
|
|
assert analysis["messages"]["is_iterated"] is True
|
|
assert "role" in analysis["message"]["accessed_properties"]
|
|
assert "content" in analysis["message"]["accessed_properties"]
|
|
|
|
assert analysis["add_generation_prompt"]["is_conditional"] is True
|
|
assert len(analysis["add_generation_prompt"]["accessed_properties"]) == 0
|
|
|
|
assert not analysis["eos_token"]["is_iterated"]
|
|
assert len(analysis["eos_token"]["accessed_properties"]) == 0
|
|
|
|
def test_nested_property_access(self):
|
|
"""Test handling of nested property access."""
|
|
LOG.info("Testing nested property access")
|
|
|
|
template = """{{ user.profile.name }}{{ user.settings['preference'] }}"""
|
|
analyzer = JinjaTemplateAnalyzer(template)
|
|
variables = analyzer.get_template_variables()
|
|
|
|
assert "user" in variables
|
|
assert "profile" in variables["user"]
|
|
assert "settings" in variables["user"]
|
|
|
|
def test_loop_variable_handling(self):
|
|
"""Test handling of loop variables and their properties."""
|
|
LOG.info("Testing loop variable handling")
|
|
|
|
template = """
|
|
{% for item in items %}
|
|
{{ item.name }}
|
|
{% for subitem in item.subitems %}
|
|
{{ subitem.value }}
|
|
{% endfor %}
|
|
{% endfor %}
|
|
"""
|
|
analyzer = JinjaTemplateAnalyzer(template)
|
|
analysis = analyzer.analyze_template()
|
|
|
|
assert analysis["items"]["is_iterated"]
|
|
assert "name" in analysis["item"]["accessed_properties"]
|
|
assert "subitems" in analysis["item"]["accessed_properties"]
|
|
|
|
def test_conditional_variable_usage(self):
|
|
"""Test detection of variables used in conditional statements."""
|
|
LOG.info("Testing conditional variable usage")
|
|
|
|
template = """
|
|
{% if user.is_admin and config.debug_mode %}
|
|
{{ debug_info }}
|
|
{% endif %}
|
|
"""
|
|
analyzer = JinjaTemplateAnalyzer(template)
|
|
analysis = analyzer.analyze_template()
|
|
|
|
assert analysis["user"]["is_conditional"]
|
|
assert analysis["config"]["is_conditional"]
|
|
assert "is_admin" in analysis["user"]["accessed_properties"]
|
|
assert "debug_mode" in analysis["config"]["accessed_properties"]
|
|
|
|
def test_complex_expressions(self):
|
|
"""Test handling of complex expressions and filters."""
|
|
LOG.info("Testing complex expressions and filters")
|
|
|
|
template = """
|
|
{{ user.name | upper }}
|
|
{{ messages | length > 0 and messages[0].content }}
|
|
{{ data['key'].nested['value'] }}
|
|
"""
|
|
analyzer = JinjaTemplateAnalyzer(template)
|
|
variables = analyzer.get_template_variables()
|
|
|
|
assert "user" in variables
|
|
assert "name" in variables["user"]
|
|
assert "messages" in variables
|
|
assert "content" in variables["messages"]
|
|
assert "data" in variables
|
|
|
|
def test_basic_msg_vars(self, basic_jinja_template_analyzer):
|
|
"""Test that the basic message variables are correctly identified."""
|
|
LOG.info("Testing basic message variables")
|
|
|
|
variables = basic_jinja_template_analyzer.get_message_vars()
|
|
assert variables == {"role", "content"}
|
|
|
|
def test_mixtral_msg_vars(self, mistral_jinja_template_analyzer):
|
|
"""Test that the mixtral message variables are correctly identified."""
|
|
LOG.info("Testing mixtral message variables")
|
|
|
|
variables = mistral_jinja_template_analyzer.get_message_vars()
|
|
assert variables == {"role", "content", "tool_calls", "tool_call_id"}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__])
|