48 lines
1.8 KiB
Python
48 lines
1.8 KiB
Python
from __future__ import annotations
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
HARD_CAP = 200
|
|
|
|
|
|
class ConversationStore:
|
|
def __init__(self, pool):
|
|
self._pool = pool
|
|
|
|
async def append(self, user_id, role, content, directive_id=None, is_summary=False):
|
|
async with self._pool.acquire(timeout=10) as conn:
|
|
await conn.execute(
|
|
"""INSERT INTO ab_conversation_memory
|
|
(user_id, role, content, directive_id, is_summary)
|
|
VALUES ($1, $2, $3, $4, $5)""",
|
|
user_id, role, content, directive_id, is_summary)
|
|
|
|
async def get(self, user_id, limit=50):
|
|
async with self._pool.acquire(timeout=10) as conn:
|
|
rows = await conn.fetch(
|
|
"""SELECT id, role, content, directive_id, is_summary, created_at
|
|
FROM ab_conversation_memory
|
|
WHERE user_id = $1
|
|
ORDER BY created_at DESC LIMIT $2""",
|
|
user_id, limit)
|
|
return [dict(r) for r in reversed(rows)]
|
|
|
|
async def count(self, user_id):
|
|
async with self._pool.acquire(timeout=10) as conn:
|
|
row = await conn.fetchrow(
|
|
'SELECT COUNT(*) as n FROM ab_conversation_memory WHERE user_id = $1 AND is_summary = false',
|
|
user_id)
|
|
return row['n']
|
|
|
|
async def prune_old(self, user_id, keep=50):
|
|
async with self._pool.acquire(timeout=10) as conn:
|
|
await conn.execute(
|
|
"""DELETE FROM ab_conversation_memory
|
|
WHERE user_id = $1 AND is_summary = false
|
|
AND id NOT IN (
|
|
SELECT id FROM ab_conversation_memory
|
|
WHERE user_id = $1 AND is_summary = false
|
|
ORDER BY created_at DESC LIMIT $2
|
|
)""",
|
|
user_id, keep)
|