diff --git a/agent_service/agents/expenses_agent.py b/agent_service/agents/expenses_agent.py index ecfbdcd..387c2f4 100644 --- a/agent_service/agents/expenses_agent.py +++ b/agent_service/agents/expenses_agent.py @@ -219,6 +219,7 @@ class ExpensesAgent(BaseAgent): required_odoo_module = 'hr_expense' system_prompt_file = 'expenses_system.txt' tools = EXPENSES_TOOLS + auto_rag = False # Receipt processing needs no RAG docs; skip the 30s peer-bus call def __init__(self, odoo, llm, peer_bus=None): super().__init__(odoo, llm, peer_bus) @@ -343,20 +344,9 @@ class ExpensesAgent(BaseAgent): logger.info('ocr filename=%r date_hint=%r ocr_len=%d text_preview=%r', r.get('filename'), r.get('date_from_name'), ocr_len, ocr_preview) - # Parse all receipts concurrently. - # b64 + mimetype are forwarded so _parse_receipt_text can use the - # vision LLM path when RECEIPT_VISION_MODE=vision (the default). - parse_tasks = [ - self._parse_receipt_text( - r.get('text', ''), r.get('filename', 'receipt'), - expense_products=expense_products, - date_hint=r.get('date_from_name'), - b64=r.get('b64'), - mimetype=r.get('mimetype'), - ) - for r in unique_receipts - ] - raw_parsed = await asyncio.gather(*parse_tasks, return_exceptions=True) + # Parse all receipts: regex phase is instant; LLM phase is batched into + # a single call so N receipts cost 1 LLM round-trip instead of N. + raw_parsed = await self._batch_parse_receipts(unique_receipts, expense_products) paired: list[tuple[dict, dict]] = [] for receipt, parsed in zip(unique_receipts, raw_parsed): @@ -522,6 +512,164 @@ class ExpensesAgent(BaseAgent): return None + async def _batch_parse_receipts(self, receipts: list, expense_products: list) -> list: + """Parse all receipts with a single LLM call instead of one per receipt. + + Phase 1 — regex (no LLM, instant): + • amount : _extract_amount_from_text + • date : filename hint > OCR regex > today + • skip flag : bank/card statement detection + + Phase 2 — one batched LLM call: + • vendor + product_name for all non-skipped receipts in one prompt + • Vision mode (RECEIPT_VISION_MODE=vision): falls back to individual + calls because images can't be batched in a single Ollama message + • Falls back to individual _parse_receipt_text calls on any failure + + Returns a list of parsed dicts in the same order as `receipts`. + Each dict: {vendor, amount, date, time, product_name, skip?} + """ + today = _date.today().isoformat() + results: list[dict] = [] + needs_llm: list[int] = [] # indices into results that need vendor/cat + + # ── Phase 1: fast per-receipt regex ────────────────────────────────── + for r in receipts: + filename = r.get('filename', 'receipt') + stripped = (r.get('text', '') or '').strip() + ocr_failed = not stripped or stripped.startswith('[') + + if not ocr_failed and _is_likely_bank_statement(stripped): + n = sum(1 for line in stripped.splitlines() if _ANY_DOLLAR_RE.search(line)) + logger.warning('receipt %s: bank statement (%d amount lines) — skip', filename, n) + results.append({'vendor': filename, 'amount': 0.0, + 'date': r.get('date_from_name') or today, 'time': None, + 'product_name': '', 'skip': True}) + continue + + amount = _extract_amount_from_text(stripped) if not ocr_failed else 0.0 + date_hint = r.get('date_from_name') + date = (date_hint or + (_extract_date_from_text(stripped) if not ocr_failed else None) or + today) + + results.append({'vendor': filename, 'amount': amount, 'date': date, + 'time': None, 'product_name': '', + # internal keys stripped before returning + '_ocr_failed': ocr_failed, '_stripped': stripped, + '_b64': r.get('b64'), '_mimetype': r.get('mimetype'), + '_filename': filename}) + needs_llm.append(len(results) - 1) + + product_list = ', '.join(f'"{p["name"]}"' for p in expense_products) + if not needs_llm or not product_list: + for entry in results: + for k in list(entry): + if k.startswith('_'): + del entry[k] + return results + + # ── Phase 2a: vision mode → individual calls (can't batch images) ──── + use_vision = ( + _get_vision_mode() == 'vision' + and any(results[i].get('_b64') and + results[i].get('_mimetype') in _VISION_MIMETYPES + for i in needs_llm) + ) + if use_vision: + tasks = [ + self._parse_receipt_text( + results[i]['_stripped'], results[i]['_filename'], + expense_products=expense_products, + b64=results[i].get('_b64'), + mimetype=results[i].get('_mimetype'), + ) + for i in needs_llm + ] + individual = await asyncio.gather(*tasks, return_exceptions=True) + for i, parsed in zip(needs_llm, individual): + if isinstance(parsed, Exception) or not isinstance(parsed, dict): + continue + results[i]['vendor'] = parsed.get('vendor', results[i]['_filename']) + results[i]['product_name'] = parsed.get('product_name', '') + for entry in results: + for k in list(entry): + if k.startswith('_'): + del entry[k] + return results + + # ── Phase 2b: text mode → single batched LLM call ──────────────────── + _cat_guide = ( + 'restaurant/cafe/food court/bar → food/meal product; ' + 'airline/airport/transit/taxi/parking/rental car → travel product; ' + 'gas station/petrol/fuel → fuel product; ' + 'hotel/motel/lodging → accommodation product; ' + 'hardware/home improvement/tech/office supply → supplies product; ' + 'return "" if nothing fits' + ) + receipts_block = '' + for seq, i in enumerate(needs_llm, 1): + entry = results[i] + if entry['_ocr_failed']: + excerpt = f'[filename: {entry["_filename"]}]' + else: + excerpt = entry['_stripped'][:300] + receipts_block += f'\n=== Receipt {seq} ({entry["_filename"]}) ===\n{excerpt}\n' + + n = len(needs_llm) + batch_prompt = ( + f'Return ONLY a JSON array with exactly {n} objects, one per receipt below.\n' + f'Each object must have exactly two keys:\n' + f'"vendor": business name from the receipt header ' + f'(first 1-3 lines; ignore slogans and item names; ' + f'do NOT substitute a brand not clearly present).\n' + f'"product_name": single best match from [{product_list}].\n' + f'Category guide: {_cat_guide}\n' + f'JSON array only:\n{receipts_block}' + ) + try: + resp = await self._llm.submit( + [{'role': 'user', 'content': batch_prompt}], + caller='expenses_agent_receipt_parser', + ) + raw = (resp.content or '').strip() + first, last = raw.find('['), raw.rfind(']') + if first == -1 or last <= first: + raise ValueError(f'No JSON array in response: {raw[:200]}') + batch_data = json.loads(raw[first:last + 1]) + if len(batch_data) != n: + raise ValueError(f'Expected {n} items, got {len(batch_data)}') + for i, item in zip(needs_llm, batch_data): + v = str(item.get('vendor', '') or '').strip() + if v: + results[i]['vendor'] = v + results[i]['product_name'] = str(item.get('product_name', '') or '').strip() + logger.info('expenses_agent: batch LLM parsed %d receipts in 1 call', n) + except Exception as exc: + logger.warning('expenses_agent: batch LLM failed (%s) — falling back to individual calls', exc) + fallback_tasks = [ + self._parse_receipt_text( + results[i]['_stripped'], results[i]['_filename'], + expense_products=expense_products, + ) + for i in needs_llm + ] + fallback = await asyncio.gather(*fallback_tasks, return_exceptions=True) + for i, parsed in zip(needs_llm, fallback): + if isinstance(parsed, Exception) or not isinstance(parsed, dict): + continue + v = str(parsed.get('vendor', '') or '').strip() + if v: + results[i]['vendor'] = v + results[i]['product_name'] = str(parsed.get('product_name', '') or '').strip() + + # Strip internal bookkeeping keys before returning + for entry in results: + for k in list(entry): + if k.startswith('_'): + del entry[k] + return results + async def _parse_receipt_text(self, text: str, filename: str, expense_products: list = None, date_hint: str = None, diff --git a/tests/test_expenses_agent.py b/tests/test_expenses_agent.py index 240093d..58d7e23 100644 --- a/tests/test_expenses_agent.py +++ b/tests/test_expenses_agent.py @@ -856,6 +856,190 @@ async def test_non_image_mimetype_uses_text_path_in_vision_mode(): ) +# --------------------------------------------------------------------------- +# _batch_parse_receipts — batched LLM call for vendor + product_name +# --------------------------------------------------------------------------- + +def _make_receipt(filename='receipt.jpg', text='Acme\nTotal: $10.00', + b64='', mimetype='image/jpeg', date_from_name=None): + """Build a minimal receipt dict as produced by parse_upload.""" + return {'filename': filename, 'text': text, 'b64': b64, + 'mimetype': mimetype, 'date_from_name': date_from_name, + 'sha256': 'abc'} + + +@pytest.mark.asyncio +async def test_batch_parse_single_llm_call_for_multiple_receipts(): + """N text receipts must result in exactly 1 LLM call (batched prompt).""" + agent = _make_agent() + receipts = [ + _make_receipt('a.txt', 'Shell Gas\nTotal: $45.00'), + _make_receipt('b.txt', 'Marriott Hotel\nAmount Due: $180.00'), + _make_receipt('c.txt', 'Chipotle\nTotal: $12.75'), + ] + products = [{'id': 1, 'name': 'Meals'}, {'id': 2, 'name': 'Travel'}, {'id': 3, 'name': 'Fuel'}] + + llm_resp = MagicMock() + llm_resp.content = ( + '[{"vendor":"Shell","product_name":"Fuel"},' + '{"vendor":"Marriott","product_name":"Travel"},' + '{"vendor":"Chipotle","product_name":"Meals"}]' + ) + agent._llm.submit = AsyncMock(return_value=llm_resp) + + with patch('agent_service.agents.expenses_agent._get_vision_mode', return_value='text'): + results = await agent._batch_parse_receipts(receipts, products) + + agent._llm.submit.assert_called_once() + assert len(results) == 3 + assert results[0]['vendor'] == 'Shell' + assert results[1]['vendor'] == 'Marriott' + assert results[2]['vendor'] == 'Chipotle' + + +@pytest.mark.asyncio +async def test_batch_parse_amounts_from_regex_not_llm(): + """Amounts must come from regex (Phase 1), not from the LLM batch response.""" + agent = _make_agent() + receipts = [_make_receipt('r.txt', 'Acme Store\nTotal: $99.99')] + products = [{'id': 1, 'name': 'Supplies'}] + + llm_resp = MagicMock() + llm_resp.content = '[{"vendor":"Acme","product_name":"Supplies"}]' + agent._llm.submit = AsyncMock(return_value=llm_resp) + + with patch('agent_service.agents.expenses_agent._get_vision_mode', return_value='text'): + results = await agent._batch_parse_receipts(receipts, products) + + assert results[0]['amount'] == 99.99 + + +@pytest.mark.asyncio +async def test_batch_parse_no_private_keys_in_results(): + """Internal _-prefixed keys must be stripped from every result dict.""" + agent = _make_agent() + receipts = [_make_receipt('r.txt', 'Acme\nTotal: $10.00')] + products = [{'id': 1, 'name': 'Meals'}] + + llm_resp = MagicMock() + llm_resp.content = '[{"vendor":"Acme","product_name":"Meals"}]' + agent._llm.submit = AsyncMock(return_value=llm_resp) + + with patch('agent_service.agents.expenses_agent._get_vision_mode', return_value='text'): + results = await agent._batch_parse_receipts(receipts, products) + + for entry in results: + private = [k for k in entry if k.startswith('_')] + assert private == [], f'Private keys not cleaned up: {private}' + + +@pytest.mark.asyncio +async def test_batch_parse_bank_statement_skipped_no_llm(): + """Bank statements inside a batch must be skipped; no LLM call for them.""" + agent = _make_agent() + # 12 transaction lines → flagged as bank statement + stmt = '\n'.join(f'05/{i+1:02d} MERCHANT {i} ${10 + i}.99' for i in range(12)) + receipts = [ + _make_receipt('stmt.pdf', stmt), + _make_receipt('real.txt', 'Shell Gas\nTotal: $45.00'), + ] + products = [{'id': 1, 'name': 'Fuel'}] + + llm_resp = MagicMock() + llm_resp.content = '[{"vendor":"Shell","product_name":"Fuel"}]' + agent._llm.submit = AsyncMock(return_value=llm_resp) + + with patch('agent_service.agents.expenses_agent._get_vision_mode', return_value='text'): + results = await agent._batch_parse_receipts(receipts, products) + + # Only 1 item sent to LLM (the real receipt, not the statement) + agent._llm.submit.assert_called_once() + # Statement entry has skip=True + assert results[0].get('skip') is True + assert results[0]['amount'] == 0.0 + # Real receipt parsed normally + assert results[1]['vendor'] == 'Shell' + + +@pytest.mark.asyncio +async def test_batch_parse_falls_back_on_malformed_json(): + """When the batch LLM returns malformed JSON, falls back to individual calls.""" + agent = _make_agent() + receipts = [ + _make_receipt('a.txt', 'Shell\nTotal: $45.00'), + _make_receipt('b.txt', 'Marriott\nTotal: $180.00'), + ] + products = [{'id': 1, 'name': 'Travel'}] + + call_count = [0] + individual_resp = MagicMock() + individual_resp.content = '{"vendor":"Shell","product_name":"Travel"}' + + async def _side_effect(messages, caller=None): + call_count[0] += 1 + if call_count[0] == 1: + bad = MagicMock() + bad.content = 'not valid json at all' + return bad + return individual_resp + + agent._llm.submit = _side_effect + + with patch('agent_service.agents.expenses_agent._get_vision_mode', return_value='text'): + results = await agent._batch_parse_receipts(receipts, products) + + # 1 batch attempt + 2 individual fallback calls = 3 + assert call_count[0] == 3 + assert len(results) == 2 + + +@pytest.mark.asyncio +async def test_batch_parse_falls_back_on_wrong_item_count(): + """When the LLM returns a JSON array with wrong length, falls back.""" + agent = _make_agent() + receipts = [ + _make_receipt('a.txt', 'Shell\nTotal: $45.00'), + _make_receipt('b.txt', 'Marriott\nTotal: $180.00'), + ] + products = [{'id': 1, 'name': 'Travel'}] + + call_count = [0] + fallback_resp = MagicMock() + fallback_resp.content = '{"vendor":"Shell","product_name":"Travel"}' + + async def _side_effect(messages, caller=None): + call_count[0] += 1 + if call_count[0] == 1: + # Returns only 1 item, expected 2 + wrong = MagicMock() + wrong.content = '[{"vendor":"Shell","product_name":"Travel"}]' + return wrong + return fallback_resp + + agent._llm.submit = _side_effect + + with patch('agent_service.agents.expenses_agent._get_vision_mode', return_value='text'): + results = await agent._batch_parse_receipts(receipts, products) + + # 1 batch attempt + 2 individual fallback calls = 3 + assert call_count[0] == 3 + assert len(results) == 2 + + +@pytest.mark.asyncio +async def test_batch_parse_no_products_skips_llm(): + """When there are no expense products, the LLM is not called.""" + agent = _make_agent() + receipts = [_make_receipt('r.txt', 'Acme\nTotal: $10.00')] + agent._llm.submit = AsyncMock() + + with patch('agent_service.agents.expenses_agent._get_vision_mode', return_value='text'): + results = await agent._batch_parse_receipts(receipts, []) + + agent._llm.submit.assert_not_called() + assert len(results) == 1 + + # --------------------------------------------------------------------------- # parse_upload — receipt_parser.py # ---------------------------------------------------------------------------