Add comprehensive unit tests for all agent service components
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1,5 +1,8 @@
|
||||
"""Unit tests for ToolCallValidator and validate_agent_tools."""
|
||||
import pytest
|
||||
from agent_service.llm.tool_validator import ToolCallValidator, AgentConfigError
|
||||
from agent_service.llm.tool_validator import (
|
||||
ToolCallValidator, AgentConfigError, ToolValidationError, validate_agent_tools,
|
||||
)
|
||||
|
||||
SAMPLE_TOOLS = [
|
||||
{'name': 'get_invoices', 'parameters': {
|
||||
@@ -16,20 +19,25 @@ SAMPLE_TOOLS = [
|
||||
|
||||
def test_validator_init():
|
||||
v = ToolCallValidator(SAMPLE_TOOLS)
|
||||
assert 'get_invoices' in v._tool_map
|
||||
assert 'get_invoices' in v._tools
|
||||
|
||||
|
||||
def test_raises_on_too_many_tools():
|
||||
many_tools = [{'name': f'tool_{i}', 'parameters': {}} for i in range(9)]
|
||||
with pytest.raises(AgentConfigError, match='MAX_TOOLS_PER_AGENT'):
|
||||
ToolCallValidator(many_tools)
|
||||
with pytest.raises(AgentConfigError, match='max is 8'):
|
||||
validate_agent_tools(many_tools, 'test_agent')
|
||||
|
||||
|
||||
def test_validate_agent_tools_ok_for_8():
|
||||
tools_8 = [{'name': f'tool_{i}', 'parameters': {}} for i in range(8)]
|
||||
validate_agent_tools(tools_8, 'test_agent') # should not raise
|
||||
|
||||
|
||||
def test_valid_tool_call():
|
||||
v = ToolCallValidator(SAMPLE_TOOLS)
|
||||
result = v.validate({'name': 'get_invoices', 'arguments': {'partner_id': 42}})
|
||||
assert result.name == 'get_invoices'
|
||||
assert result.arguments['partner_id'] == 42
|
||||
assert result['name'] == 'get_invoices'
|
||||
assert result['arguments']['partner_id'] == 42
|
||||
|
||||
|
||||
def test_strips_hallucinated_params():
|
||||
@@ -37,37 +45,71 @@ def test_strips_hallucinated_params():
|
||||
result = v.validate({'name': 'get_invoices', 'arguments': {
|
||||
'partner_id': 1, 'nonexistent_param': 'bad',
|
||||
}})
|
||||
assert 'nonexistent_param' not in result.arguments
|
||||
assert 'nonexistent_param' not in result['arguments']
|
||||
|
||||
|
||||
def test_missing_required_param_raises():
|
||||
v = ToolCallValidator(SAMPLE_TOOLS)
|
||||
with pytest.raises(ValueError, match='partner_id'):
|
||||
with pytest.raises(ToolValidationError, match='partner_id'):
|
||||
v.validate({'name': 'get_invoices', 'arguments': {}})
|
||||
|
||||
|
||||
def test_type_coercion_int():
|
||||
v = ToolCallValidator(SAMPLE_TOOLS)
|
||||
result = v.validate({'name': 'get_invoices', 'arguments': {'partner_id': '42'}})
|
||||
assert result.arguments['partner_id'] == 42
|
||||
assert result['arguments']['partner_id'] == 42
|
||||
|
||||
|
||||
def test_unknown_tool_raises():
|
||||
v = ToolCallValidator(SAMPLE_TOOLS)
|
||||
with pytest.raises(ValueError, match='Unknown tool'):
|
||||
with pytest.raises(ToolValidationError, match='Unknown tool'):
|
||||
v.validate({'name': 'nonexistent_tool', 'arguments': {}})
|
||||
|
||||
|
||||
def test_parse_or_fallback_returns_none_on_bad_json():
|
||||
def test_parse_or_fallback_returns_none_on_unknown_tool():
|
||||
v = ToolCallValidator(SAMPLE_TOOLS)
|
||||
result = v.parse_or_fallback('not json at all', context='test')
|
||||
result = v.parse_or_fallback({'name': 'bad_tool', 'arguments': {}})
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_parse_or_fallback_valid_json():
|
||||
def test_parse_or_fallback_returns_none_on_missing_required():
|
||||
v = ToolCallValidator(SAMPLE_TOOLS)
|
||||
import json
|
||||
raw = json.dumps({'name': 'send_reminder', 'arguments': {'invoice_id': 5}})
|
||||
result = v.parse_or_fallback(raw, context='test')
|
||||
result = v.parse_or_fallback({'name': 'get_invoices', 'arguments': {}})
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_parse_or_fallback_valid_call():
|
||||
v = ToolCallValidator(SAMPLE_TOOLS)
|
||||
result = v.parse_or_fallback({'name': 'send_reminder', 'arguments': {'invoice_id': 5}})
|
||||
assert result is not None
|
||||
assert result.name == 'send_reminder'
|
||||
assert result['name'] == 'send_reminder'
|
||||
|
||||
|
||||
def test_parse_or_fallback_none_input():
|
||||
v = ToolCallValidator(SAMPLE_TOOLS)
|
||||
result = v.parse_or_fallback(None)
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_optional_param_not_required():
|
||||
v = ToolCallValidator(SAMPLE_TOOLS)
|
||||
result = v.validate({'name': 'get_invoices', 'arguments': {'partner_id': 1}})
|
||||
assert result['name'] == 'get_invoices'
|
||||
|
||||
|
||||
def test_enum_validation_rejects_invalid_value():
|
||||
tools = [{'name': 'set_state', 'parameters': {
|
||||
'state': {'type': 'string', 'enum': ['draft', 'posted']},
|
||||
}}]
|
||||
v = ToolCallValidator(tools)
|
||||
with pytest.raises(ToolValidationError):
|
||||
v.validate({'name': 'set_state', 'arguments': {'state': 'invalid_state'}})
|
||||
|
||||
|
||||
def test_enum_validation_accepts_valid_value():
|
||||
tools = [{'name': 'set_state', 'parameters': {
|
||||
'state': {'type': 'string', 'enum': ['draft', 'posted']},
|
||||
}}]
|
||||
v = ToolCallValidator(tools)
|
||||
result = v.validate({'name': 'set_state', 'arguments': {'state': 'posted'}})
|
||||
assert result['arguments']['state'] == 'posted'
|
||||
|
||||
Reference in New Issue
Block a user