feat: add LLM abstraction layer (router, Ollama backend, Claude backend)
This commit is contained in:
0
agent_service/llm/__init__.py
Normal file
0
agent_service/llm/__init__.py
Normal file
83
agent_service/llm/claude_backend.py
Normal file
83
agent_service/llm/claude_backend.py
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
import asyncio, logging, time
|
||||||
|
from typing import Any
|
||||||
|
from .llm_types import LLMResponse, ClaudeTimeoutError, ClaudeAuthError, ClaudeRateLimitError
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _to_claude_tools(tools):
|
||||||
|
result = []
|
||||||
|
for t in tools:
|
||||||
|
params = t.get('parameters', {})
|
||||||
|
props, req = {}, []
|
||||||
|
for k, v in params.items():
|
||||||
|
props[k] = {k2: v2 for k2, v2 in v.items() if k2 != 'optional'}
|
||||||
|
if not v.get('optional', False):
|
||||||
|
req.append(k)
|
||||||
|
result.append({'name': t['name'], 'description': t.get('description', t['name']),
|
||||||
|
'input_schema': {'type': 'object', 'properties': props, 'required': req}})
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class ClaudeBackend:
|
||||||
|
def __init__(self, api_key, model, timeout=60, max_concurrent=5):
|
||||||
|
import anthropic
|
||||||
|
self._client = anthropic.AsyncAnthropic(api_key=api_key)
|
||||||
|
self._model = model
|
||||||
|
self._timeout = timeout
|
||||||
|
self._semaphore = asyncio.Semaphore(max_concurrent)
|
||||||
|
self._active = 0
|
||||||
|
|
||||||
|
async def submit(self, messages, tools=None, caller='unknown'):
|
||||||
|
import anthropic
|
||||||
|
wait_start = time.monotonic()
|
||||||
|
async with self._semaphore:
|
||||||
|
wait_ms = int((time.monotonic() - wait_start) * 1000)
|
||||||
|
self._active += 1
|
||||||
|
t0 = time.monotonic()
|
||||||
|
try:
|
||||||
|
system = None
|
||||||
|
conv = []
|
||||||
|
for m in messages:
|
||||||
|
if m.get('role') == 'system':
|
||||||
|
system = m['content']
|
||||||
|
else:
|
||||||
|
conv.append(m)
|
||||||
|
kw: dict[str, Any] = {'model': self._model, 'max_tokens': 4096, 'messages': conv}
|
||||||
|
if system:
|
||||||
|
kw['system'] = system
|
||||||
|
if tools:
|
||||||
|
kw['tools'] = _to_claude_tools(tools)
|
||||||
|
try:
|
||||||
|
resp = await asyncio.wait_for(self._client.messages.create(**kw), timeout=self._timeout)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
raise ClaudeTimeoutError(f'Claude timeout after {self._timeout}s caller={caller}')
|
||||||
|
except anthropic.AuthenticationError as exc:
|
||||||
|
raise ClaudeAuthError(f'Claude auth error: {exc}') from exc
|
||||||
|
except anthropic.RateLimitError:
|
||||||
|
logger.warning('Claude rate limit, backing off 5s caller=%s', caller)
|
||||||
|
await asyncio.sleep(5)
|
||||||
|
try:
|
||||||
|
resp = await asyncio.wait_for(self._client.messages.create(**kw), timeout=self._timeout)
|
||||||
|
except anthropic.RateLimitError as exc2:
|
||||||
|
raise ClaudeRateLimitError(f'Claude rate limit persists: {exc2}') from exc2
|
||||||
|
ms = int((time.monotonic() - t0) * 1000)
|
||||||
|
text, tool_calls = '', None
|
||||||
|
for block in resp.content:
|
||||||
|
if block.type == 'text':
|
||||||
|
text += block.text
|
||||||
|
elif block.type == 'tool_use':
|
||||||
|
if tool_calls is None: tool_calls = []
|
||||||
|
tool_calls.append({'name': block.name, 'arguments': block.input})
|
||||||
|
tin, tout = resp.usage.input_tokens, resp.usage.output_tokens
|
||||||
|
cost = (tin * 3 + tout * 15) / 1_000_000
|
||||||
|
logger.info('claude caller=%s model=%s wait_ms=%d ms=%d tin=%d tout=%d cost=%.5f',
|
||||||
|
caller, self._model, wait_ms, ms, tin, tout, cost)
|
||||||
|
return LLMResponse(content=text, tool_calls=tool_calls, backend_used='claude',
|
||||||
|
model_used=self._model, tokens_in=tin, tokens_out=tout, latency_ms=ms)
|
||||||
|
finally:
|
||||||
|
self._active -= 1
|
||||||
|
|
||||||
|
@property
|
||||||
|
def active_count(self): return self._active
|
||||||
35
agent_service/llm/llm_config_store.py
Normal file
35
agent_service/llm/llm_config_store.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class LLMConfigStore:
|
||||||
|
def __init__(self, pg_pool):
|
||||||
|
self._pool = pg_pool
|
||||||
|
|
||||||
|
async def get_backend(self, caller):
|
||||||
|
async with self._pool.acquire(timeout=10) as conn:
|
||||||
|
row = await conn.fetchrow(
|
||||||
|
'SELECT backend FROM ab_llm_config WHERE caller = $1', caller)
|
||||||
|
return row['backend'] if row else None
|
||||||
|
|
||||||
|
async def set_backend(self, caller, backend, set_by, note=None):
|
||||||
|
async with self._pool.acquire(timeout=10) as conn:
|
||||||
|
await conn.execute(
|
||||||
|
"""INSERT INTO ab_llm_config (caller, backend, set_by, note)
|
||||||
|
VALUES ($1, $2, $3, $4)
|
||||||
|
ON CONFLICT (caller) DO UPDATE
|
||||||
|
SET backend=$2, set_by=$3, set_at=NOW(), note=$4""",
|
||||||
|
caller, backend, set_by, note)
|
||||||
|
logger.info('LLMConfigStore.set_backend caller=%s backend=%s set_by=%s', caller, backend, set_by)
|
||||||
|
|
||||||
|
async def get_all(self):
|
||||||
|
async with self._pool.acquire(timeout=10) as conn:
|
||||||
|
rows = await conn.fetch('SELECT * FROM ab_llm_config ORDER BY caller')
|
||||||
|
return [dict(r) for r in rows]
|
||||||
|
|
||||||
|
async def reset(self, caller):
|
||||||
|
async with self._pool.acquire(timeout=10) as conn:
|
||||||
|
await conn.execute('DELETE FROM ab_llm_config WHERE caller = $1', caller)
|
||||||
|
logger.info('LLMConfigStore.reset caller=%s', caller)
|
||||||
112
agent_service/llm/llm_router.py
Normal file
112
agent_service/llm/llm_router.py
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
import logging, os
|
||||||
|
from .llm_types import LLMResponse, OllamaUnavailableError, ClaudeTimeoutError, ClaudeRateLimitError
|
||||||
|
from .ollama_backend import OllamaBackend
|
||||||
|
from .llm_config_store import LLMConfigStore
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
HIPAA_LOCKED_AGENTS = frozenset({'finance_agent', 'accounting_agent', 'employees_agent', 'expenses_agent'})
|
||||||
|
|
||||||
|
|
||||||
|
class LLMRouter:
|
||||||
|
def __init__(self, config, pg_pool=None):
|
||||||
|
self._config = config
|
||||||
|
self._privacy_mode = getattr(config, 'llm_privacy_mode', 'local')
|
||||||
|
self._config_store = LLMConfigStore(pg_pool) if pg_pool else None
|
||||||
|
self._ollama = OllamaBackend(
|
||||||
|
url=config.ollama_url, model=config.ollama_model,
|
||||||
|
timeout=config.ollama_timeout, max_concurrent=config.ollama_max_concurrent)
|
||||||
|
self._claude = None
|
||||||
|
if self._privacy_mode != 'local':
|
||||||
|
api_key = getattr(config, 'anthropic_api_key', None)
|
||||||
|
if api_key:
|
||||||
|
from .claude_backend import ClaudeBackend
|
||||||
|
self._claude = ClaudeBackend(
|
||||||
|
api_key=api_key, model=config.claude_model,
|
||||||
|
timeout=config.claude_timeout, max_concurrent=config.claude_max_concurrent)
|
||||||
|
logger.info('ClaudeBackend initialized mode=%s', self._privacy_mode)
|
||||||
|
elif self._privacy_mode == 'cloud':
|
||||||
|
logger.error('Privacy mode is cloud but ANTHROPIC_API_KEY not set')
|
||||||
|
else:
|
||||||
|
logger.warning('Privacy mode is hybrid but ANTHROPIC_API_KEY not set')
|
||||||
|
|
||||||
|
async def submit(self, messages, tools=None, caller='unknown'):
|
||||||
|
backend_name = await self.get_backend(caller)
|
||||||
|
if backend_name == 'claude':
|
||||||
|
if self._claude is None:
|
||||||
|
logger.warning('Claude requested but unavailable, fallback to Ollama caller=%s', caller)
|
||||||
|
backend_name = 'ollama'
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
return await self._claude.submit(messages, tools, caller)
|
||||||
|
except (ClaudeTimeoutError, ClaudeRateLimitError) as exc:
|
||||||
|
logger.warning('Claude failed caller=%s (%s), falling back to Ollama', caller, exc)
|
||||||
|
return await self._ollama.submit(messages, tools, caller)
|
||||||
|
return await self._ollama.submit(messages, tools, caller)
|
||||||
|
|
||||||
|
async def get_backend(self, caller):
|
||||||
|
if caller in HIPAA_LOCKED_AGENTS:
|
||||||
|
return 'ollama'
|
||||||
|
if self._privacy_mode == 'local':
|
||||||
|
return 'ollama'
|
||||||
|
if self._privacy_mode == 'cloud':
|
||||||
|
return 'claude'
|
||||||
|
if self._privacy_mode == 'hybrid':
|
||||||
|
if self._config_store:
|
||||||
|
try:
|
||||||
|
db_val = await self._config_store.get_backend(caller)
|
||||||
|
if db_val:
|
||||||
|
return db_val
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning('LLMConfigStore lookup failed: %s', exc)
|
||||||
|
env_key = f'AGENT_BACKEND_{caller.upper()}'
|
||||||
|
env_val = os.environ.get(env_key)
|
||||||
|
if env_val in ('ollama', 'claude'):
|
||||||
|
return env_val
|
||||||
|
return 'ollama'
|
||||||
|
logger.error('Unknown privacy mode %s, defaulting to ollama', self._privacy_mode)
|
||||||
|
return 'ollama'
|
||||||
|
|
||||||
|
async def set_backend(self, caller, backend, set_by, note=None):
|
||||||
|
if caller in HIPAA_LOCKED_AGENTS:
|
||||||
|
raise ValueError(f'Cannot override backend for HIPAA-locked agent: {caller}')
|
||||||
|
if backend not in ('ollama', 'claude'):
|
||||||
|
raise ValueError(f'Invalid backend: {backend}')
|
||||||
|
if not self._config_store:
|
||||||
|
raise RuntimeError('No Postgres pool for runtime config store')
|
||||||
|
await self._config_store.set_backend(caller, backend, set_by, note)
|
||||||
|
|
||||||
|
async def set_privacy_mode(self, mode, set_by):
|
||||||
|
if mode not in ('local', 'hybrid', 'cloud'):
|
||||||
|
raise ValueError(f'Invalid privacy mode: {mode}')
|
||||||
|
self._privacy_mode = mode
|
||||||
|
if self._config_store:
|
||||||
|
await self._config_store.set_backend('__system__', mode, set_by,
|
||||||
|
f'Privacy mode changed to {mode}')
|
||||||
|
if mode == 'local':
|
||||||
|
self._claude = None
|
||||||
|
logger.info('Privacy mode set to local - ClaudeBackend disabled')
|
||||||
|
elif mode in ('hybrid', 'cloud') and self._claude is None:
|
||||||
|
api_key = getattr(self._config, 'anthropic_api_key', None)
|
||||||
|
if api_key:
|
||||||
|
from .claude_backend import ClaudeBackend
|
||||||
|
self._claude = ClaudeBackend(
|
||||||
|
api_key=api_key, model=self._config.claude_model,
|
||||||
|
timeout=self._config.claude_timeout, max_concurrent=self._config.claude_max_concurrent)
|
||||||
|
logger.info('Privacy mode set to %s by user_id=%s', mode, set_by)
|
||||||
|
|
||||||
|
async def get_status(self):
|
||||||
|
s = {'privacy_mode': self._privacy_mode,
|
||||||
|
'ollama': {'active': self._ollama.active_count}}
|
||||||
|
if self._claude:
|
||||||
|
s['claude'] = {'active': self._claude.active_count}
|
||||||
|
else:
|
||||||
|
s['claude'] = {'available': False, 'reason': 'local mode or no API key'}
|
||||||
|
return s
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ollama_queue_depth(self): return self._ollama.active_count
|
||||||
|
|
||||||
|
@property
|
||||||
|
def claude_active_count(self): return self._claude.active_count if self._claude else 0
|
||||||
21
agent_service/llm/llm_types.py
Normal file
21
agent_service/llm/llm_types.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LLMResponse:
|
||||||
|
content: str
|
||||||
|
tool_calls: object # list | None
|
||||||
|
backend_used: str # 'ollama' | 'claude'
|
||||||
|
model_used: str
|
||||||
|
tokens_in: int
|
||||||
|
tokens_out: int
|
||||||
|
latency_ms: int
|
||||||
|
|
||||||
|
|
||||||
|
class LLMError(Exception): pass
|
||||||
|
class OllamaTimeoutError(LLMError): pass
|
||||||
|
class OllamaUnavailableError(LLMError): pass
|
||||||
|
class ClaudeTimeoutError(LLMError): pass
|
||||||
|
class ClaudeAuthError(LLMError): pass
|
||||||
|
class ClaudeRateLimitError(LLMError): pass
|
||||||
54
agent_service/llm/ollama_backend.py
Normal file
54
agent_service/llm/ollama_backend.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
import asyncio, logging, time
|
||||||
|
from .llm_types import LLMResponse, OllamaTimeoutError, OllamaUnavailableError
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaBackend:
|
||||||
|
def __init__(self, url, model, timeout=120, max_concurrent=2):
|
||||||
|
self._url = url
|
||||||
|
self._model = model
|
||||||
|
self._timeout = timeout
|
||||||
|
self._semaphore = asyncio.Semaphore(max_concurrent)
|
||||||
|
self._active = 0
|
||||||
|
|
||||||
|
async def submit(self, messages, tools=None, caller='unknown'):
|
||||||
|
import ollama
|
||||||
|
wait_start = time.monotonic()
|
||||||
|
async with self._semaphore:
|
||||||
|
wait_ms = int((time.monotonic() - wait_start) * 1000)
|
||||||
|
self._active += 1
|
||||||
|
t0 = time.monotonic()
|
||||||
|
try:
|
||||||
|
kwargs = {'model': self._model, 'messages': messages}
|
||||||
|
if tools:
|
||||||
|
kwargs['tools'] = tools
|
||||||
|
client = ollama.AsyncClient(host=self._url)
|
||||||
|
try:
|
||||||
|
response = await asyncio.wait_for(client.chat(**kwargs), timeout=self._timeout)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
raise OllamaTimeoutError(f'Ollama timeout after {self._timeout}s caller={caller}')
|
||||||
|
except Exception as exc:
|
||||||
|
s = str(exc).lower()
|
||||||
|
if 'connect' in s or 'refused' in s or 'unreachable' in s:
|
||||||
|
raise OllamaUnavailableError(f'Ollama unreachable: {exc}') from exc
|
||||||
|
raise OllamaUnavailableError(f'Ollama error: {exc}') from exc
|
||||||
|
ms = int((time.monotonic() - t0) * 1000)
|
||||||
|
msg = response.message
|
||||||
|
tool_calls = None
|
||||||
|
if hasattr(msg, 'tool_calls') and msg.tool_calls:
|
||||||
|
tool_calls = [{'name': tc.function.name, 'arguments': tc.function.arguments}
|
||||||
|
for tc in msg.tool_calls]
|
||||||
|
tin = response.prompt_eval_count or 0
|
||||||
|
tout = response.eval_count or 0
|
||||||
|
logger.info('ollama caller=%s wait_ms=%d inf_ms=%d tin=%d tout=%d',
|
||||||
|
caller, wait_ms, ms, tin, tout)
|
||||||
|
return LLMResponse(content=msg.content or '', tool_calls=tool_calls,
|
||||||
|
backend_used='ollama', model_used=self._model,
|
||||||
|
tokens_in=tin, tokens_out=tout, latency_ms=ms)
|
||||||
|
finally:
|
||||||
|
self._active -= 1
|
||||||
|
|
||||||
|
@property
|
||||||
|
def active_count(self): return self._active
|
||||||
89
agent_service/llm/tool_validator.py
Normal file
89
agent_service/llm/tool_validator.py
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
import json, logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
MAX_TOOLS_PER_AGENT = 8
|
||||||
|
|
||||||
|
_TYPE_MAP = {
|
||||||
|
'string': str, 'integer': int, 'number': (int, float),
|
||||||
|
'boolean': bool, 'object': dict, 'array': list,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ToolValidationError(Exception): pass
|
||||||
|
class AgentConfigError(Exception): pass
|
||||||
|
|
||||||
|
|
||||||
|
def validate_agent_tools(tools, agent_name):
|
||||||
|
if len(tools) > MAX_TOOLS_PER_AGENT:
|
||||||
|
raise AgentConfigError(
|
||||||
|
f'{agent_name} has {len(tools)} tools but max is {MAX_TOOLS_PER_AGENT}. '
|
||||||
|
f'Split into contextual tool groups.')
|
||||||
|
|
||||||
|
|
||||||
|
class ToolCallValidator:
|
||||||
|
def __init__(self, tools):
|
||||||
|
self._tools = {t["name"]: t for t in tools}
|
||||||
|
|
||||||
|
def validate(self, tool_call):
|
||||||
|
name = tool_call.get('name') or (tool_call.get('function') or {}).get('name')
|
||||||
|
if not name:
|
||||||
|
raise ToolValidationError('Tool call missing name field')
|
||||||
|
if name not in self._tools:
|
||||||
|
raise ToolValidationError(f'Unknown tool {name!r} not in agent tool list')
|
||||||
|
tool_def = self._tools[name]
|
||||||
|
params_schema = tool_def.get('parameters', {})
|
||||||
|
arguments = tool_call.get('arguments') or (tool_call.get('function') or {}).get('arguments', {})
|
||||||
|
if isinstance(arguments, str):
|
||||||
|
try:
|
||||||
|
arguments = json.loads(arguments)
|
||||||
|
except json.JSONDecodeError as exc:
|
||||||
|
raise ToolValidationError(f'Tool {name!r} arguments is invalid JSON: {exc}') from exc
|
||||||
|
if not isinstance(arguments, dict):
|
||||||
|
raise ToolValidationError(f'Tool {name!r} arguments must be dict, got {type(arguments)}')
|
||||||
|
cleaned: dict[str, Any] = {}
|
||||||
|
for key, value in arguments.items():
|
||||||
|
if key not in params_schema:
|
||||||
|
logger.warning('ToolCallValidator: stripping hallucinated param %r from tool %r', key, name)
|
||||||
|
continue
|
||||||
|
cleaned[key] = value
|
||||||
|
for key, schema in params_schema.items():
|
||||||
|
if schema.get('optional', False):
|
||||||
|
continue
|
||||||
|
if key not in cleaned:
|
||||||
|
if 'default' in schema:
|
||||||
|
cleaned[key] = schema['default']
|
||||||
|
else:
|
||||||
|
raise ToolValidationError(f'Tool {name!r} missing required param {key!r}')
|
||||||
|
for key, value in list(cleaned.items()):
|
||||||
|
schema = params_schema.get(key, {})
|
||||||
|
expected_type = schema.get('type')
|
||||||
|
if expected_type and expected_type in _TYPE_MAP:
|
||||||
|
expected = _TYPE_MAP[expected_type]
|
||||||
|
if not isinstance(value, expected):
|
||||||
|
if expected_type in ('integer', 'number') and isinstance(value, str):
|
||||||
|
try:
|
||||||
|
cleaned[key] = int(value) if expected_type == 'integer' else float(value)
|
||||||
|
except ValueError:
|
||||||
|
raise ToolValidationError(
|
||||||
|
f'Tool {name!r} param {key!r} expected {expected_type}, got {value!r}')
|
||||||
|
elif expected_type == 'string':
|
||||||
|
cleaned[key] = str(value)
|
||||||
|
else:
|
||||||
|
raise ToolValidationError(
|
||||||
|
f'Tool {name!r} param {key!r} expected {expected_type}, got {type(value).__name__}')
|
||||||
|
if 'enum' in schema and cleaned.get(key) not in schema['enum']:
|
||||||
|
raise ToolValidationError(
|
||||||
|
f'Tool {name!r} param {key!r} value {cleaned.get(key)!r} not in {schema["enum"]}')
|
||||||
|
return {'name': name, 'arguments': cleaned}
|
||||||
|
|
||||||
|
def parse_or_fallback(self, tool_call):
|
||||||
|
if tool_call is None:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return self.validate(tool_call)
|
||||||
|
except ToolValidationError as exc:
|
||||||
|
logger.warning('ToolCallValidator fallback: %s', exc)
|
||||||
|
return None
|
||||||
Reference in New Issue
Block a user