feat: add base agent and peer bus
This commit is contained in:
0
agent_service/agents/__init__.py
Normal file
0
agent_service/agents/__init__.py
Normal file
196
agent_service/agents/base_agent.py
Normal file
196
agent_service/agents/base_agent.py
Normal file
@@ -0,0 +1,196 @@
|
||||
from __future__ import annotations
|
||||
import asyncio, json, logging, time
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from ..llm.tool_validator import ToolCallValidator, ToolValidationError, validate_agent_tools
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentDirective:
|
||||
directive_id: str
|
||||
agent: str
|
||||
task: str
|
||||
params: dict
|
||||
context: object # DirectiveContext
|
||||
authorized_actions: list
|
||||
constraints: dict
|
||||
approved: bool = False
|
||||
approval_item_id: object = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class DirectiveContext:
|
||||
client_profile: dict = field(default_factory=dict)
|
||||
recent_findings: list = field(default_factory=list)
|
||||
conversation_summary: str = ''
|
||||
peer_data: dict = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentReport:
|
||||
directive_id: str
|
||||
agent: str
|
||||
status: str # complete | partial | failed | escalated
|
||||
summary: str
|
||||
actions_taken: list = field(default_factory=list)
|
||||
escalations: list = field(default_factory=list)
|
||||
peer_calls_made: list = field(default_factory=list)
|
||||
recommendations: list = field(default_factory=list)
|
||||
data: dict = field(default_factory=dict)
|
||||
error: object = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SweepReport:
|
||||
agent: str
|
||||
findings: list = field(default_factory=list)
|
||||
actions_taken: list = field(default_factory=list)
|
||||
recommendations: list = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolResult:
|
||||
name: str
|
||||
success: bool
|
||||
data: Any = None
|
||||
error: object = None
|
||||
|
||||
|
||||
AUTO_ACT_THRESHOLDS = {
|
||||
'read': True, 'search': True, 'report': True,
|
||||
'create_non_financial': True,
|
||||
'create_financial': False,
|
||||
'write_non_financial': True,
|
||||
'write_financial_over_5k': False,
|
||||
'post_chatter': True,
|
||||
'send_email': True,
|
||||
'confirm_document': False,
|
||||
'delete': False,
|
||||
'archive': False,
|
||||
'post_journal_entry': False,
|
||||
'hipaa_sensitive': False,
|
||||
}
|
||||
|
||||
|
||||
class BaseAgent(ABC):
|
||||
name: str = 'base'
|
||||
domain: str = 'base'
|
||||
required_odoo_module: str = 'base'
|
||||
system_prompt_file: str = ''
|
||||
tools: list = []
|
||||
|
||||
def __init__(self, odoo, llm, peer_bus=None):
|
||||
self._odoo = odoo
|
||||
self._llm = llm
|
||||
self._peer_bus = peer_bus
|
||||
self._directive: AgentDirective | None = None
|
||||
self._gathered: dict = {}
|
||||
self._messages: list = []
|
||||
validate_agent_tools(self.tools, self.name)
|
||||
self._validator = ToolCallValidator(self.tools)
|
||||
|
||||
async def execute(self, directive: AgentDirective) -> AgentReport:
|
||||
self._directive = directive
|
||||
self._gathered = {}
|
||||
self._messages = []
|
||||
t0 = time.monotonic()
|
||||
try:
|
||||
await self._receive(directive)
|
||||
plan = await self._plan()
|
||||
await self._gather(plan)
|
||||
reasoning = await self._reason()
|
||||
await self._act(reasoning)
|
||||
report = await self._report()
|
||||
ms = int((time.monotonic() - t0) * 1000)
|
||||
logger.info('agent=%s directive=%s status=%s ms=%d',
|
||||
self.name, directive.directive_id, report.status, ms)
|
||||
return report
|
||||
except Exception as exc:
|
||||
logger.error('agent=%s directive=%s FAILED: %s', self.name, directive.directive_id, exc)
|
||||
return AgentReport(
|
||||
directive_id=directive.directive_id, agent=self.name,
|
||||
status='failed', summary=f'Agent failed: {exc}', error=str(exc))
|
||||
|
||||
async def _receive(self, directive):
|
||||
logger.info('agent=%s received directive=%s task=%s', self.name, directive.directive_id, directive.task[:80])
|
||||
|
||||
@abstractmethod
|
||||
async def _plan(self): ...
|
||||
|
||||
@abstractmethod
|
||||
async def _gather(self, plan): ...
|
||||
|
||||
@abstractmethod
|
||||
async def _reason(self): ...
|
||||
|
||||
@abstractmethod
|
||||
async def _act(self, reasoning): ...
|
||||
|
||||
@abstractmethod
|
||||
async def _report(self) -> AgentReport: ...
|
||||
|
||||
async def _run_tool(self, name, params) -> ToolResult:
|
||||
logger.info('agent=%s tool=%s params=%s', self.name, name, list(params.keys()))
|
||||
tool_fn = getattr(self, f'_tool_{name}', None)
|
||||
if tool_fn is None:
|
||||
return ToolResult(name=name, success=False, error=f'Unknown tool: {name}')
|
||||
try:
|
||||
data = await tool_fn(**params)
|
||||
return ToolResult(name=name, success=True, data=data)
|
||||
except Exception as exc:
|
||||
logger.error('agent=%s tool=%s failed: %s', self.name, name, exc)
|
||||
return ToolResult(name=name, success=False, error=str(exc))
|
||||
|
||||
async def _should_auto_act(self, action_type, params=None) -> tuple:
|
||||
auto = AUTO_ACT_THRESHOLDS.get(action_type, False)
|
||||
if auto:
|
||||
return True, 'auto-approved by threshold policy'
|
||||
return False, f'action_type={action_type} requires human approval'
|
||||
|
||||
async def _escalate(self, reason, model, record_id, severity='medium'):
|
||||
item = {
|
||||
'reason': reason, 'model': model, 'record_id': record_id,
|
||||
'severity': severity, 'agent': self.name,
|
||||
'directive_id': self._directive.directive_id if self._directive else None
|
||||
}
|
||||
if not hasattr(self, '_escalations'):
|
||||
self._escalations = []
|
||||
self._escalations.append(item)
|
||||
logger.info('agent=%s escalation severity=%s reason=%s', self.name, severity, reason)
|
||||
return item
|
||||
|
||||
async def _loop(self, messages, tools=None, max_iter=10) -> str:
|
||||
current = list(messages)
|
||||
active_tools = tools or self.tools
|
||||
for iteration in range(max_iter):
|
||||
resp = await self._llm.submit(current, tools=active_tools, caller=self.name)
|
||||
if resp.tool_calls:
|
||||
for raw_call in resp.tool_calls:
|
||||
validated = self._validator.parse_or_fallback(raw_call)
|
||||
if validated is None:
|
||||
# Fallback: ask model for plain text
|
||||
current.append({'role': 'user',
|
||||
'content': 'Tool call was invalid. Please provide your analysis as plain text.'})
|
||||
break
|
||||
tool_result = await self._run_tool(validated['name'], validated['arguments'])
|
||||
current.append({'role': 'assistant', 'content': resp.content or '', 'tool_calls': [raw_call]})
|
||||
current.append({
|
||||
'role': 'tool',
|
||||
'name': validated['name'],
|
||||
'content': json.dumps({'success': tool_result.success, 'data': tool_result.data,
|
||||
'error': tool_result.error})
|
||||
})
|
||||
else:
|
||||
return resp.content
|
||||
logger.warning('agent=%s max iterations %d reached', self.name, max_iter)
|
||||
await self._escalate('Max iterations exceeded', '', 0, severity='medium')
|
||||
return 'Reached maximum iterations. Partial results only.'
|
||||
|
||||
async def handle_peer_request(self, request_type, params, directive_id) -> dict:
|
||||
return {'success': False, 'error': 'Peer requests not implemented for this agent'}
|
||||
|
||||
async def sweep(self) -> SweepReport:
|
||||
return SweepReport(agent=self.name, findings=[], recommendations=[])
|
||||
66
agent_service/agents/peer_bus.py
Normal file
66
agent_service/agents/peer_bus.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from __future__ import annotations
|
||||
import asyncio, logging
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
MAX_DEPTH = 3
|
||||
|
||||
|
||||
@dataclass
|
||||
class PeerResponse:
|
||||
available: bool
|
||||
success: bool = False
|
||||
data: dict = field(default_factory=dict)
|
||||
agent: str = ''
|
||||
request_type: str = ''
|
||||
error: object = None
|
||||
|
||||
|
||||
class PeerCircularRequestError(Exception): pass
|
||||
|
||||
|
||||
class PeerBus:
|
||||
def __init__(self, registry, directive_id):
|
||||
self._registry = registry
|
||||
self._directive_id = directive_id
|
||||
self._call_log: list[dict] = []
|
||||
self._call_chain: list[str] = []
|
||||
|
||||
async def request(self, from_agent, to_agent, request_type, params, reason):
|
||||
if to_agent in self._call_chain:
|
||||
logger.warning('PeerBus: circular request blocked %s->%s chain=%s',
|
||||
from_agent, to_agent, self._call_chain)
|
||||
raise PeerCircularRequestError(f'Circular peer request: {from_agent}->{to_agent}')
|
||||
if len(self._call_chain) >= MAX_DEPTH:
|
||||
logger.warning('PeerBus: max depth %d reached', MAX_DEPTH)
|
||||
return PeerResponse(available=False, error=f'Max peer depth {MAX_DEPTH} reached')
|
||||
if not await self._registry.is_active(to_agent):
|
||||
logger.debug('PeerBus: agent %s inactive, from=%s', to_agent, from_agent)
|
||||
return PeerResponse(available=False, agent=to_agent, request_type=request_type)
|
||||
agent = self._registry.get_agent_instance(to_agent)
|
||||
if agent is None:
|
||||
return PeerResponse(available=False, agent=to_agent, request_type=request_type)
|
||||
self._call_chain.append(from_agent)
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
agent.handle_peer_request(request_type, params, self._directive_id),
|
||||
timeout=30)
|
||||
entry = {'from': from_agent, 'to': to_agent, 'type': request_type,
|
||||
'params': params, 'reason': reason, 'success': result.get('success', True)}
|
||||
self._call_log.append(entry)
|
||||
logger.debug('PeerBus: %s->%s type=%s ok', from_agent, to_agent, request_type)
|
||||
return PeerResponse(available=True, success=True, data=result,
|
||||
agent=to_agent, request_type=request_type)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning('PeerBus: timeout %s->%s', from_agent, to_agent)
|
||||
return PeerResponse(available=True, success=False, agent=to_agent,
|
||||
request_type=request_type, error='Peer timeout after 30s')
|
||||
except Exception as exc:
|
||||
logger.error('PeerBus: error %s->%s: %s', from_agent, to_agent, exc)
|
||||
return PeerResponse(available=True, success=False, agent=to_agent,
|
||||
request_type=request_type, error=str(exc))
|
||||
finally:
|
||||
self._call_chain.pop()
|
||||
|
||||
@property
|
||||
def call_log(self): return self._call_log
|
||||
40
agent_service/agents/registry.py
Normal file
40
agent_service/agents/registry.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentRegistry:
|
||||
def __init__(self):
|
||||
self._agents: dict = {} # agent_key -> BaseAgent instance
|
||||
self._active: set = set()
|
||||
self._capabilities: dict = {}
|
||||
|
||||
async def load_from_odoo(self, odoo_client):
|
||||
try:
|
||||
rows = await odoo_client.search_read(
|
||||
'ab.ai.agent.registry',
|
||||
[['is_active', '=', True]],
|
||||
['agent_key', 'name', 'capabilities_summary', 'sweep_enabled', 'sweep_interval_hours'])
|
||||
self._active = {r['agent_key'] for r in rows}
|
||||
self._capabilities = {r['agent_key']: r.get('capabilities_summary', '') for r in rows}
|
||||
logger.info('AgentRegistry loaded: active=%s', list(self._active))
|
||||
except Exception as exc:
|
||||
logger.error('AgentRegistry.load_from_odoo failed: %s', exc)
|
||||
|
||||
async def get_active_agents(self):
|
||||
return [{'agent_key': k, 'capabilities_summary': self._capabilities.get(k, '')}
|
||||
for k in self._active]
|
||||
|
||||
async def is_active(self, agent_key):
|
||||
return agent_key in self._active
|
||||
|
||||
async def sync(self, active_keys):
|
||||
self._active = set(active_keys)
|
||||
logger.info('AgentRegistry synced: active=%s', active_keys)
|
||||
|
||||
def register(self, agent_key, agent_instance):
|
||||
self._agents[agent_key] = agent_instance
|
||||
|
||||
def get_agent_instance(self, agent_key):
|
||||
return self._agents.get(agent_key)
|
||||
Reference in New Issue
Block a user