"""Unit tests for ToolCallValidator and validate_agent_tools.""" import pytest from agent_service.llm.tool_validator import ( ToolCallValidator, AgentConfigError, ToolValidationError, validate_agent_tools, ) SAMPLE_TOOLS = [ {'name': 'get_invoices', 'parameters': { 'state': {'type': 'string', 'optional': True}, 'limit': {'type': 'integer', 'optional': True}, 'partner_id': {'type': 'integer'}, }}, {'name': 'send_reminder', 'parameters': { 'invoice_id': {'type': 'integer'}, 'message': {'type': 'string', 'optional': True}, }}, ] def test_validator_init(): v = ToolCallValidator(SAMPLE_TOOLS) assert 'get_invoices' in v._tools def test_raises_on_too_many_tools(): many_tools = [{'name': f'tool_{i}', 'parameters': {}} for i in range(9)] with pytest.raises(AgentConfigError, match='max is 8'): validate_agent_tools(many_tools, 'test_agent') def test_validate_agent_tools_ok_for_8(): tools_8 = [{'name': f'tool_{i}', 'parameters': {}} for i in range(8)] validate_agent_tools(tools_8, 'test_agent') # should not raise def test_valid_tool_call(): v = ToolCallValidator(SAMPLE_TOOLS) result = v.validate({'name': 'get_invoices', 'arguments': {'partner_id': 42}}) assert result['name'] == 'get_invoices' assert result['arguments']['partner_id'] == 42 def test_strips_hallucinated_params(): v = ToolCallValidator(SAMPLE_TOOLS) result = v.validate({'name': 'get_invoices', 'arguments': { 'partner_id': 1, 'nonexistent_param': 'bad', }}) assert 'nonexistent_param' not in result['arguments'] def test_missing_required_param_raises(): v = ToolCallValidator(SAMPLE_TOOLS) with pytest.raises(ToolValidationError, match='partner_id'): v.validate({'name': 'get_invoices', 'arguments': {}}) def test_type_coercion_int(): v = ToolCallValidator(SAMPLE_TOOLS) result = v.validate({'name': 'get_invoices', 'arguments': {'partner_id': '42'}}) assert result['arguments']['partner_id'] == 42 def test_unknown_tool_raises(): v = ToolCallValidator(SAMPLE_TOOLS) with pytest.raises(ToolValidationError, match='Unknown tool'): v.validate({'name': 'nonexistent_tool', 'arguments': {}}) def test_parse_or_fallback_returns_none_on_unknown_tool(): v = ToolCallValidator(SAMPLE_TOOLS) result = v.parse_or_fallback({'name': 'bad_tool', 'arguments': {}}) assert result is None def test_parse_or_fallback_returns_none_on_missing_required(): v = ToolCallValidator(SAMPLE_TOOLS) result = v.parse_or_fallback({'name': 'get_invoices', 'arguments': {}}) assert result is None def test_parse_or_fallback_valid_call(): v = ToolCallValidator(SAMPLE_TOOLS) result = v.parse_or_fallback({'name': 'send_reminder', 'arguments': {'invoice_id': 5}}) assert result is not None assert result['name'] == 'send_reminder' def test_parse_or_fallback_none_input(): v = ToolCallValidator(SAMPLE_TOOLS) result = v.parse_or_fallback(None) assert result is None def test_optional_param_not_required(): v = ToolCallValidator(SAMPLE_TOOLS) result = v.validate({'name': 'get_invoices', 'arguments': {'partner_id': 1}}) assert result['name'] == 'get_invoices' def test_enum_validation_rejects_invalid_value(): tools = [{'name': 'set_state', 'parameters': { 'state': {'type': 'string', 'enum': ['draft', 'posted']}, }}] v = ToolCallValidator(tools) with pytest.raises(ToolValidationError): v.validate({'name': 'set_state', 'arguments': {'state': 'invalid_state'}}) def test_enum_validation_accepts_valid_value(): tools = [{'name': 'set_state', 'parameters': { 'state': {'type': 'string', 'enum': ['draft', 'posted']}, }}] v = ToolCallValidator(tools) result = v.validate({'name': 'set_state', 'arguments': {'state': 'posted'}}) assert result['arguments']['state'] == 'posted'