feat: add 3-tier memory layer
This commit is contained in:
0
agent_service/memory/__init__.py
Normal file
0
agent_service/memory/__init__.py
Normal file
47
agent_service/memory/conversation_store.py
Normal file
47
agent_service/memory/conversation_store.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
HARD_CAP = 200
|
||||
|
||||
|
||||
class ConversationStore:
|
||||
def __init__(self, pool):
|
||||
self._pool = pool
|
||||
|
||||
async def append(self, user_id, role, content, directive_id=None, is_summary=False):
|
||||
async with self._pool.acquire(timeout=10) as conn:
|
||||
await conn.execute(
|
||||
"""INSERT INTO ab_conversation_memory
|
||||
(user_id, role, content, directive_id, is_summary)
|
||||
VALUES ($1, $2, $3, $4, $5)""",
|
||||
user_id, role, content, directive_id, is_summary)
|
||||
|
||||
async def get(self, user_id, limit=50):
|
||||
async with self._pool.acquire(timeout=10) as conn:
|
||||
rows = await conn.fetch(
|
||||
"""SELECT id, role, content, directive_id, is_summary, created_at
|
||||
FROM ab_conversation_memory
|
||||
WHERE user_id = $1
|
||||
ORDER BY created_at DESC LIMIT $2""",
|
||||
user_id, limit)
|
||||
return [dict(r) for r in reversed(rows)]
|
||||
|
||||
async def count(self, user_id):
|
||||
async with self._pool.acquire(timeout=10) as conn:
|
||||
row = await conn.fetchrow(
|
||||
'SELECT COUNT(*) as n FROM ab_conversation_memory WHERE user_id = $1 AND is_summary = false',
|
||||
user_id)
|
||||
return row['n']
|
||||
|
||||
async def prune_old(self, user_id, keep=50):
|
||||
async with self._pool.acquire(timeout=10) as conn:
|
||||
await conn.execute(
|
||||
"""DELETE FROM ab_conversation_memory
|
||||
WHERE user_id = $1 AND is_summary = false
|
||||
AND id NOT IN (
|
||||
SELECT id FROM ab_conversation_memory
|
||||
WHERE user_id = $1 AND is_summary = false
|
||||
ORDER BY created_at DESC LIMIT $2
|
||||
)""",
|
||||
user_id, keep)
|
||||
31
agent_service/memory/knowledge_store.py
Normal file
31
agent_service/memory/knowledge_store.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from __future__ import annotations
|
||||
import json, logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class KnowledgeStore:
|
||||
def __init__(self, pool):
|
||||
self._pool = pool
|
||||
|
||||
async def upsert(self, entity_type, entity_key, facts):
|
||||
async with self._pool.acquire(timeout=10) as conn:
|
||||
await conn.execute(
|
||||
"""INSERT INTO ab_knowledge_store (entity_type, entity_key, facts)
|
||||
VALUES ($1, $2, $3)
|
||||
ON CONFLICT (entity_type, entity_key)
|
||||
DO UPDATE SET facts = $3, updated_at = NOW()""",
|
||||
entity_type, entity_key, json.dumps(facts))
|
||||
|
||||
async def get(self, entity_type, entity_key):
|
||||
async with self._pool.acquire(timeout=10) as conn:
|
||||
row = await conn.fetchrow(
|
||||
'SELECT facts FROM ab_knowledge_store WHERE entity_type=$1 AND entity_key=$2',
|
||||
entity_type, entity_key)
|
||||
if not row:
|
||||
return {}
|
||||
f = row['facts']
|
||||
return json.loads(f) if isinstance(f, str) else f
|
||||
|
||||
async def get_client_profile(self, partner_id):
|
||||
return await self.get('client', f'partner_{partner_id}')
|
||||
151
agent_service/memory/memory_manager.py
Normal file
151
agent_service/memory/memory_manager.py
Normal file
@@ -0,0 +1,151 @@
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from .conversation_store import ConversationStore, HARD_CAP
|
||||
from .operational_store import OperationalStore
|
||||
from .knowledge_store import KnowledgeStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
SUMMARIZE_THRESHOLD = 50
|
||||
|
||||
|
||||
@dataclass
|
||||
class MasterContext:
|
||||
user_id: int
|
||||
conversation: list = field(default_factory=list)
|
||||
operational_findings: list = field(default_factory=list)
|
||||
knowledge: dict = field(default_factory=dict)
|
||||
pending_approvals: list = field(default_factory=list)
|
||||
active_directives: list = field(default_factory=list)
|
||||
|
||||
|
||||
class MemoryManager:
|
||||
def __init__(self, pool, llm_router=None):
|
||||
self._pool = pool
|
||||
self._llm = llm_router
|
||||
self._conv = ConversationStore(pool)
|
||||
self._ops = OperationalStore(pool)
|
||||
self._know = KnowledgeStore(pool)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tier 1 — Conversation
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def append_message(self, user_id, role, content, directive_id=None):
|
||||
await self._conv.append(user_id, role, content, directive_id)
|
||||
await self.summarize_if_needed(user_id)
|
||||
|
||||
async def get_conversation(self, user_id, limit=50):
|
||||
return await self._conv.get(user_id, limit)
|
||||
|
||||
async def summarize_if_needed(self, user_id):
|
||||
count = await self._conv.count(user_id)
|
||||
if count < SUMMARIZE_THRESHOLD:
|
||||
return
|
||||
if count >= HARD_CAP - 10:
|
||||
logger.warning('Conversation memory near hard cap for user_id=%s count=%d', user_id, count)
|
||||
if self._llm and count >= SUMMARIZE_THRESHOLD:
|
||||
recent = await self._conv.get(user_id, limit=SUMMARIZE_THRESHOLD)
|
||||
if not recent:
|
||||
return
|
||||
history_text = chr(10).join(m['role'] + ': ' + m['content'] for m in recent[:20])
|
||||
messages = [
|
||||
{'role': 'system', 'content': 'Summarize this conversation history in 3-5 sentences, preserving key decisions and context.'},
|
||||
{'role': 'user', 'content': history_text}
|
||||
]
|
||||
try:
|
||||
resp = await self._llm.submit(messages, caller='memory_manager')
|
||||
summary = resp.content
|
||||
await self._conv.append(user_id, 'system', f'[SUMMARY] {summary}', is_summary=True)
|
||||
await self._conv.prune_old(user_id, keep=20)
|
||||
logger.info('Conversation summarized for user_id=%s', user_id)
|
||||
except Exception as exc:
|
||||
logger.error('Conversation summarization failed user_id=%s: %s', user_id, exc)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tier 2 — Operational
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def store_findings(self, scope, summary, raw_data=None, ttl_days=90, source_directive_id=None):
|
||||
await self._ops.store(scope, summary, raw_data, ttl_days, source_directive_id)
|
||||
|
||||
async def get_recent_findings(self, scope, limit=10):
|
||||
return await self._ops.get_recent(scope, limit)
|
||||
|
||||
async def prune_expired_operational(self):
|
||||
await self._ops.prune_expired()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tier 3 — Long-term Knowledge
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def upsert_knowledge(self, entity_type, entity_key, facts):
|
||||
await self._know.upsert(entity_type, entity_key, facts)
|
||||
|
||||
async def get_knowledge(self, entity_type, entity_key):
|
||||
return await self._know.get(entity_type, entity_key)
|
||||
|
||||
async def get_client_profile(self, partner_id):
|
||||
return await self._know.get_client_profile(partner_id)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Context assembly — called before every Master LLM call
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def build_context(self, user_id, intent_hint=None):
|
||||
conversation = await self._conv.get(user_id, limit=50)
|
||||
ops_scope = intent_hint or 'general'
|
||||
operational_findings = await self._ops.get_recent(ops_scope, limit=5)
|
||||
knowledge = {}
|
||||
if intent_hint:
|
||||
try:
|
||||
knowledge = await self._know.get('user', str(user_id))
|
||||
except Exception:
|
||||
pass
|
||||
pending_approvals = await self._get_pending_approvals(user_id)
|
||||
active_directives = await self._get_active_directives(user_id)
|
||||
return MasterContext(
|
||||
user_id=user_id,
|
||||
conversation=conversation,
|
||||
operational_findings=operational_findings,
|
||||
knowledge=knowledge,
|
||||
pending_approvals=pending_approvals,
|
||||
active_directives=active_directives,
|
||||
)
|
||||
|
||||
async def _get_pending_approvals(self, user_id):
|
||||
try:
|
||||
async with self._pool.acquire(timeout=10) as conn:
|
||||
rows = await conn.fetch(
|
||||
"""SELECT directive_id, escalations FROM ab_directive_log
|
||||
WHERE user_id = $1 AND status = 'awaiting_approval'
|
||||
ORDER BY started_at DESC LIMIT 10""",
|
||||
user_id)
|
||||
return [dict(r) for r in rows]
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
async def _get_active_directives(self, user_id):
|
||||
try:
|
||||
async with self._pool.acquire(timeout=10) as conn:
|
||||
rows = await conn.fetch(
|
||||
"""SELECT directive_id, intent_summary, status, started_at
|
||||
FROM ab_directive_log
|
||||
WHERE user_id = $1 AND status IN ('pending', 'processing')
|
||||
ORDER BY started_at DESC LIMIT 5""",
|
||||
user_id)
|
||||
return [dict(r) for r in rows]
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
async def summarize_long_conversations(self):
|
||||
try:
|
||||
async with self._pool.acquire(timeout=10) as conn:
|
||||
rows = await conn.fetch(
|
||||
"""SELECT user_id, COUNT(*) as n FROM ab_conversation_memory
|
||||
WHERE is_summary = false GROUP BY user_id HAVING COUNT(*) > $1""",
|
||||
SUMMARIZE_THRESHOLD)
|
||||
for row in rows:
|
||||
await self.summarize_if_needed(row['user_id'])
|
||||
except Exception as exc:
|
||||
logger.error('summarize_long_conversations failed: %s', exc)
|
||||
35
agent_service/memory/operational_store.py
Normal file
35
agent_service/memory/operational_store.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OperationalStore:
|
||||
def __init__(self, pool):
|
||||
self._pool = pool
|
||||
|
||||
async def store(self, scope, summary, raw_data=None, ttl_days=90, source_directive_id=None):
|
||||
expires_at = datetime.utcnow() + timedelta(days=ttl_days)
|
||||
async with self._pool.acquire(timeout=10) as conn:
|
||||
await conn.execute(
|
||||
"""INSERT INTO ab_operational_memory
|
||||
(scope, summary, raw_data, source_directive_id, expires_at)
|
||||
VALUES ($1, $2, $3, $4, $5)""",
|
||||
scope, summary, raw_data, source_directive_id, expires_at)
|
||||
|
||||
async def get_recent(self, scope, limit=10):
|
||||
async with self._pool.acquire(timeout=10) as conn:
|
||||
rows = await conn.fetch(
|
||||
"""SELECT id, scope, summary, raw_data, created_at
|
||||
FROM ab_operational_memory
|
||||
WHERE scope = $1 AND (expires_at IS NULL OR expires_at > NOW())
|
||||
ORDER BY created_at DESC LIMIT $2""",
|
||||
scope, limit)
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
async def prune_expired(self):
|
||||
async with self._pool.acquire(timeout=10) as conn:
|
||||
result = await conn.execute(
|
||||
'DELETE FROM ab_operational_memory WHERE expires_at IS NOT NULL AND expires_at < NOW()')
|
||||
logger.info('OperationalStore.prune_expired: %s', result)
|
||||
Reference in New Issue
Block a user