90 lines
3.8 KiB
Python
90 lines
3.8 KiB
Python
from __future__ import annotations
|
|
import json, logging
|
|
from typing import Any
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
MAX_TOOLS_PER_AGENT = 8
|
|
|
|
_TYPE_MAP = {
|
|
'string': str, 'integer': int, 'number': (int, float),
|
|
'boolean': bool, 'object': dict, 'array': list,
|
|
}
|
|
|
|
|
|
class ToolValidationError(Exception): pass
|
|
class AgentConfigError(Exception): pass
|
|
|
|
|
|
def validate_agent_tools(tools, agent_name):
|
|
if len(tools) > MAX_TOOLS_PER_AGENT:
|
|
raise AgentConfigError(
|
|
f'{agent_name} has {len(tools)} tools but max is {MAX_TOOLS_PER_AGENT}. '
|
|
f'Split into contextual tool groups.')
|
|
|
|
|
|
class ToolCallValidator:
|
|
def __init__(self, tools):
|
|
self._tools = {t["name"]: t for t in tools}
|
|
|
|
def validate(self, tool_call):
|
|
name = tool_call.get('name') or (tool_call.get('function') or {}).get('name')
|
|
if not name:
|
|
raise ToolValidationError('Tool call missing name field')
|
|
if name not in self._tools:
|
|
raise ToolValidationError(f'Unknown tool {name!r} not in agent tool list')
|
|
tool_def = self._tools[name]
|
|
params_schema = tool_def.get('parameters', {})
|
|
arguments = tool_call.get('arguments') or (tool_call.get('function') or {}).get('arguments', {})
|
|
if isinstance(arguments, str):
|
|
try:
|
|
arguments = json.loads(arguments)
|
|
except json.JSONDecodeError as exc:
|
|
raise ToolValidationError(f'Tool {name!r} arguments is invalid JSON: {exc}') from exc
|
|
if not isinstance(arguments, dict):
|
|
raise ToolValidationError(f'Tool {name!r} arguments must be dict, got {type(arguments)}')
|
|
cleaned: dict[str, Any] = {}
|
|
for key, value in arguments.items():
|
|
if key not in params_schema:
|
|
logger.warning('ToolCallValidator: stripping hallucinated param %r from tool %r', key, name)
|
|
continue
|
|
cleaned[key] = value
|
|
for key, schema in params_schema.items():
|
|
if schema.get('optional', False):
|
|
continue
|
|
if key not in cleaned:
|
|
if 'default' in schema:
|
|
cleaned[key] = schema['default']
|
|
else:
|
|
raise ToolValidationError(f'Tool {name!r} missing required param {key!r}')
|
|
for key, value in list(cleaned.items()):
|
|
schema = params_schema.get(key, {})
|
|
expected_type = schema.get('type')
|
|
if expected_type and expected_type in _TYPE_MAP:
|
|
expected = _TYPE_MAP[expected_type]
|
|
if not isinstance(value, expected):
|
|
if expected_type in ('integer', 'number') and isinstance(value, str):
|
|
try:
|
|
cleaned[key] = int(value) if expected_type == 'integer' else float(value)
|
|
except ValueError:
|
|
raise ToolValidationError(
|
|
f'Tool {name!r} param {key!r} expected {expected_type}, got {value!r}')
|
|
elif expected_type == 'string':
|
|
cleaned[key] = str(value)
|
|
else:
|
|
raise ToolValidationError(
|
|
f'Tool {name!r} param {key!r} expected {expected_type}, got {type(value).__name__}')
|
|
if 'enum' in schema and cleaned.get(key) not in schema['enum']:
|
|
raise ToolValidationError(
|
|
f'Tool {name!r} param {key!r} value {cleaned.get(key)!r} not in {schema["enum"]}')
|
|
return {'name': name, 'arguments': cleaned}
|
|
|
|
def parse_or_fallback(self, tool_call):
|
|
if tool_call is None:
|
|
return None
|
|
try:
|
|
return self.validate(tool_call)
|
|
except ToolValidationError as exc:
|
|
logger.warning('ToolCallValidator fallback: %s', exc)
|
|
return None
|