Files
odoo-ai/agent_service/llm/tool_validator.py

90 lines
3.8 KiB
Python

from __future__ import annotations
import json, logging
from typing import Any
logger = logging.getLogger(__name__)
MAX_TOOLS_PER_AGENT = 8
_TYPE_MAP = {
'string': str, 'integer': int, 'number': (int, float),
'boolean': bool, 'object': dict, 'array': list,
}
class ToolValidationError(Exception): pass
class AgentConfigError(Exception): pass
def validate_agent_tools(tools, agent_name):
if len(tools) > MAX_TOOLS_PER_AGENT:
raise AgentConfigError(
f'{agent_name} has {len(tools)} tools but max is {MAX_TOOLS_PER_AGENT}. '
f'Split into contextual tool groups.')
class ToolCallValidator:
def __init__(self, tools):
self._tools = {t["name"]: t for t in tools}
def validate(self, tool_call):
name = tool_call.get('name') or (tool_call.get('function') or {}).get('name')
if not name:
raise ToolValidationError('Tool call missing name field')
if name not in self._tools:
raise ToolValidationError(f'Unknown tool {name!r} not in agent tool list')
tool_def = self._tools[name]
params_schema = tool_def.get('parameters', {})
arguments = tool_call.get('arguments') or (tool_call.get('function') or {}).get('arguments', {})
if isinstance(arguments, str):
try:
arguments = json.loads(arguments)
except json.JSONDecodeError as exc:
raise ToolValidationError(f'Tool {name!r} arguments is invalid JSON: {exc}') from exc
if not isinstance(arguments, dict):
raise ToolValidationError(f'Tool {name!r} arguments must be dict, got {type(arguments)}')
cleaned: dict[str, Any] = {}
for key, value in arguments.items():
if key not in params_schema:
logger.warning('ToolCallValidator: stripping hallucinated param %r from tool %r', key, name)
continue
cleaned[key] = value
for key, schema in params_schema.items():
if schema.get('optional', False):
continue
if key not in cleaned:
if 'default' in schema:
cleaned[key] = schema['default']
else:
raise ToolValidationError(f'Tool {name!r} missing required param {key!r}')
for key, value in list(cleaned.items()):
schema = params_schema.get(key, {})
expected_type = schema.get('type')
if expected_type and expected_type in _TYPE_MAP:
expected = _TYPE_MAP[expected_type]
if not isinstance(value, expected):
if expected_type in ('integer', 'number') and isinstance(value, str):
try:
cleaned[key] = int(value) if expected_type == 'integer' else float(value)
except ValueError:
raise ToolValidationError(
f'Tool {name!r} param {key!r} expected {expected_type}, got {value!r}')
elif expected_type == 'string':
cleaned[key] = str(value)
else:
raise ToolValidationError(
f'Tool {name!r} param {key!r} expected {expected_type}, got {type(value).__name__}')
if 'enum' in schema and cleaned.get(key) not in schema['enum']:
raise ToolValidationError(
f'Tool {name!r} param {key!r} value {cleaned.get(key)!r} not in {schema["enum"]}')
return {'name': name, 'arguments': cleaned}
def parse_or_fallback(self, tool_call):
if tool_call is None:
return None
try:
return self.validate(tool_call)
except ToolValidationError as exc:
logger.warning('ToolCallValidator fallback: %s', exc)
return None