diff --git a/agent_service/agents/expenses_agent.py b/agent_service/agents/expenses_agent.py index a808177..2957479 100644 --- a/agent_service/agents/expenses_agent.py +++ b/agent_service/agents/expenses_agent.py @@ -317,13 +317,63 @@ class ExpensesAgent(BaseAgent): return None + @staticmethod + def _match_category(category: str, expense_products: list) -> str: + """Map a vision-model category label to the nearest expense product name. + + Tries exact/substring match first, then a fuzzy SequenceMatcher pass. + Returns empty string when no reasonable match is found. + """ + if not expense_products or not category: + return '' + cat = category.lower().strip() + # Exact or substring match + for p in expense_products: + name = p['name'].lower() + if cat == name or cat in name or name in cat: + return p['name'] + # Fuzzy fallback (ratio >= 0.4) + names_lower = [p['name'].lower() for p in expense_products] + matches = difflib.get_close_matches(cat, names_lower, n=1, cutoff=0.4) + if matches: + for p in expense_products: + if p['name'].lower() == matches[0]: + return p['name'] + return '' + async def _parse_receipt_text(self, text: str, filename: str, expense_products: list = None, date_hint: str = None) -> dict: today = _date.today().isoformat() fallback = {'vendor': filename, 'amount': 0.0, 'date': date_hint or today, 'time': None, 'product_name': ''} - ocr_failed = not text or text.startswith('[') + + # ── Fast path: vision model already returned structured JSON ────────── + # receipt_parser._ocr_image_vision() returns a JSON string directly + # when a vision model is configured. Skip the second LLM call entirely. + stripped = (text or '').strip() + if stripped.startswith('{'): + try: + data = json.loads(stripped) + if 'amount' in data: + logger.debug('expenses_agent: using vision pre-extracted JSON for %s', filename) + # Map the vision category label → expense product name + product_name = self._match_category( + data.get('category', ''), expense_products or []) + # Vision model sometimes returns the string "null" instead of JSON null + raw_time = data.get('time') + time_val = None if raw_time in (None, 'null', 'None', '') else str(raw_time) + return { + 'vendor': str(data.get('vendor') or filename), + 'amount': float(data.get('amount', 0.0)), + 'date': str(data.get('date') or date_hint or today), + 'time': time_val, + 'product_name': product_name, + } + except (json.JSONDecodeError, ValueError, TypeError): + pass # not clean JSON — fall through to LLM path + + ocr_failed = not stripped or stripped.startswith('[') product_list = '' if expense_products: @@ -341,6 +391,13 @@ class ExpensesAgent(BaseAgent): f'Return ONLY valid JSON: {{"product_name": "..."}}' ) else: + # Keep both the header (vendor/date) and footer (totals) of the receipt. + # A plain [:N] cut discards the bottom of long receipts where the grand + # total lives — the primary cause of amount=0 extraction errors. + if len(stripped) > 3000: + receipt_text = stripped[:1500] + '\n[...]\n' + stripped[-1500:] + else: + receipt_text = stripped prompt = ( 'Extract expense details from the following receipt text. ' 'Return ONLY valid JSON with these keys:\n' @@ -354,7 +411,7 @@ class ExpensesAgent(BaseAgent): '"time" (string HH:MM in 24-hour format — the transaction time printed on the receipt; ' 'null if not present),\n' f'"product_name" (string, pick the best match from [{product_list}] or empty string).\n\n' - f'Receipt text:\n{text[:2000]}\n\nJSON only:' + f'Receipt text:\n{receipt_text}\n\nJSON only:' ) try: resp = await self._llm.submit( diff --git a/agent_service/tools/receipt_parser.py b/agent_service/tools/receipt_parser.py index 33479d6..6e49588 100644 --- a/agent_service/tools/receipt_parser.py +++ b/agent_service/tools/receipt_parser.py @@ -98,7 +98,13 @@ def _ocr_image(data: bytes, filename: str) -> str: def _ocr_image_vision(data: bytes, filename: str, ollama_url: str, model: str) -> str: - """Use an Ollama vision model to read a receipt image.""" + """Use an Ollama vision model to extract receipt data directly as JSON. + + Returns a JSON string {vendor, amount, date, time, category} so the + expenses agent can skip the second LLM extraction step entirely. + Returns empty string on any failure so the caller falls back to Tesseract. + """ + import json as _json try: import ollama as _ollama client = _ollama.Client(host=ollama_url) @@ -107,22 +113,41 @@ def _ocr_image_vision(data: bytes, filename: str, ollama_url: str, model: str) - messages=[{ 'role': 'user', 'content': ( - 'This is a photo of a paper receipt. ' - 'Transcribe ALL text exactly as it appears on the receipt. ' - 'Preserve every line in order: store name, address, date, time, ' - 'each line item with price, subtotal, tax, tip if present, and ' - 'the final total. Output the raw text only — no commentary, ' - 'no markdown, no explanations.' + 'This is a photo of a receipt. Extract these fields:\n' + '- vendor: the store or restaurant name\n' + '- amount: the FINAL total the customer paid. Look for a line ' + 'labeled "Total", "Grand Total", "Amount Due", or "Balance Due". ' + 'Do NOT use subtotal, tax, or tip. Return 0 if you cannot find ' + 'a clear final total.\n' + '- date: transaction date in YYYY-MM-DD format\n' + '- time: transaction time in HH:MM 24-hour format, or null\n' + '- category: one word describing the expense type — one of: ' + 'meals, fuel, hotel, office, transport, other\n\n' + 'Return ONLY a valid JSON object, no commentary, no markdown:\n' + '{"vendor":"...","amount":0.00,"date":"YYYY-MM-DD",' + '"time":"HH:MM or null","category":"..."}' ), 'images': [data], }], ) if isinstance(response, dict): - text = (response.get('message', {}).get('content') or '').strip() + raw = (response.get('message', {}).get('content') or '').strip() else: - text = (response.message.content or '').strip() - logger.debug('Vision OCR %s (%s): %d chars', filename, model, len(text)) - return text + raw = (response.message.content or '').strip() + + # Must contain a JSON object, not prose + first, last = raw.find('{'), raw.rfind('}') + if first == -1 or last <= first: + logger.warning('Vision OCR %s: model returned prose, falling back to Tesseract', + filename) + return '' + json_str = raw[first:last + 1] + parsed = _json.loads(json_str) + if 'amount' not in parsed: + logger.warning('Vision OCR %s: JSON missing amount field, falling back', filename) + return '' + logger.debug('Vision OCR %s (%s): extracted JSON ok', filename, model) + return json_str except ImportError: logger.warning('ollama package not installed — vision OCR unavailable for %s', filename) return '' diff --git a/tests/test_expenses_agent.py b/tests/test_expenses_agent.py index 9d1b920..ddfe128 100644 --- a/tests/test_expenses_agent.py +++ b/tests/test_expenses_agent.py @@ -289,8 +289,13 @@ async def test_plan_task_field_also_checked(): # --------------------------------------------------------------------------- @pytest.mark.asyncio -async def test_act_enters_awaiting_confirmation_on_first_pass(): - """First call with receipts and no confirm → mode becomes awaiting_confirmation.""" +async def test_act_creates_expenses_immediately(): + """Expenses are created in draft immediately — no confirmation gate. + + The old two-step confirm flow was removed because receipts are only + available in the initial /upload request, making a follow-up confirmation + turn impossible. _act() now creates draft expenses straight away. + """ agent = _make_agent() fake_receipt = { @@ -308,18 +313,20 @@ async def test_act_enters_awaiting_confirmation_on_first_pass(): parsed_result = {'vendor': 'Acme', 'amount': 10.00, 'date': '2026-05-09', 'time': None, 'product_name': ''} + sheet_result = MagicMock(success=True, record_id=42) + expense_result = MagicMock(success=True, record_id=99) + agent._et.get_employee_id_for_user = AsyncMock(return_value=1) - agent._et.get_expense_products = AsyncMock(return_value=[ - {'id': 1, 'name': 'Meals'} - ]) + agent._et.get_expense_products = AsyncMock(return_value=[{'id': 1, 'name': 'Meals'}]) + agent._et.create_expense_sheet = AsyncMock(return_value=sheet_result) + agent._et.create_expense = AsyncMock(return_value=expense_result) with patch.object(agent, '_parse_receipt_text', new=AsyncMock(return_value=parsed_result)): - result = await agent._act({}) + actions = await agent._act({}) - assert result == [] - assert agent._gathered_data['mode'] == 'awaiting_confirmation' - assert len(agent._confirmation_items) == 1 - vendor, parsed, is_dup = agent._confirmation_items[0] + assert any('Created expense sheet' in a for a in actions) + agent._et.create_expense_sheet.assert_called_once() + agent._et.create_expense.assert_called_once() @pytest.mark.asyncio @@ -415,6 +422,103 @@ async def test_act_no_employee_returns_empty_and_escalates(): assert any('No employee record' in e for e in agent._escalations_list) +# --------------------------------------------------------------------------- +# _match_category +# --------------------------------------------------------------------------- + +class TestMatchCategory: + PRODUCTS = [ + {'id': 1, 'name': 'Meals'}, + {'id': 2, 'name': 'Fuel'}, + {'id': 3, 'name': 'Hotel'}, + {'id': 4, 'name': 'Office Supplies'}, + {'id': 5, 'name': 'Transport'}, + {'id': 6, 'name': 'Other'}, + ] + + def test_exact_match(self): + assert ExpensesAgent._match_category('Meals', self.PRODUCTS) == 'Meals' + + def test_case_insensitive(self): + assert ExpensesAgent._match_category('meals', self.PRODUCTS) == 'Meals' + assert ExpensesAgent._match_category('FUEL', self.PRODUCTS) == 'Fuel' + + def test_substring_match(self): + # 'office' is a substring of 'Office Supplies' + assert ExpensesAgent._match_category('office', self.PRODUCTS) == 'Office Supplies' + + def test_fuzzy_match(self): + # 'transport' is close to 'Transport' + assert ExpensesAgent._match_category('transport', self.PRODUCTS) == 'Transport' + + def test_no_match_returns_empty(self): + assert ExpensesAgent._match_category('zxqwerty', self.PRODUCTS) == '' + + def test_empty_category(self): + assert ExpensesAgent._match_category('', self.PRODUCTS) == '' + + def test_empty_products(self): + assert ExpensesAgent._match_category('meals', []) == '' + + +# --------------------------------------------------------------------------- +# _parse_receipt_text — vision JSON fast path +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_parse_vision_json_fast_path(): + """When text is pre-extracted JSON from vision model, skip LLM call.""" + agent = _make_agent() + agent._llm.submit = AsyncMock() # should NOT be called + + vision_json = ('{"vendor":"McDonald\'s","amount":12.50,' + '"date":"2026-05-09","time":"13:30","category":"meals"}') + products = [{'id': 1, 'name': 'Meals'}, {'id': 2, 'name': 'Fuel'}] + + result = await agent._parse_receipt_text(vision_json, 'receipt.jpg', + expense_products=products) + + assert result['vendor'] == "McDonald's" + assert result['amount'] == 12.50 + assert result['date'] == '2026-05-09' + assert result['time'] == '13:30' + assert result['product_name'] == 'Meals' + agent._llm.submit.assert_not_called() + + +@pytest.mark.asyncio +async def test_parse_vision_json_null_time(): + """Vision model may return the string 'null' for time — normalise to None.""" + agent = _make_agent() + agent._llm.submit = AsyncMock() + + vision_json = '{"vendor":"Shell","amount":45.00,"date":"2026-05-09","time":"null","category":"fuel"}' + products = [{'id': 1, 'name': 'Meals'}, {'id': 2, 'name': 'Fuel'}] + + result = await agent._parse_receipt_text(vision_json, 'shell.jpg', + expense_products=products) + assert result['time'] is None + assert result['product_name'] == 'Fuel' + agent._llm.submit.assert_not_called() + + +@pytest.mark.asyncio +async def test_parse_non_json_text_falls_through_to_llm(): + """Plain OCR text (not JSON) should go through the LLM extraction path.""" + agent = _make_agent() + llm_resp = MagicMock() + llm_resp.content = '{"vendor":"Acme","amount":9.99,"date":"2026-05-09","time":null,"product_name":"Meals"}' + agent._llm.submit = AsyncMock(return_value=llm_resp) + + result = await agent._parse_receipt_text( + 'Acme Store\nTotal: $9.99', 'receipt.jpg', + expense_products=[{'id': 1, 'name': 'Meals'}], + ) + assert result['vendor'] == 'Acme' + assert result['amount'] == 9.99 + agent._llm.submit.assert_called_once() + + # --------------------------------------------------------------------------- # parse_upload — receipt_parser.py # ---------------------------------------------------------------------------