feat: add LLM abstraction layer (router, Ollama backend, Claude backend)
This commit is contained in:
89
agent_service/llm/tool_validator.py
Normal file
89
agent_service/llm/tool_validator.py
Normal file
@@ -0,0 +1,89 @@
|
||||
from __future__ import annotations
|
||||
import json, logging
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_TOOLS_PER_AGENT = 8
|
||||
|
||||
_TYPE_MAP = {
|
||||
'string': str, 'integer': int, 'number': (int, float),
|
||||
'boolean': bool, 'object': dict, 'array': list,
|
||||
}
|
||||
|
||||
|
||||
class ToolValidationError(Exception): pass
|
||||
class AgentConfigError(Exception): pass
|
||||
|
||||
|
||||
def validate_agent_tools(tools, agent_name):
|
||||
if len(tools) > MAX_TOOLS_PER_AGENT:
|
||||
raise AgentConfigError(
|
||||
f'{agent_name} has {len(tools)} tools but max is {MAX_TOOLS_PER_AGENT}. '
|
||||
f'Split into contextual tool groups.')
|
||||
|
||||
|
||||
class ToolCallValidator:
|
||||
def __init__(self, tools):
|
||||
self._tools = {t["name"]: t for t in tools}
|
||||
|
||||
def validate(self, tool_call):
|
||||
name = tool_call.get('name') or (tool_call.get('function') or {}).get('name')
|
||||
if not name:
|
||||
raise ToolValidationError('Tool call missing name field')
|
||||
if name not in self._tools:
|
||||
raise ToolValidationError(f'Unknown tool {name!r} not in agent tool list')
|
||||
tool_def = self._tools[name]
|
||||
params_schema = tool_def.get('parameters', {})
|
||||
arguments = tool_call.get('arguments') or (tool_call.get('function') or {}).get('arguments', {})
|
||||
if isinstance(arguments, str):
|
||||
try:
|
||||
arguments = json.loads(arguments)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise ToolValidationError(f'Tool {name!r} arguments is invalid JSON: {exc}') from exc
|
||||
if not isinstance(arguments, dict):
|
||||
raise ToolValidationError(f'Tool {name!r} arguments must be dict, got {type(arguments)}')
|
||||
cleaned: dict[str, Any] = {}
|
||||
for key, value in arguments.items():
|
||||
if key not in params_schema:
|
||||
logger.warning('ToolCallValidator: stripping hallucinated param %r from tool %r', key, name)
|
||||
continue
|
||||
cleaned[key] = value
|
||||
for key, schema in params_schema.items():
|
||||
if schema.get('optional', False):
|
||||
continue
|
||||
if key not in cleaned:
|
||||
if 'default' in schema:
|
||||
cleaned[key] = schema['default']
|
||||
else:
|
||||
raise ToolValidationError(f'Tool {name!r} missing required param {key!r}')
|
||||
for key, value in list(cleaned.items()):
|
||||
schema = params_schema.get(key, {})
|
||||
expected_type = schema.get('type')
|
||||
if expected_type and expected_type in _TYPE_MAP:
|
||||
expected = _TYPE_MAP[expected_type]
|
||||
if not isinstance(value, expected):
|
||||
if expected_type in ('integer', 'number') and isinstance(value, str):
|
||||
try:
|
||||
cleaned[key] = int(value) if expected_type == 'integer' else float(value)
|
||||
except ValueError:
|
||||
raise ToolValidationError(
|
||||
f'Tool {name!r} param {key!r} expected {expected_type}, got {value!r}')
|
||||
elif expected_type == 'string':
|
||||
cleaned[key] = str(value)
|
||||
else:
|
||||
raise ToolValidationError(
|
||||
f'Tool {name!r} param {key!r} expected {expected_type}, got {type(value).__name__}')
|
||||
if 'enum' in schema and cleaned.get(key) not in schema['enum']:
|
||||
raise ToolValidationError(
|
||||
f'Tool {name!r} param {key!r} value {cleaned.get(key)!r} not in {schema["enum"]}')
|
||||
return {'name': name, 'arguments': cleaned}
|
||||
|
||||
def parse_or_fallback(self, tool_call):
|
||||
if tool_call is None:
|
||||
return None
|
||||
try:
|
||||
return self.validate(tool_call)
|
||||
except ToolValidationError as exc:
|
||||
logger.warning('ToolCallValidator fallback: %s', exc)
|
||||
return None
|
||||
Reference in New Issue
Block a user