feat: add 3-tier memory layer
This commit is contained in:
0
agent_service/memory/__init__.py
Normal file
0
agent_service/memory/__init__.py
Normal file
47
agent_service/memory/conversation_store.py
Normal file
47
agent_service/memory/conversation_store.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
HARD_CAP = 200
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationStore:
|
||||||
|
def __init__(self, pool):
|
||||||
|
self._pool = pool
|
||||||
|
|
||||||
|
async def append(self, user_id, role, content, directive_id=None, is_summary=False):
|
||||||
|
async with self._pool.acquire(timeout=10) as conn:
|
||||||
|
await conn.execute(
|
||||||
|
"""INSERT INTO ab_conversation_memory
|
||||||
|
(user_id, role, content, directive_id, is_summary)
|
||||||
|
VALUES ($1, $2, $3, $4, $5)""",
|
||||||
|
user_id, role, content, directive_id, is_summary)
|
||||||
|
|
||||||
|
async def get(self, user_id, limit=50):
|
||||||
|
async with self._pool.acquire(timeout=10) as conn:
|
||||||
|
rows = await conn.fetch(
|
||||||
|
"""SELECT id, role, content, directive_id, is_summary, created_at
|
||||||
|
FROM ab_conversation_memory
|
||||||
|
WHERE user_id = $1
|
||||||
|
ORDER BY created_at DESC LIMIT $2""",
|
||||||
|
user_id, limit)
|
||||||
|
return [dict(r) for r in reversed(rows)]
|
||||||
|
|
||||||
|
async def count(self, user_id):
|
||||||
|
async with self._pool.acquire(timeout=10) as conn:
|
||||||
|
row = await conn.fetchrow(
|
||||||
|
'SELECT COUNT(*) as n FROM ab_conversation_memory WHERE user_id = $1 AND is_summary = false',
|
||||||
|
user_id)
|
||||||
|
return row['n']
|
||||||
|
|
||||||
|
async def prune_old(self, user_id, keep=50):
|
||||||
|
async with self._pool.acquire(timeout=10) as conn:
|
||||||
|
await conn.execute(
|
||||||
|
"""DELETE FROM ab_conversation_memory
|
||||||
|
WHERE user_id = $1 AND is_summary = false
|
||||||
|
AND id NOT IN (
|
||||||
|
SELECT id FROM ab_conversation_memory
|
||||||
|
WHERE user_id = $1 AND is_summary = false
|
||||||
|
ORDER BY created_at DESC LIMIT $2
|
||||||
|
)""",
|
||||||
|
user_id, keep)
|
||||||
31
agent_service/memory/knowledge_store.py
Normal file
31
agent_service/memory/knowledge_store.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
import json, logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class KnowledgeStore:
|
||||||
|
def __init__(self, pool):
|
||||||
|
self._pool = pool
|
||||||
|
|
||||||
|
async def upsert(self, entity_type, entity_key, facts):
|
||||||
|
async with self._pool.acquire(timeout=10) as conn:
|
||||||
|
await conn.execute(
|
||||||
|
"""INSERT INTO ab_knowledge_store (entity_type, entity_key, facts)
|
||||||
|
VALUES ($1, $2, $3)
|
||||||
|
ON CONFLICT (entity_type, entity_key)
|
||||||
|
DO UPDATE SET facts = $3, updated_at = NOW()""",
|
||||||
|
entity_type, entity_key, json.dumps(facts))
|
||||||
|
|
||||||
|
async def get(self, entity_type, entity_key):
|
||||||
|
async with self._pool.acquire(timeout=10) as conn:
|
||||||
|
row = await conn.fetchrow(
|
||||||
|
'SELECT facts FROM ab_knowledge_store WHERE entity_type=$1 AND entity_key=$2',
|
||||||
|
entity_type, entity_key)
|
||||||
|
if not row:
|
||||||
|
return {}
|
||||||
|
f = row['facts']
|
||||||
|
return json.loads(f) if isinstance(f, str) else f
|
||||||
|
|
||||||
|
async def get_client_profile(self, partner_id):
|
||||||
|
return await self.get('client', f'partner_{partner_id}')
|
||||||
151
agent_service/memory/memory_manager.py
Normal file
151
agent_service/memory/memory_manager.py
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from .conversation_store import ConversationStore, HARD_CAP
|
||||||
|
from .operational_store import OperationalStore
|
||||||
|
from .knowledge_store import KnowledgeStore
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
SUMMARIZE_THRESHOLD = 50
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MasterContext:
|
||||||
|
user_id: int
|
||||||
|
conversation: list = field(default_factory=list)
|
||||||
|
operational_findings: list = field(default_factory=list)
|
||||||
|
knowledge: dict = field(default_factory=dict)
|
||||||
|
pending_approvals: list = field(default_factory=list)
|
||||||
|
active_directives: list = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryManager:
|
||||||
|
def __init__(self, pool, llm_router=None):
|
||||||
|
self._pool = pool
|
||||||
|
self._llm = llm_router
|
||||||
|
self._conv = ConversationStore(pool)
|
||||||
|
self._ops = OperationalStore(pool)
|
||||||
|
self._know = KnowledgeStore(pool)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Tier 1 — Conversation
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def append_message(self, user_id, role, content, directive_id=None):
|
||||||
|
await self._conv.append(user_id, role, content, directive_id)
|
||||||
|
await self.summarize_if_needed(user_id)
|
||||||
|
|
||||||
|
async def get_conversation(self, user_id, limit=50):
|
||||||
|
return await self._conv.get(user_id, limit)
|
||||||
|
|
||||||
|
async def summarize_if_needed(self, user_id):
|
||||||
|
count = await self._conv.count(user_id)
|
||||||
|
if count < SUMMARIZE_THRESHOLD:
|
||||||
|
return
|
||||||
|
if count >= HARD_CAP - 10:
|
||||||
|
logger.warning('Conversation memory near hard cap for user_id=%s count=%d', user_id, count)
|
||||||
|
if self._llm and count >= SUMMARIZE_THRESHOLD:
|
||||||
|
recent = await self._conv.get(user_id, limit=SUMMARIZE_THRESHOLD)
|
||||||
|
if not recent:
|
||||||
|
return
|
||||||
|
history_text = chr(10).join(m['role'] + ': ' + m['content'] for m in recent[:20])
|
||||||
|
messages = [
|
||||||
|
{'role': 'system', 'content': 'Summarize this conversation history in 3-5 sentences, preserving key decisions and context.'},
|
||||||
|
{'role': 'user', 'content': history_text}
|
||||||
|
]
|
||||||
|
try:
|
||||||
|
resp = await self._llm.submit(messages, caller='memory_manager')
|
||||||
|
summary = resp.content
|
||||||
|
await self._conv.append(user_id, 'system', f'[SUMMARY] {summary}', is_summary=True)
|
||||||
|
await self._conv.prune_old(user_id, keep=20)
|
||||||
|
logger.info('Conversation summarized for user_id=%s', user_id)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error('Conversation summarization failed user_id=%s: %s', user_id, exc)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Tier 2 — Operational
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def store_findings(self, scope, summary, raw_data=None, ttl_days=90, source_directive_id=None):
|
||||||
|
await self._ops.store(scope, summary, raw_data, ttl_days, source_directive_id)
|
||||||
|
|
||||||
|
async def get_recent_findings(self, scope, limit=10):
|
||||||
|
return await self._ops.get_recent(scope, limit)
|
||||||
|
|
||||||
|
async def prune_expired_operational(self):
|
||||||
|
await self._ops.prune_expired()
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Tier 3 — Long-term Knowledge
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def upsert_knowledge(self, entity_type, entity_key, facts):
|
||||||
|
await self._know.upsert(entity_type, entity_key, facts)
|
||||||
|
|
||||||
|
async def get_knowledge(self, entity_type, entity_key):
|
||||||
|
return await self._know.get(entity_type, entity_key)
|
||||||
|
|
||||||
|
async def get_client_profile(self, partner_id):
|
||||||
|
return await self._know.get_client_profile(partner_id)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Context assembly — called before every Master LLM call
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def build_context(self, user_id, intent_hint=None):
|
||||||
|
conversation = await self._conv.get(user_id, limit=50)
|
||||||
|
ops_scope = intent_hint or 'general'
|
||||||
|
operational_findings = await self._ops.get_recent(ops_scope, limit=5)
|
||||||
|
knowledge = {}
|
||||||
|
if intent_hint:
|
||||||
|
try:
|
||||||
|
knowledge = await self._know.get('user', str(user_id))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
pending_approvals = await self._get_pending_approvals(user_id)
|
||||||
|
active_directives = await self._get_active_directives(user_id)
|
||||||
|
return MasterContext(
|
||||||
|
user_id=user_id,
|
||||||
|
conversation=conversation,
|
||||||
|
operational_findings=operational_findings,
|
||||||
|
knowledge=knowledge,
|
||||||
|
pending_approvals=pending_approvals,
|
||||||
|
active_directives=active_directives,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _get_pending_approvals(self, user_id):
|
||||||
|
try:
|
||||||
|
async with self._pool.acquire(timeout=10) as conn:
|
||||||
|
rows = await conn.fetch(
|
||||||
|
"""SELECT directive_id, escalations FROM ab_directive_log
|
||||||
|
WHERE user_id = $1 AND status = 'awaiting_approval'
|
||||||
|
ORDER BY started_at DESC LIMIT 10""",
|
||||||
|
user_id)
|
||||||
|
return [dict(r) for r in rows]
|
||||||
|
except Exception:
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def _get_active_directives(self, user_id):
|
||||||
|
try:
|
||||||
|
async with self._pool.acquire(timeout=10) as conn:
|
||||||
|
rows = await conn.fetch(
|
||||||
|
"""SELECT directive_id, intent_summary, status, started_at
|
||||||
|
FROM ab_directive_log
|
||||||
|
WHERE user_id = $1 AND status IN ('pending', 'processing')
|
||||||
|
ORDER BY started_at DESC LIMIT 5""",
|
||||||
|
user_id)
|
||||||
|
return [dict(r) for r in rows]
|
||||||
|
except Exception:
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def summarize_long_conversations(self):
|
||||||
|
try:
|
||||||
|
async with self._pool.acquire(timeout=10) as conn:
|
||||||
|
rows = await conn.fetch(
|
||||||
|
"""SELECT user_id, COUNT(*) as n FROM ab_conversation_memory
|
||||||
|
WHERE is_summary = false GROUP BY user_id HAVING COUNT(*) > $1""",
|
||||||
|
SUMMARIZE_THRESHOLD)
|
||||||
|
for row in rows:
|
||||||
|
await self.summarize_if_needed(row['user_id'])
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error('summarize_long_conversations failed: %s', exc)
|
||||||
35
agent_service/memory/operational_store.py
Normal file
35
agent_service/memory/operational_store.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class OperationalStore:
|
||||||
|
def __init__(self, pool):
|
||||||
|
self._pool = pool
|
||||||
|
|
||||||
|
async def store(self, scope, summary, raw_data=None, ttl_days=90, source_directive_id=None):
|
||||||
|
expires_at = datetime.utcnow() + timedelta(days=ttl_days)
|
||||||
|
async with self._pool.acquire(timeout=10) as conn:
|
||||||
|
await conn.execute(
|
||||||
|
"""INSERT INTO ab_operational_memory
|
||||||
|
(scope, summary, raw_data, source_directive_id, expires_at)
|
||||||
|
VALUES ($1, $2, $3, $4, $5)""",
|
||||||
|
scope, summary, raw_data, source_directive_id, expires_at)
|
||||||
|
|
||||||
|
async def get_recent(self, scope, limit=10):
|
||||||
|
async with self._pool.acquire(timeout=10) as conn:
|
||||||
|
rows = await conn.fetch(
|
||||||
|
"""SELECT id, scope, summary, raw_data, created_at
|
||||||
|
FROM ab_operational_memory
|
||||||
|
WHERE scope = $1 AND (expires_at IS NULL OR expires_at > NOW())
|
||||||
|
ORDER BY created_at DESC LIMIT $2""",
|
||||||
|
scope, limit)
|
||||||
|
return [dict(r) for r in rows]
|
||||||
|
|
||||||
|
async def prune_expired(self):
|
||||||
|
async with self._pool.acquire(timeout=10) as conn:
|
||||||
|
result = await conn.execute(
|
||||||
|
'DELETE FROM ab_operational_memory WHERE expires_at IS NOT NULL AND expires_at < NOW()')
|
||||||
|
logger.info('OperationalStore.prune_expired: %s', result)
|
||||||
Reference in New Issue
Block a user