expenses_agent: batch LLM calls + skip RAG to fix timeout on large uploads
- auto_rag=False: skip PeerBus odoo_doc_agent call on every execute(); eliminates 30s Ollama semaphore contention before parsing even starts - _batch_parse_receipts(): Phase 1 regex (instant per-receipt: amount, date, bank-statement skip); Phase 2 single batched LLM call for all vendor+product_name instead of N individual calls; vision mode falls back to per-receipt calls (can't batch images); LLM fallback on bad JSON or wrong item count - _act() updated to use _batch_parse_receipts() - 7 new tests covering batch happy path, regex-only amounts, private-key cleanup, bank-statement skip, malformed-JSON fallback, wrong-count fallback, no-products short-circuit (99 tests total, all passing) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -219,6 +219,7 @@ class ExpensesAgent(BaseAgent):
|
|||||||
required_odoo_module = 'hr_expense'
|
required_odoo_module = 'hr_expense'
|
||||||
system_prompt_file = 'expenses_system.txt'
|
system_prompt_file = 'expenses_system.txt'
|
||||||
tools = EXPENSES_TOOLS
|
tools = EXPENSES_TOOLS
|
||||||
|
auto_rag = False # Receipt processing needs no RAG docs; skip the 30s peer-bus call
|
||||||
|
|
||||||
def __init__(self, odoo, llm, peer_bus=None):
|
def __init__(self, odoo, llm, peer_bus=None):
|
||||||
super().__init__(odoo, llm, peer_bus)
|
super().__init__(odoo, llm, peer_bus)
|
||||||
@@ -343,20 +344,9 @@ class ExpensesAgent(BaseAgent):
|
|||||||
logger.info('ocr filename=%r date_hint=%r ocr_len=%d text_preview=%r',
|
logger.info('ocr filename=%r date_hint=%r ocr_len=%d text_preview=%r',
|
||||||
r.get('filename'), r.get('date_from_name'), ocr_len, ocr_preview)
|
r.get('filename'), r.get('date_from_name'), ocr_len, ocr_preview)
|
||||||
|
|
||||||
# Parse all receipts concurrently.
|
# Parse all receipts: regex phase is instant; LLM phase is batched into
|
||||||
# b64 + mimetype are forwarded so _parse_receipt_text can use the
|
# a single call so N receipts cost 1 LLM round-trip instead of N.
|
||||||
# vision LLM path when RECEIPT_VISION_MODE=vision (the default).
|
raw_parsed = await self._batch_parse_receipts(unique_receipts, expense_products)
|
||||||
parse_tasks = [
|
|
||||||
self._parse_receipt_text(
|
|
||||||
r.get('text', ''), r.get('filename', 'receipt'),
|
|
||||||
expense_products=expense_products,
|
|
||||||
date_hint=r.get('date_from_name'),
|
|
||||||
b64=r.get('b64'),
|
|
||||||
mimetype=r.get('mimetype'),
|
|
||||||
)
|
|
||||||
for r in unique_receipts
|
|
||||||
]
|
|
||||||
raw_parsed = await asyncio.gather(*parse_tasks, return_exceptions=True)
|
|
||||||
|
|
||||||
paired: list[tuple[dict, dict]] = []
|
paired: list[tuple[dict, dict]] = []
|
||||||
for receipt, parsed in zip(unique_receipts, raw_parsed):
|
for receipt, parsed in zip(unique_receipts, raw_parsed):
|
||||||
@@ -522,6 +512,164 @@ class ExpensesAgent(BaseAgent):
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
async def _batch_parse_receipts(self, receipts: list, expense_products: list) -> list:
|
||||||
|
"""Parse all receipts with a single LLM call instead of one per receipt.
|
||||||
|
|
||||||
|
Phase 1 — regex (no LLM, instant):
|
||||||
|
• amount : _extract_amount_from_text
|
||||||
|
• date : filename hint > OCR regex > today
|
||||||
|
• skip flag : bank/card statement detection
|
||||||
|
|
||||||
|
Phase 2 — one batched LLM call:
|
||||||
|
• vendor + product_name for all non-skipped receipts in one prompt
|
||||||
|
• Vision mode (RECEIPT_VISION_MODE=vision): falls back to individual
|
||||||
|
calls because images can't be batched in a single Ollama message
|
||||||
|
• Falls back to individual _parse_receipt_text calls on any failure
|
||||||
|
|
||||||
|
Returns a list of parsed dicts in the same order as `receipts`.
|
||||||
|
Each dict: {vendor, amount, date, time, product_name, skip?}
|
||||||
|
"""
|
||||||
|
today = _date.today().isoformat()
|
||||||
|
results: list[dict] = []
|
||||||
|
needs_llm: list[int] = [] # indices into results that need vendor/cat
|
||||||
|
|
||||||
|
# ── Phase 1: fast per-receipt regex ──────────────────────────────────
|
||||||
|
for r in receipts:
|
||||||
|
filename = r.get('filename', 'receipt')
|
||||||
|
stripped = (r.get('text', '') or '').strip()
|
||||||
|
ocr_failed = not stripped or stripped.startswith('[')
|
||||||
|
|
||||||
|
if not ocr_failed and _is_likely_bank_statement(stripped):
|
||||||
|
n = sum(1 for line in stripped.splitlines() if _ANY_DOLLAR_RE.search(line))
|
||||||
|
logger.warning('receipt %s: bank statement (%d amount lines) — skip', filename, n)
|
||||||
|
results.append({'vendor': filename, 'amount': 0.0,
|
||||||
|
'date': r.get('date_from_name') or today, 'time': None,
|
||||||
|
'product_name': '', 'skip': True})
|
||||||
|
continue
|
||||||
|
|
||||||
|
amount = _extract_amount_from_text(stripped) if not ocr_failed else 0.0
|
||||||
|
date_hint = r.get('date_from_name')
|
||||||
|
date = (date_hint or
|
||||||
|
(_extract_date_from_text(stripped) if not ocr_failed else None) or
|
||||||
|
today)
|
||||||
|
|
||||||
|
results.append({'vendor': filename, 'amount': amount, 'date': date,
|
||||||
|
'time': None, 'product_name': '',
|
||||||
|
# internal keys stripped before returning
|
||||||
|
'_ocr_failed': ocr_failed, '_stripped': stripped,
|
||||||
|
'_b64': r.get('b64'), '_mimetype': r.get('mimetype'),
|
||||||
|
'_filename': filename})
|
||||||
|
needs_llm.append(len(results) - 1)
|
||||||
|
|
||||||
|
product_list = ', '.join(f'"{p["name"]}"' for p in expense_products)
|
||||||
|
if not needs_llm or not product_list:
|
||||||
|
for entry in results:
|
||||||
|
for k in list(entry):
|
||||||
|
if k.startswith('_'):
|
||||||
|
del entry[k]
|
||||||
|
return results
|
||||||
|
|
||||||
|
# ── Phase 2a: vision mode → individual calls (can't batch images) ────
|
||||||
|
use_vision = (
|
||||||
|
_get_vision_mode() == 'vision'
|
||||||
|
and any(results[i].get('_b64') and
|
||||||
|
results[i].get('_mimetype') in _VISION_MIMETYPES
|
||||||
|
for i in needs_llm)
|
||||||
|
)
|
||||||
|
if use_vision:
|
||||||
|
tasks = [
|
||||||
|
self._parse_receipt_text(
|
||||||
|
results[i]['_stripped'], results[i]['_filename'],
|
||||||
|
expense_products=expense_products,
|
||||||
|
b64=results[i].get('_b64'),
|
||||||
|
mimetype=results[i].get('_mimetype'),
|
||||||
|
)
|
||||||
|
for i in needs_llm
|
||||||
|
]
|
||||||
|
individual = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
for i, parsed in zip(needs_llm, individual):
|
||||||
|
if isinstance(parsed, Exception) or not isinstance(parsed, dict):
|
||||||
|
continue
|
||||||
|
results[i]['vendor'] = parsed.get('vendor', results[i]['_filename'])
|
||||||
|
results[i]['product_name'] = parsed.get('product_name', '')
|
||||||
|
for entry in results:
|
||||||
|
for k in list(entry):
|
||||||
|
if k.startswith('_'):
|
||||||
|
del entry[k]
|
||||||
|
return results
|
||||||
|
|
||||||
|
# ── Phase 2b: text mode → single batched LLM call ────────────────────
|
||||||
|
_cat_guide = (
|
||||||
|
'restaurant/cafe/food court/bar → food/meal product; '
|
||||||
|
'airline/airport/transit/taxi/parking/rental car → travel product; '
|
||||||
|
'gas station/petrol/fuel → fuel product; '
|
||||||
|
'hotel/motel/lodging → accommodation product; '
|
||||||
|
'hardware/home improvement/tech/office supply → supplies product; '
|
||||||
|
'return "" if nothing fits'
|
||||||
|
)
|
||||||
|
receipts_block = ''
|
||||||
|
for seq, i in enumerate(needs_llm, 1):
|
||||||
|
entry = results[i]
|
||||||
|
if entry['_ocr_failed']:
|
||||||
|
excerpt = f'[filename: {entry["_filename"]}]'
|
||||||
|
else:
|
||||||
|
excerpt = entry['_stripped'][:300]
|
||||||
|
receipts_block += f'\n=== Receipt {seq} ({entry["_filename"]}) ===\n{excerpt}\n'
|
||||||
|
|
||||||
|
n = len(needs_llm)
|
||||||
|
batch_prompt = (
|
||||||
|
f'Return ONLY a JSON array with exactly {n} objects, one per receipt below.\n'
|
||||||
|
f'Each object must have exactly two keys:\n'
|
||||||
|
f'"vendor": business name from the receipt header '
|
||||||
|
f'(first 1-3 lines; ignore slogans and item names; '
|
||||||
|
f'do NOT substitute a brand not clearly present).\n'
|
||||||
|
f'"product_name": single best match from [{product_list}].\n'
|
||||||
|
f'Category guide: {_cat_guide}\n'
|
||||||
|
f'JSON array only:\n{receipts_block}'
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
resp = await self._llm.submit(
|
||||||
|
[{'role': 'user', 'content': batch_prompt}],
|
||||||
|
caller='expenses_agent_receipt_parser',
|
||||||
|
)
|
||||||
|
raw = (resp.content or '').strip()
|
||||||
|
first, last = raw.find('['), raw.rfind(']')
|
||||||
|
if first == -1 or last <= first:
|
||||||
|
raise ValueError(f'No JSON array in response: {raw[:200]}')
|
||||||
|
batch_data = json.loads(raw[first:last + 1])
|
||||||
|
if len(batch_data) != n:
|
||||||
|
raise ValueError(f'Expected {n} items, got {len(batch_data)}')
|
||||||
|
for i, item in zip(needs_llm, batch_data):
|
||||||
|
v = str(item.get('vendor', '') or '').strip()
|
||||||
|
if v:
|
||||||
|
results[i]['vendor'] = v
|
||||||
|
results[i]['product_name'] = str(item.get('product_name', '') or '').strip()
|
||||||
|
logger.info('expenses_agent: batch LLM parsed %d receipts in 1 call', n)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning('expenses_agent: batch LLM failed (%s) — falling back to individual calls', exc)
|
||||||
|
fallback_tasks = [
|
||||||
|
self._parse_receipt_text(
|
||||||
|
results[i]['_stripped'], results[i]['_filename'],
|
||||||
|
expense_products=expense_products,
|
||||||
|
)
|
||||||
|
for i in needs_llm
|
||||||
|
]
|
||||||
|
fallback = await asyncio.gather(*fallback_tasks, return_exceptions=True)
|
||||||
|
for i, parsed in zip(needs_llm, fallback):
|
||||||
|
if isinstance(parsed, Exception) or not isinstance(parsed, dict):
|
||||||
|
continue
|
||||||
|
v = str(parsed.get('vendor', '') or '').strip()
|
||||||
|
if v:
|
||||||
|
results[i]['vendor'] = v
|
||||||
|
results[i]['product_name'] = str(parsed.get('product_name', '') or '').strip()
|
||||||
|
|
||||||
|
# Strip internal bookkeeping keys before returning
|
||||||
|
for entry in results:
|
||||||
|
for k in list(entry):
|
||||||
|
if k.startswith('_'):
|
||||||
|
del entry[k]
|
||||||
|
return results
|
||||||
|
|
||||||
async def _parse_receipt_text(self, text: str, filename: str,
|
async def _parse_receipt_text(self, text: str, filename: str,
|
||||||
expense_products: list = None,
|
expense_products: list = None,
|
||||||
date_hint: str = None,
|
date_hint: str = None,
|
||||||
|
|||||||
@@ -856,6 +856,190 @@ async def test_non_image_mimetype_uses_text_path_in_vision_mode():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _batch_parse_receipts — batched LLM call for vendor + product_name
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _make_receipt(filename='receipt.jpg', text='Acme\nTotal: $10.00',
|
||||||
|
b64='', mimetype='image/jpeg', date_from_name=None):
|
||||||
|
"""Build a minimal receipt dict as produced by parse_upload."""
|
||||||
|
return {'filename': filename, 'text': text, 'b64': b64,
|
||||||
|
'mimetype': mimetype, 'date_from_name': date_from_name,
|
||||||
|
'sha256': 'abc'}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_batch_parse_single_llm_call_for_multiple_receipts():
|
||||||
|
"""N text receipts must result in exactly 1 LLM call (batched prompt)."""
|
||||||
|
agent = _make_agent()
|
||||||
|
receipts = [
|
||||||
|
_make_receipt('a.txt', 'Shell Gas\nTotal: $45.00'),
|
||||||
|
_make_receipt('b.txt', 'Marriott Hotel\nAmount Due: $180.00'),
|
||||||
|
_make_receipt('c.txt', 'Chipotle\nTotal: $12.75'),
|
||||||
|
]
|
||||||
|
products = [{'id': 1, 'name': 'Meals'}, {'id': 2, 'name': 'Travel'}, {'id': 3, 'name': 'Fuel'}]
|
||||||
|
|
||||||
|
llm_resp = MagicMock()
|
||||||
|
llm_resp.content = (
|
||||||
|
'[{"vendor":"Shell","product_name":"Fuel"},'
|
||||||
|
'{"vendor":"Marriott","product_name":"Travel"},'
|
||||||
|
'{"vendor":"Chipotle","product_name":"Meals"}]'
|
||||||
|
)
|
||||||
|
agent._llm.submit = AsyncMock(return_value=llm_resp)
|
||||||
|
|
||||||
|
with patch('agent_service.agents.expenses_agent._get_vision_mode', return_value='text'):
|
||||||
|
results = await agent._batch_parse_receipts(receipts, products)
|
||||||
|
|
||||||
|
agent._llm.submit.assert_called_once()
|
||||||
|
assert len(results) == 3
|
||||||
|
assert results[0]['vendor'] == 'Shell'
|
||||||
|
assert results[1]['vendor'] == 'Marriott'
|
||||||
|
assert results[2]['vendor'] == 'Chipotle'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_batch_parse_amounts_from_regex_not_llm():
|
||||||
|
"""Amounts must come from regex (Phase 1), not from the LLM batch response."""
|
||||||
|
agent = _make_agent()
|
||||||
|
receipts = [_make_receipt('r.txt', 'Acme Store\nTotal: $99.99')]
|
||||||
|
products = [{'id': 1, 'name': 'Supplies'}]
|
||||||
|
|
||||||
|
llm_resp = MagicMock()
|
||||||
|
llm_resp.content = '[{"vendor":"Acme","product_name":"Supplies"}]'
|
||||||
|
agent._llm.submit = AsyncMock(return_value=llm_resp)
|
||||||
|
|
||||||
|
with patch('agent_service.agents.expenses_agent._get_vision_mode', return_value='text'):
|
||||||
|
results = await agent._batch_parse_receipts(receipts, products)
|
||||||
|
|
||||||
|
assert results[0]['amount'] == 99.99
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_batch_parse_no_private_keys_in_results():
|
||||||
|
"""Internal _-prefixed keys must be stripped from every result dict."""
|
||||||
|
agent = _make_agent()
|
||||||
|
receipts = [_make_receipt('r.txt', 'Acme\nTotal: $10.00')]
|
||||||
|
products = [{'id': 1, 'name': 'Meals'}]
|
||||||
|
|
||||||
|
llm_resp = MagicMock()
|
||||||
|
llm_resp.content = '[{"vendor":"Acme","product_name":"Meals"}]'
|
||||||
|
agent._llm.submit = AsyncMock(return_value=llm_resp)
|
||||||
|
|
||||||
|
with patch('agent_service.agents.expenses_agent._get_vision_mode', return_value='text'):
|
||||||
|
results = await agent._batch_parse_receipts(receipts, products)
|
||||||
|
|
||||||
|
for entry in results:
|
||||||
|
private = [k for k in entry if k.startswith('_')]
|
||||||
|
assert private == [], f'Private keys not cleaned up: {private}'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_batch_parse_bank_statement_skipped_no_llm():
|
||||||
|
"""Bank statements inside a batch must be skipped; no LLM call for them."""
|
||||||
|
agent = _make_agent()
|
||||||
|
# 12 transaction lines → flagged as bank statement
|
||||||
|
stmt = '\n'.join(f'05/{i+1:02d} MERCHANT {i} ${10 + i}.99' for i in range(12))
|
||||||
|
receipts = [
|
||||||
|
_make_receipt('stmt.pdf', stmt),
|
||||||
|
_make_receipt('real.txt', 'Shell Gas\nTotal: $45.00'),
|
||||||
|
]
|
||||||
|
products = [{'id': 1, 'name': 'Fuel'}]
|
||||||
|
|
||||||
|
llm_resp = MagicMock()
|
||||||
|
llm_resp.content = '[{"vendor":"Shell","product_name":"Fuel"}]'
|
||||||
|
agent._llm.submit = AsyncMock(return_value=llm_resp)
|
||||||
|
|
||||||
|
with patch('agent_service.agents.expenses_agent._get_vision_mode', return_value='text'):
|
||||||
|
results = await agent._batch_parse_receipts(receipts, products)
|
||||||
|
|
||||||
|
# Only 1 item sent to LLM (the real receipt, not the statement)
|
||||||
|
agent._llm.submit.assert_called_once()
|
||||||
|
# Statement entry has skip=True
|
||||||
|
assert results[0].get('skip') is True
|
||||||
|
assert results[0]['amount'] == 0.0
|
||||||
|
# Real receipt parsed normally
|
||||||
|
assert results[1]['vendor'] == 'Shell'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_batch_parse_falls_back_on_malformed_json():
|
||||||
|
"""When the batch LLM returns malformed JSON, falls back to individual calls."""
|
||||||
|
agent = _make_agent()
|
||||||
|
receipts = [
|
||||||
|
_make_receipt('a.txt', 'Shell\nTotal: $45.00'),
|
||||||
|
_make_receipt('b.txt', 'Marriott\nTotal: $180.00'),
|
||||||
|
]
|
||||||
|
products = [{'id': 1, 'name': 'Travel'}]
|
||||||
|
|
||||||
|
call_count = [0]
|
||||||
|
individual_resp = MagicMock()
|
||||||
|
individual_resp.content = '{"vendor":"Shell","product_name":"Travel"}'
|
||||||
|
|
||||||
|
async def _side_effect(messages, caller=None):
|
||||||
|
call_count[0] += 1
|
||||||
|
if call_count[0] == 1:
|
||||||
|
bad = MagicMock()
|
||||||
|
bad.content = 'not valid json at all'
|
||||||
|
return bad
|
||||||
|
return individual_resp
|
||||||
|
|
||||||
|
agent._llm.submit = _side_effect
|
||||||
|
|
||||||
|
with patch('agent_service.agents.expenses_agent._get_vision_mode', return_value='text'):
|
||||||
|
results = await agent._batch_parse_receipts(receipts, products)
|
||||||
|
|
||||||
|
# 1 batch attempt + 2 individual fallback calls = 3
|
||||||
|
assert call_count[0] == 3
|
||||||
|
assert len(results) == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_batch_parse_falls_back_on_wrong_item_count():
|
||||||
|
"""When the LLM returns a JSON array with wrong length, falls back."""
|
||||||
|
agent = _make_agent()
|
||||||
|
receipts = [
|
||||||
|
_make_receipt('a.txt', 'Shell\nTotal: $45.00'),
|
||||||
|
_make_receipt('b.txt', 'Marriott\nTotal: $180.00'),
|
||||||
|
]
|
||||||
|
products = [{'id': 1, 'name': 'Travel'}]
|
||||||
|
|
||||||
|
call_count = [0]
|
||||||
|
fallback_resp = MagicMock()
|
||||||
|
fallback_resp.content = '{"vendor":"Shell","product_name":"Travel"}'
|
||||||
|
|
||||||
|
async def _side_effect(messages, caller=None):
|
||||||
|
call_count[0] += 1
|
||||||
|
if call_count[0] == 1:
|
||||||
|
# Returns only 1 item, expected 2
|
||||||
|
wrong = MagicMock()
|
||||||
|
wrong.content = '[{"vendor":"Shell","product_name":"Travel"}]'
|
||||||
|
return wrong
|
||||||
|
return fallback_resp
|
||||||
|
|
||||||
|
agent._llm.submit = _side_effect
|
||||||
|
|
||||||
|
with patch('agent_service.agents.expenses_agent._get_vision_mode', return_value='text'):
|
||||||
|
results = await agent._batch_parse_receipts(receipts, products)
|
||||||
|
|
||||||
|
# 1 batch attempt + 2 individual fallback calls = 3
|
||||||
|
assert call_count[0] == 3
|
||||||
|
assert len(results) == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_batch_parse_no_products_skips_llm():
|
||||||
|
"""When there are no expense products, the LLM is not called."""
|
||||||
|
agent = _make_agent()
|
||||||
|
receipts = [_make_receipt('r.txt', 'Acme\nTotal: $10.00')]
|
||||||
|
agent._llm.submit = AsyncMock()
|
||||||
|
|
||||||
|
with patch('agent_service.agents.expenses_agent._get_vision_mode', return_value='text'):
|
||||||
|
results = await agent._batch_parse_receipts(receipts, [])
|
||||||
|
|
||||||
|
agent._llm.submit.assert_not_called()
|
||||||
|
assert len(results) == 1
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# parse_upload — receipt_parser.py
|
# parse_upload — receipt_parser.py
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
Reference in New Issue
Block a user