From 7d92c2ea6f0fd55c304791facf54b495e88867e3 Mon Sep 17 00:00:00 2001 From: ActiveBlue Build Date: Sun, 12 Apr 2026 16:46:18 -0400 Subject: [PATCH] feat: add LLM abstraction layer (router, Ollama backend, Claude backend) --- agent_service/llm/__init__.py | 0 agent_service/llm/claude_backend.py | 83 +++++++++++++++++++ agent_service/llm/llm_config_store.py | 35 ++++++++ agent_service/llm/llm_router.py | 112 ++++++++++++++++++++++++++ agent_service/llm/llm_types.py | 21 +++++ agent_service/llm/ollama_backend.py | 54 +++++++++++++ agent_service/llm/tool_validator.py | 89 ++++++++++++++++++++ 7 files changed, 394 insertions(+) create mode 100644 agent_service/llm/__init__.py create mode 100644 agent_service/llm/claude_backend.py create mode 100644 agent_service/llm/llm_config_store.py create mode 100644 agent_service/llm/llm_router.py create mode 100644 agent_service/llm/llm_types.py create mode 100644 agent_service/llm/ollama_backend.py create mode 100644 agent_service/llm/tool_validator.py diff --git a/agent_service/llm/__init__.py b/agent_service/llm/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/agent_service/llm/claude_backend.py b/agent_service/llm/claude_backend.py new file mode 100644 index 0000000..a00c7e9 --- /dev/null +++ b/agent_service/llm/claude_backend.py @@ -0,0 +1,83 @@ +from __future__ import annotations +import asyncio, logging, time +from typing import Any +from .llm_types import LLMResponse, ClaudeTimeoutError, ClaudeAuthError, ClaudeRateLimitError + +logger = logging.getLogger(__name__) + + +def _to_claude_tools(tools): + result = [] + for t in tools: + params = t.get('parameters', {}) + props, req = {}, [] + for k, v in params.items(): + props[k] = {k2: v2 for k2, v2 in v.items() if k2 != 'optional'} + if not v.get('optional', False): + req.append(k) + result.append({'name': t['name'], 'description': t.get('description', t['name']), + 'input_schema': {'type': 'object', 'properties': props, 'required': req}}) + return result + + +class ClaudeBackend: + def __init__(self, api_key, model, timeout=60, max_concurrent=5): + import anthropic + self._client = anthropic.AsyncAnthropic(api_key=api_key) + self._model = model + self._timeout = timeout + self._semaphore = asyncio.Semaphore(max_concurrent) + self._active = 0 + + async def submit(self, messages, tools=None, caller='unknown'): + import anthropic + wait_start = time.monotonic() + async with self._semaphore: + wait_ms = int((time.monotonic() - wait_start) * 1000) + self._active += 1 + t0 = time.monotonic() + try: + system = None + conv = [] + for m in messages: + if m.get('role') == 'system': + system = m['content'] + else: + conv.append(m) + kw: dict[str, Any] = {'model': self._model, 'max_tokens': 4096, 'messages': conv} + if system: + kw['system'] = system + if tools: + kw['tools'] = _to_claude_tools(tools) + try: + resp = await asyncio.wait_for(self._client.messages.create(**kw), timeout=self._timeout) + except asyncio.TimeoutError: + raise ClaudeTimeoutError(f'Claude timeout after {self._timeout}s caller={caller}') + except anthropic.AuthenticationError as exc: + raise ClaudeAuthError(f'Claude auth error: {exc}') from exc + except anthropic.RateLimitError: + logger.warning('Claude rate limit, backing off 5s caller=%s', caller) + await asyncio.sleep(5) + try: + resp = await asyncio.wait_for(self._client.messages.create(**kw), timeout=self._timeout) + except anthropic.RateLimitError as exc2: + raise ClaudeRateLimitError(f'Claude rate limit persists: {exc2}') from exc2 + ms = int((time.monotonic() - t0) * 1000) + text, tool_calls = '', None + for block in resp.content: + if block.type == 'text': + text += block.text + elif block.type == 'tool_use': + if tool_calls is None: tool_calls = [] + tool_calls.append({'name': block.name, 'arguments': block.input}) + tin, tout = resp.usage.input_tokens, resp.usage.output_tokens + cost = (tin * 3 + tout * 15) / 1_000_000 + logger.info('claude caller=%s model=%s wait_ms=%d ms=%d tin=%d tout=%d cost=%.5f', + caller, self._model, wait_ms, ms, tin, tout, cost) + return LLMResponse(content=text, tool_calls=tool_calls, backend_used='claude', + model_used=self._model, tokens_in=tin, tokens_out=tout, latency_ms=ms) + finally: + self._active -= 1 + + @property + def active_count(self): return self._active diff --git a/agent_service/llm/llm_config_store.py b/agent_service/llm/llm_config_store.py new file mode 100644 index 0000000..e86db8e --- /dev/null +++ b/agent_service/llm/llm_config_store.py @@ -0,0 +1,35 @@ +from __future__ import annotations +import logging + +logger = logging.getLogger(__name__) + + +class LLMConfigStore: + def __init__(self, pg_pool): + self._pool = pg_pool + + async def get_backend(self, caller): + async with self._pool.acquire(timeout=10) as conn: + row = await conn.fetchrow( + 'SELECT backend FROM ab_llm_config WHERE caller = $1', caller) + return row['backend'] if row else None + + async def set_backend(self, caller, backend, set_by, note=None): + async with self._pool.acquire(timeout=10) as conn: + await conn.execute( + """INSERT INTO ab_llm_config (caller, backend, set_by, note) + VALUES ($1, $2, $3, $4) + ON CONFLICT (caller) DO UPDATE + SET backend=$2, set_by=$3, set_at=NOW(), note=$4""", + caller, backend, set_by, note) + logger.info('LLMConfigStore.set_backend caller=%s backend=%s set_by=%s', caller, backend, set_by) + + async def get_all(self): + async with self._pool.acquire(timeout=10) as conn: + rows = await conn.fetch('SELECT * FROM ab_llm_config ORDER BY caller') + return [dict(r) for r in rows] + + async def reset(self, caller): + async with self._pool.acquire(timeout=10) as conn: + await conn.execute('DELETE FROM ab_llm_config WHERE caller = $1', caller) + logger.info('LLMConfigStore.reset caller=%s', caller) diff --git a/agent_service/llm/llm_router.py b/agent_service/llm/llm_router.py new file mode 100644 index 0000000..43fd894 --- /dev/null +++ b/agent_service/llm/llm_router.py @@ -0,0 +1,112 @@ +from __future__ import annotations +import logging, os +from .llm_types import LLMResponse, OllamaUnavailableError, ClaudeTimeoutError, ClaudeRateLimitError +from .ollama_backend import OllamaBackend +from .llm_config_store import LLMConfigStore + +logger = logging.getLogger(__name__) + +HIPAA_LOCKED_AGENTS = frozenset({'finance_agent', 'accounting_agent', 'employees_agent', 'expenses_agent'}) + + +class LLMRouter: + def __init__(self, config, pg_pool=None): + self._config = config + self._privacy_mode = getattr(config, 'llm_privacy_mode', 'local') + self._config_store = LLMConfigStore(pg_pool) if pg_pool else None + self._ollama = OllamaBackend( + url=config.ollama_url, model=config.ollama_model, + timeout=config.ollama_timeout, max_concurrent=config.ollama_max_concurrent) + self._claude = None + if self._privacy_mode != 'local': + api_key = getattr(config, 'anthropic_api_key', None) + if api_key: + from .claude_backend import ClaudeBackend + self._claude = ClaudeBackend( + api_key=api_key, model=config.claude_model, + timeout=config.claude_timeout, max_concurrent=config.claude_max_concurrent) + logger.info('ClaudeBackend initialized mode=%s', self._privacy_mode) + elif self._privacy_mode == 'cloud': + logger.error('Privacy mode is cloud but ANTHROPIC_API_KEY not set') + else: + logger.warning('Privacy mode is hybrid but ANTHROPIC_API_KEY not set') + + async def submit(self, messages, tools=None, caller='unknown'): + backend_name = await self.get_backend(caller) + if backend_name == 'claude': + if self._claude is None: + logger.warning('Claude requested but unavailable, fallback to Ollama caller=%s', caller) + backend_name = 'ollama' + else: + try: + return await self._claude.submit(messages, tools, caller) + except (ClaudeTimeoutError, ClaudeRateLimitError) as exc: + logger.warning('Claude failed caller=%s (%s), falling back to Ollama', caller, exc) + return await self._ollama.submit(messages, tools, caller) + return await self._ollama.submit(messages, tools, caller) + + async def get_backend(self, caller): + if caller in HIPAA_LOCKED_AGENTS: + return 'ollama' + if self._privacy_mode == 'local': + return 'ollama' + if self._privacy_mode == 'cloud': + return 'claude' + if self._privacy_mode == 'hybrid': + if self._config_store: + try: + db_val = await self._config_store.get_backend(caller) + if db_val: + return db_val + except Exception as exc: + logger.warning('LLMConfigStore lookup failed: %s', exc) + env_key = f'AGENT_BACKEND_{caller.upper()}' + env_val = os.environ.get(env_key) + if env_val in ('ollama', 'claude'): + return env_val + return 'ollama' + logger.error('Unknown privacy mode %s, defaulting to ollama', self._privacy_mode) + return 'ollama' + + async def set_backend(self, caller, backend, set_by, note=None): + if caller in HIPAA_LOCKED_AGENTS: + raise ValueError(f'Cannot override backend for HIPAA-locked agent: {caller}') + if backend not in ('ollama', 'claude'): + raise ValueError(f'Invalid backend: {backend}') + if not self._config_store: + raise RuntimeError('No Postgres pool for runtime config store') + await self._config_store.set_backend(caller, backend, set_by, note) + + async def set_privacy_mode(self, mode, set_by): + if mode not in ('local', 'hybrid', 'cloud'): + raise ValueError(f'Invalid privacy mode: {mode}') + self._privacy_mode = mode + if self._config_store: + await self._config_store.set_backend('__system__', mode, set_by, + f'Privacy mode changed to {mode}') + if mode == 'local': + self._claude = None + logger.info('Privacy mode set to local - ClaudeBackend disabled') + elif mode in ('hybrid', 'cloud') and self._claude is None: + api_key = getattr(self._config, 'anthropic_api_key', None) + if api_key: + from .claude_backend import ClaudeBackend + self._claude = ClaudeBackend( + api_key=api_key, model=self._config.claude_model, + timeout=self._config.claude_timeout, max_concurrent=self._config.claude_max_concurrent) + logger.info('Privacy mode set to %s by user_id=%s', mode, set_by) + + async def get_status(self): + s = {'privacy_mode': self._privacy_mode, + 'ollama': {'active': self._ollama.active_count}} + if self._claude: + s['claude'] = {'active': self._claude.active_count} + else: + s['claude'] = {'available': False, 'reason': 'local mode or no API key'} + return s + + @property + def ollama_queue_depth(self): return self._ollama.active_count + + @property + def claude_active_count(self): return self._claude.active_count if self._claude else 0 diff --git a/agent_service/llm/llm_types.py b/agent_service/llm/llm_types.py new file mode 100644 index 0000000..fd3b967 --- /dev/null +++ b/agent_service/llm/llm_types.py @@ -0,0 +1,21 @@ +from __future__ import annotations +from dataclasses import dataclass + + +@dataclass +class LLMResponse: + content: str + tool_calls: object # list | None + backend_used: str # 'ollama' | 'claude' + model_used: str + tokens_in: int + tokens_out: int + latency_ms: int + + +class LLMError(Exception): pass +class OllamaTimeoutError(LLMError): pass +class OllamaUnavailableError(LLMError): pass +class ClaudeTimeoutError(LLMError): pass +class ClaudeAuthError(LLMError): pass +class ClaudeRateLimitError(LLMError): pass diff --git a/agent_service/llm/ollama_backend.py b/agent_service/llm/ollama_backend.py new file mode 100644 index 0000000..bd358da --- /dev/null +++ b/agent_service/llm/ollama_backend.py @@ -0,0 +1,54 @@ +from __future__ import annotations +import asyncio, logging, time +from .llm_types import LLMResponse, OllamaTimeoutError, OllamaUnavailableError + +logger = logging.getLogger(__name__) + + +class OllamaBackend: + def __init__(self, url, model, timeout=120, max_concurrent=2): + self._url = url + self._model = model + self._timeout = timeout + self._semaphore = asyncio.Semaphore(max_concurrent) + self._active = 0 + + async def submit(self, messages, tools=None, caller='unknown'): + import ollama + wait_start = time.monotonic() + async with self._semaphore: + wait_ms = int((time.monotonic() - wait_start) * 1000) + self._active += 1 + t0 = time.monotonic() + try: + kwargs = {'model': self._model, 'messages': messages} + if tools: + kwargs['tools'] = tools + client = ollama.AsyncClient(host=self._url) + try: + response = await asyncio.wait_for(client.chat(**kwargs), timeout=self._timeout) + except asyncio.TimeoutError: + raise OllamaTimeoutError(f'Ollama timeout after {self._timeout}s caller={caller}') + except Exception as exc: + s = str(exc).lower() + if 'connect' in s or 'refused' in s or 'unreachable' in s: + raise OllamaUnavailableError(f'Ollama unreachable: {exc}') from exc + raise OllamaUnavailableError(f'Ollama error: {exc}') from exc + ms = int((time.monotonic() - t0) * 1000) + msg = response.message + tool_calls = None + if hasattr(msg, 'tool_calls') and msg.tool_calls: + tool_calls = [{'name': tc.function.name, 'arguments': tc.function.arguments} + for tc in msg.tool_calls] + tin = response.prompt_eval_count or 0 + tout = response.eval_count or 0 + logger.info('ollama caller=%s wait_ms=%d inf_ms=%d tin=%d tout=%d', + caller, wait_ms, ms, tin, tout) + return LLMResponse(content=msg.content or '', tool_calls=tool_calls, + backend_used='ollama', model_used=self._model, + tokens_in=tin, tokens_out=tout, latency_ms=ms) + finally: + self._active -= 1 + + @property + def active_count(self): return self._active diff --git a/agent_service/llm/tool_validator.py b/agent_service/llm/tool_validator.py new file mode 100644 index 0000000..1149c2b --- /dev/null +++ b/agent_service/llm/tool_validator.py @@ -0,0 +1,89 @@ +from __future__ import annotations +import json, logging +from typing import Any + +logger = logging.getLogger(__name__) + +MAX_TOOLS_PER_AGENT = 8 + +_TYPE_MAP = { + 'string': str, 'integer': int, 'number': (int, float), + 'boolean': bool, 'object': dict, 'array': list, +} + + +class ToolValidationError(Exception): pass +class AgentConfigError(Exception): pass + + +def validate_agent_tools(tools, agent_name): + if len(tools) > MAX_TOOLS_PER_AGENT: + raise AgentConfigError( + f'{agent_name} has {len(tools)} tools but max is {MAX_TOOLS_PER_AGENT}. ' + f'Split into contextual tool groups.') + + +class ToolCallValidator: + def __init__(self, tools): + self._tools = {t["name"]: t for t in tools} + + def validate(self, tool_call): + name = tool_call.get('name') or (tool_call.get('function') or {}).get('name') + if not name: + raise ToolValidationError('Tool call missing name field') + if name not in self._tools: + raise ToolValidationError(f'Unknown tool {name!r} not in agent tool list') + tool_def = self._tools[name] + params_schema = tool_def.get('parameters', {}) + arguments = tool_call.get('arguments') or (tool_call.get('function') or {}).get('arguments', {}) + if isinstance(arguments, str): + try: + arguments = json.loads(arguments) + except json.JSONDecodeError as exc: + raise ToolValidationError(f'Tool {name!r} arguments is invalid JSON: {exc}') from exc + if not isinstance(arguments, dict): + raise ToolValidationError(f'Tool {name!r} arguments must be dict, got {type(arguments)}') + cleaned: dict[str, Any] = {} + for key, value in arguments.items(): + if key not in params_schema: + logger.warning('ToolCallValidator: stripping hallucinated param %r from tool %r', key, name) + continue + cleaned[key] = value + for key, schema in params_schema.items(): + if schema.get('optional', False): + continue + if key not in cleaned: + if 'default' in schema: + cleaned[key] = schema['default'] + else: + raise ToolValidationError(f'Tool {name!r} missing required param {key!r}') + for key, value in list(cleaned.items()): + schema = params_schema.get(key, {}) + expected_type = schema.get('type') + if expected_type and expected_type in _TYPE_MAP: + expected = _TYPE_MAP[expected_type] + if not isinstance(value, expected): + if expected_type in ('integer', 'number') and isinstance(value, str): + try: + cleaned[key] = int(value) if expected_type == 'integer' else float(value) + except ValueError: + raise ToolValidationError( + f'Tool {name!r} param {key!r} expected {expected_type}, got {value!r}') + elif expected_type == 'string': + cleaned[key] = str(value) + else: + raise ToolValidationError( + f'Tool {name!r} param {key!r} expected {expected_type}, got {type(value).__name__}') + if 'enum' in schema and cleaned.get(key) not in schema['enum']: + raise ToolValidationError( + f'Tool {name!r} param {key!r} value {cleaned.get(key)!r} not in {schema["enum"]}') + return {'name': name, 'arguments': cleaned} + + def parse_or_fallback(self, tool_call): + if tool_call is None: + return None + try: + return self.validate(tool_call) + except ToolValidationError as exc: + logger.warning('ToolCallValidator fallback: %s', exc) + return None