feat: add base agent and peer bus
This commit is contained in:
196
agent_service/agents/base_agent.py
Normal file
196
agent_service/agents/base_agent.py
Normal file
@@ -0,0 +1,196 @@
|
||||
from __future__ import annotations
|
||||
import asyncio, json, logging, time
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from ..llm.tool_validator import ToolCallValidator, ToolValidationError, validate_agent_tools
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentDirective:
|
||||
directive_id: str
|
||||
agent: str
|
||||
task: str
|
||||
params: dict
|
||||
context: object # DirectiveContext
|
||||
authorized_actions: list
|
||||
constraints: dict
|
||||
approved: bool = False
|
||||
approval_item_id: object = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class DirectiveContext:
|
||||
client_profile: dict = field(default_factory=dict)
|
||||
recent_findings: list = field(default_factory=list)
|
||||
conversation_summary: str = ''
|
||||
peer_data: dict = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentReport:
|
||||
directive_id: str
|
||||
agent: str
|
||||
status: str # complete | partial | failed | escalated
|
||||
summary: str
|
||||
actions_taken: list = field(default_factory=list)
|
||||
escalations: list = field(default_factory=list)
|
||||
peer_calls_made: list = field(default_factory=list)
|
||||
recommendations: list = field(default_factory=list)
|
||||
data: dict = field(default_factory=dict)
|
||||
error: object = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SweepReport:
|
||||
agent: str
|
||||
findings: list = field(default_factory=list)
|
||||
actions_taken: list = field(default_factory=list)
|
||||
recommendations: list = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolResult:
|
||||
name: str
|
||||
success: bool
|
||||
data: Any = None
|
||||
error: object = None
|
||||
|
||||
|
||||
AUTO_ACT_THRESHOLDS = {
|
||||
'read': True, 'search': True, 'report': True,
|
||||
'create_non_financial': True,
|
||||
'create_financial': False,
|
||||
'write_non_financial': True,
|
||||
'write_financial_over_5k': False,
|
||||
'post_chatter': True,
|
||||
'send_email': True,
|
||||
'confirm_document': False,
|
||||
'delete': False,
|
||||
'archive': False,
|
||||
'post_journal_entry': False,
|
||||
'hipaa_sensitive': False,
|
||||
}
|
||||
|
||||
|
||||
class BaseAgent(ABC):
|
||||
name: str = 'base'
|
||||
domain: str = 'base'
|
||||
required_odoo_module: str = 'base'
|
||||
system_prompt_file: str = ''
|
||||
tools: list = []
|
||||
|
||||
def __init__(self, odoo, llm, peer_bus=None):
|
||||
self._odoo = odoo
|
||||
self._llm = llm
|
||||
self._peer_bus = peer_bus
|
||||
self._directive: AgentDirective | None = None
|
||||
self._gathered: dict = {}
|
||||
self._messages: list = []
|
||||
validate_agent_tools(self.tools, self.name)
|
||||
self._validator = ToolCallValidator(self.tools)
|
||||
|
||||
async def execute(self, directive: AgentDirective) -> AgentReport:
|
||||
self._directive = directive
|
||||
self._gathered = {}
|
||||
self._messages = []
|
||||
t0 = time.monotonic()
|
||||
try:
|
||||
await self._receive(directive)
|
||||
plan = await self._plan()
|
||||
await self._gather(plan)
|
||||
reasoning = await self._reason()
|
||||
await self._act(reasoning)
|
||||
report = await self._report()
|
||||
ms = int((time.monotonic() - t0) * 1000)
|
||||
logger.info('agent=%s directive=%s status=%s ms=%d',
|
||||
self.name, directive.directive_id, report.status, ms)
|
||||
return report
|
||||
except Exception as exc:
|
||||
logger.error('agent=%s directive=%s FAILED: %s', self.name, directive.directive_id, exc)
|
||||
return AgentReport(
|
||||
directive_id=directive.directive_id, agent=self.name,
|
||||
status='failed', summary=f'Agent failed: {exc}', error=str(exc))
|
||||
|
||||
async def _receive(self, directive):
|
||||
logger.info('agent=%s received directive=%s task=%s', self.name, directive.directive_id, directive.task[:80])
|
||||
|
||||
@abstractmethod
|
||||
async def _plan(self): ...
|
||||
|
||||
@abstractmethod
|
||||
async def _gather(self, plan): ...
|
||||
|
||||
@abstractmethod
|
||||
async def _reason(self): ...
|
||||
|
||||
@abstractmethod
|
||||
async def _act(self, reasoning): ...
|
||||
|
||||
@abstractmethod
|
||||
async def _report(self) -> AgentReport: ...
|
||||
|
||||
async def _run_tool(self, name, params) -> ToolResult:
|
||||
logger.info('agent=%s tool=%s params=%s', self.name, name, list(params.keys()))
|
||||
tool_fn = getattr(self, f'_tool_{name}', None)
|
||||
if tool_fn is None:
|
||||
return ToolResult(name=name, success=False, error=f'Unknown tool: {name}')
|
||||
try:
|
||||
data = await tool_fn(**params)
|
||||
return ToolResult(name=name, success=True, data=data)
|
||||
except Exception as exc:
|
||||
logger.error('agent=%s tool=%s failed: %s', self.name, name, exc)
|
||||
return ToolResult(name=name, success=False, error=str(exc))
|
||||
|
||||
async def _should_auto_act(self, action_type, params=None) -> tuple:
|
||||
auto = AUTO_ACT_THRESHOLDS.get(action_type, False)
|
||||
if auto:
|
||||
return True, 'auto-approved by threshold policy'
|
||||
return False, f'action_type={action_type} requires human approval'
|
||||
|
||||
async def _escalate(self, reason, model, record_id, severity='medium'):
|
||||
item = {
|
||||
'reason': reason, 'model': model, 'record_id': record_id,
|
||||
'severity': severity, 'agent': self.name,
|
||||
'directive_id': self._directive.directive_id if self._directive else None
|
||||
}
|
||||
if not hasattr(self, '_escalations'):
|
||||
self._escalations = []
|
||||
self._escalations.append(item)
|
||||
logger.info('agent=%s escalation severity=%s reason=%s', self.name, severity, reason)
|
||||
return item
|
||||
|
||||
async def _loop(self, messages, tools=None, max_iter=10) -> str:
|
||||
current = list(messages)
|
||||
active_tools = tools or self.tools
|
||||
for iteration in range(max_iter):
|
||||
resp = await self._llm.submit(current, tools=active_tools, caller=self.name)
|
||||
if resp.tool_calls:
|
||||
for raw_call in resp.tool_calls:
|
||||
validated = self._validator.parse_or_fallback(raw_call)
|
||||
if validated is None:
|
||||
# Fallback: ask model for plain text
|
||||
current.append({'role': 'user',
|
||||
'content': 'Tool call was invalid. Please provide your analysis as plain text.'})
|
||||
break
|
||||
tool_result = await self._run_tool(validated['name'], validated['arguments'])
|
||||
current.append({'role': 'assistant', 'content': resp.content or '', 'tool_calls': [raw_call]})
|
||||
current.append({
|
||||
'role': 'tool',
|
||||
'name': validated['name'],
|
||||
'content': json.dumps({'success': tool_result.success, 'data': tool_result.data,
|
||||
'error': tool_result.error})
|
||||
})
|
||||
else:
|
||||
return resp.content
|
||||
logger.warning('agent=%s max iterations %d reached', self.name, max_iter)
|
||||
await self._escalate('Max iterations exceeded', '', 0, severity='medium')
|
||||
return 'Reached maximum iterations. Partial results only.'
|
||||
|
||||
async def handle_peer_request(self, request_type, params, directive_id) -> dict:
|
||||
return {'success': False, 'error': 'Peer requests not implemented for this agent'}
|
||||
|
||||
async def sweep(self) -> SweepReport:
|
||||
return SweepReport(agent=self.name, findings=[], recommendations=[])
|
||||
Reference in New Issue
Block a user