197 lines
7.1 KiB
Python
197 lines
7.1 KiB
Python
from __future__ import annotations
|
|
import asyncio, json, logging, time
|
|
from abc import ABC, abstractmethod
|
|
from dataclasses import dataclass, field
|
|
from typing import Any
|
|
from ..llm.tool_validator import ToolCallValidator, ToolValidationError, validate_agent_tools
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class AgentDirective:
|
|
directive_id: str
|
|
agent: str
|
|
task: str
|
|
params: dict
|
|
context: object # DirectiveContext
|
|
authorized_actions: list
|
|
constraints: dict
|
|
approved: bool = False
|
|
approval_item_id: object = None
|
|
|
|
|
|
@dataclass
|
|
class DirectiveContext:
|
|
client_profile: dict = field(default_factory=dict)
|
|
recent_findings: list = field(default_factory=list)
|
|
conversation_summary: str = ''
|
|
peer_data: dict = field(default_factory=dict)
|
|
|
|
|
|
@dataclass
|
|
class AgentReport:
|
|
directive_id: str
|
|
agent: str
|
|
status: str # complete | partial | failed | escalated
|
|
summary: str
|
|
actions_taken: list = field(default_factory=list)
|
|
escalations: list = field(default_factory=list)
|
|
peer_calls_made: list = field(default_factory=list)
|
|
recommendations: list = field(default_factory=list)
|
|
data: dict = field(default_factory=dict)
|
|
error: object = None
|
|
|
|
|
|
@dataclass
|
|
class SweepReport:
|
|
agent: str
|
|
findings: list = field(default_factory=list)
|
|
actions_taken: list = field(default_factory=list)
|
|
recommendations: list = field(default_factory=list)
|
|
|
|
|
|
@dataclass
|
|
class ToolResult:
|
|
name: str
|
|
success: bool
|
|
data: Any = None
|
|
error: object = None
|
|
|
|
|
|
AUTO_ACT_THRESHOLDS = {
|
|
'read': True, 'search': True, 'report': True,
|
|
'create_non_financial': True,
|
|
'create_financial': False,
|
|
'write_non_financial': True,
|
|
'write_financial_over_5k': False,
|
|
'post_chatter': True,
|
|
'send_email': True,
|
|
'confirm_document': False,
|
|
'delete': False,
|
|
'archive': False,
|
|
'post_journal_entry': False,
|
|
'hipaa_sensitive': False,
|
|
}
|
|
|
|
|
|
class BaseAgent(ABC):
|
|
name: str = 'base'
|
|
domain: str = 'base'
|
|
required_odoo_module: str = 'base'
|
|
system_prompt_file: str = ''
|
|
tools: list = []
|
|
|
|
def __init__(self, odoo, llm, peer_bus=None):
|
|
self._odoo = odoo
|
|
self._llm = llm
|
|
self._peer_bus = peer_bus
|
|
self._directive: AgentDirective | None = None
|
|
self._gathered: dict = {}
|
|
self._messages: list = []
|
|
validate_agent_tools(self.tools, self.name)
|
|
self._validator = ToolCallValidator(self.tools)
|
|
|
|
async def execute(self, directive: AgentDirective) -> AgentReport:
|
|
self._directive = directive
|
|
self._gathered = {}
|
|
self._messages = []
|
|
t0 = time.monotonic()
|
|
try:
|
|
await self._receive(directive)
|
|
plan = await self._plan()
|
|
await self._gather(plan)
|
|
reasoning = await self._reason()
|
|
await self._act(reasoning)
|
|
report = await self._report()
|
|
ms = int((time.monotonic() - t0) * 1000)
|
|
logger.info('agent=%s directive=%s status=%s ms=%d',
|
|
self.name, directive.directive_id, report.status, ms)
|
|
return report
|
|
except Exception as exc:
|
|
logger.error('agent=%s directive=%s FAILED: %s', self.name, directive.directive_id, exc)
|
|
return AgentReport(
|
|
directive_id=directive.directive_id, agent=self.name,
|
|
status='failed', summary=f'Agent failed: {exc}', error=str(exc))
|
|
|
|
async def _receive(self, directive):
|
|
logger.info('agent=%s received directive=%s task=%s', self.name, directive.directive_id, directive.task[:80])
|
|
|
|
@abstractmethod
|
|
async def _plan(self): ...
|
|
|
|
@abstractmethod
|
|
async def _gather(self, plan): ...
|
|
|
|
@abstractmethod
|
|
async def _reason(self): ...
|
|
|
|
@abstractmethod
|
|
async def _act(self, reasoning): ...
|
|
|
|
@abstractmethod
|
|
async def _report(self) -> AgentReport: ...
|
|
|
|
async def _run_tool(self, name, params) -> ToolResult:
|
|
logger.info('agent=%s tool=%s params=%s', self.name, name, list(params.keys()))
|
|
tool_fn = getattr(self, f'_tool_{name}', None)
|
|
if tool_fn is None:
|
|
return ToolResult(name=name, success=False, error=f'Unknown tool: {name}')
|
|
try:
|
|
data = await tool_fn(**params)
|
|
return ToolResult(name=name, success=True, data=data)
|
|
except Exception as exc:
|
|
logger.error('agent=%s tool=%s failed: %s', self.name, name, exc)
|
|
return ToolResult(name=name, success=False, error=str(exc))
|
|
|
|
async def _should_auto_act(self, action_type, params=None) -> tuple:
|
|
auto = AUTO_ACT_THRESHOLDS.get(action_type, False)
|
|
if auto:
|
|
return True, 'auto-approved by threshold policy'
|
|
return False, f'action_type={action_type} requires human approval'
|
|
|
|
async def _escalate(self, reason, model, record_id, severity='medium'):
|
|
item = {
|
|
'reason': reason, 'model': model, 'record_id': record_id,
|
|
'severity': severity, 'agent': self.name,
|
|
'directive_id': self._directive.directive_id if self._directive else None
|
|
}
|
|
if not hasattr(self, '_escalations'):
|
|
self._escalations = []
|
|
self._escalations.append(item)
|
|
logger.info('agent=%s escalation severity=%s reason=%s', self.name, severity, reason)
|
|
return item
|
|
|
|
async def _loop(self, messages, tools=None, max_iter=10) -> str:
|
|
current = list(messages)
|
|
active_tools = tools or self.tools
|
|
for iteration in range(max_iter):
|
|
resp = await self._llm.submit(current, tools=active_tools, caller=self.name)
|
|
if resp.tool_calls:
|
|
for raw_call in resp.tool_calls:
|
|
validated = self._validator.parse_or_fallback(raw_call)
|
|
if validated is None:
|
|
# Fallback: ask model for plain text
|
|
current.append({'role': 'user',
|
|
'content': 'Tool call was invalid. Please provide your analysis as plain text.'})
|
|
break
|
|
tool_result = await self._run_tool(validated['name'], validated['arguments'])
|
|
current.append({'role': 'assistant', 'content': resp.content or '', 'tool_calls': [raw_call]})
|
|
current.append({
|
|
'role': 'tool',
|
|
'name': validated['name'],
|
|
'content': json.dumps({'success': tool_result.success, 'data': tool_result.data,
|
|
'error': tool_result.error})
|
|
})
|
|
else:
|
|
return resp.content
|
|
logger.warning('agent=%s max iterations %d reached', self.name, max_iter)
|
|
await self._escalate('Max iterations exceeded', '', 0, severity='medium')
|
|
return 'Reached maximum iterations. Partial results only.'
|
|
|
|
async def handle_peer_request(self, request_type, params, directive_id) -> dict:
|
|
return {'success': False, 'error': 'Peer requests not implemented for this agent'}
|
|
|
|
async def sweep(self) -> SweepReport:
|
|
return SweepReport(agent=self.name, findings=[], recommendations=[])
|