diff --git a/agent_service/agents/base_agent.py b/agent_service/agents/base_agent.py index 2ad2564..3bea3dd 100644 --- a/agent_service/agents/base_agent.py +++ b/agent_service/agents/base_agent.py @@ -11,14 +11,18 @@ logger = logging.getLogger(__name__) @dataclass class AgentDirective: directive_id: str - agent: str - task: str - params: dict - context: object # DirectiveContext - authorized_actions: list - constraints: dict + agent: str = '' + task: str = '' + params: dict = field(default_factory=dict) + context: object = field(default_factory=dict) # DirectiveContext or dict + authorized_actions: list = field(default_factory=list) + constraints: dict = field(default_factory=dict) approved: bool = False approval_item_id: object = None + # Convenience aliases used by agent _plan() methods + intent: str = '' + user_id: str = '' + agent_name: str = '' @dataclass @@ -32,10 +36,10 @@ class DirectiveContext: @dataclass class AgentReport: - directive_id: str agent: str - status: str # complete | partial | failed | escalated summary: str + directive_id: str = '' + status: str = 'complete' # complete | partial | failed | escalated actions_taken: list = field(default_factory=list) escalations: list = field(default_factory=list) peer_calls_made: list = field(default_factory=list) @@ -48,8 +52,11 @@ class AgentReport: class SweepReport: agent: str findings: list = field(default_factory=list) + actions: list = field(default_factory=list) actions_taken: list = field(default_factory=list) recommendations: list = field(default_factory=list) + summary: str = '' + error: object = None @dataclass diff --git a/tests/test_accounting_agent.py b/tests/test_accounting_agent.py new file mode 100644 index 0000000..860c35d --- /dev/null +++ b/tests/test_accounting_agent.py @@ -0,0 +1,247 @@ +"""Unit tests for AccountingAgent — plan, gather, reason, report, peer_bus, sweep.""" +import pytest +from unittest.mock import AsyncMock, MagicMock +from agent_service.agents.accounting_agent import AccountingAgent, ACCOUNTING_TOOLS +from agent_service.agents.base_agent import AgentDirective, AgentReport, SweepReport + + +def _directive(intent='', context=None): + return AgentDirective( + directive_id='test-d1', user_id='1', intent=intent, + context=context or {}, agent_name='accounting_agent', + ) + + +def _make_agent(): + odoo = MagicMock() + llm = MagicMock() + agent = AccountingAgent(odoo=odoo, llm=llm) + agent._at = MagicMock() + agent._at.get_trial_balance = AsyncMock(return_value=[]) + agent._at.get_tax_summary = AsyncMock(return_value={}) + agent._at.get_journal_entries = AsyncMock(return_value=[]) + agent._at.get_chart_of_accounts = AsyncMock(return_value=[]) + agent._at.get_account_balance = AsyncMock(return_value={'debit': 0, 'credit': 0}) + agent._at.flag_for_review = AsyncMock(return_value=True) + agent._at.post_chatter_note = AsyncMock(return_value=True) + return agent + + +# ── Meta ──────────────────────────────────────────────────────────────────── + +def test_tool_count(): + assert len(ACCOUNTING_TOOLS) <= 8 + +def test_agent_name(): + assert AccountingAgent.name == 'accounting_agent' + +def test_hipaa_locked(): + from agent_service.llm.llm_router import HIPAA_LOCKED_AGENTS + assert 'accounting_agent' in HIPAA_LOCKED_AGENTS + + +# ── _plan ─────────────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_plan_trial_balance_intent(): + agent = _make_agent() + plan = await agent._plan(_directive(intent='show trial balance')) + assert plan['fetch_trial_balance'] is True + +@pytest.mark.asyncio +async def test_plan_tax_intent(): + agent = _make_agent() + plan = await agent._plan(_directive(intent='get tax summary for Q1')) + assert plan['fetch_tax'] is True + +@pytest.mark.asyncio +async def test_plan_journal_entries_intent(): + agent = _make_agent() + plan = await agent._plan(_directive(intent='list journal entries')) + assert plan['fetch_entries'] is True + +@pytest.mark.asyncio +async def test_plan_unknown_intent_all_false(): + agent = _make_agent() + plan = await agent._plan(_directive(intent='random question')) + assert plan['fetch_trial_balance'] is False + assert plan['fetch_tax'] is False + assert plan['fetch_entries'] is False + +@pytest.mark.asyncio +async def test_plan_propagates_dates(): + agent = _make_agent() + plan = await agent._plan(_directive(context={'date_from': '2026-01-01', 'date_to': '2026-03-31'})) + assert plan['date_from'] == '2026-01-01' + assert plan['date_to'] == '2026-03-31' + + +# ── _gather ───────────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_gather_fetches_trial_balance(): + agent = _make_agent() + agent._at.get_trial_balance = AsyncMock(return_value=[{'account_name': 'Cash', 'balance': 5000}]) + ctx = {'plan': {'fetch_trial_balance': True, 'fetch_tax': False, 'fetch_entries': False, + 'date_from': None, 'date_to': None}} + data = await agent._gather(ctx) + assert 'trial_balance' in data + agent._at.get_trial_balance.assert_awaited_once() + +@pytest.mark.asyncio +async def test_gather_fetches_tax_summary(): + agent = _make_agent() + agent._at.get_tax_summary = AsyncMock(return_value={'total_tax_amount': 1500.0}) + ctx = {'plan': {'fetch_trial_balance': False, 'fetch_tax': True, 'fetch_entries': False, + 'date_from': None, 'date_to': None}} + data = await agent._gather(ctx) + assert 'tax_summary' in data + +@pytest.mark.asyncio +async def test_gather_falls_back_to_entries_when_nothing_else(): + agent = _make_agent() + ctx = {'plan': {'fetch_trial_balance': False, 'fetch_tax': False, 'fetch_entries': False, + 'date_from': None, 'date_to': None}} + data = await agent._gather(ctx) + assert 'entries' in data + agent._at.get_journal_entries.assert_awaited() + + +# ── _reason ───────────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_reason_flags_large_balance(): + agent = _make_agent() + agent._gathered_data = { + 'trial_balance': [{'account_name': 'Suspense', 'balance': 200000.0}] + } + analysis = await agent._reason({}) + assert len(analysis['flags']) == 1 + assert analysis['flags'][0]['balance'] == 200000.0 + +@pytest.mark.asyncio +async def test_reason_no_flags_below_threshold(): + agent = _make_agent() + agent._gathered_data = {'trial_balance': [{'account_name': 'Cash', 'balance': 50000.0}]} + analysis = await agent._reason({}) + assert analysis['flags'] == [] + + +# ── _report ───────────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_report_mentions_trial_balance_count(): + agent = _make_agent() + agent._gathered_data = { + 'trial_balance': [{'account_name': 'A', 'balance': 1000}, {'account_name': 'B', 'balance': 2000}] + } + agent._escalations_list = [] + report = await agent._report({}) + assert isinstance(report, AgentReport) + assert '2' in report.summary + +@pytest.mark.asyncio +async def test_report_mentions_tax_total(): + agent = _make_agent() + agent._gathered_data = {'tax_summary': {'total_tax_amount': 4500.0, 'total_tax_lines': 12}} + agent._escalations_list = [] + report = await agent._report({}) + assert '4500' in report.summary + +@pytest.mark.asyncio +async def test_report_fallback_message(): + agent = _make_agent() + agent._gathered_data = {} + agent._escalations_list = [] + report = await agent._report({}) + assert 'complete' in report.summary.lower() + + +# ── _dispatch_tool ─────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_dispatch_tool_get_journal_entries(): + agent = _make_agent() + agent._at.get_journal_entries = AsyncMock(return_value=[{'id': 1}]) + result = await agent._dispatch_tool('get_journal_entries', {'limit': 5}) + agent._at.get_journal_entries.assert_awaited_once_with(limit=5) + +@pytest.mark.asyncio +async def test_dispatch_tool_unknown_raises(): + agent = _make_agent() + with pytest.raises(ValueError, match='Unknown tool'): + await agent._dispatch_tool('nonexistent', {}) + +@pytest.mark.asyncio +async def test_dispatch_tool_flag_for_review(): + agent = _make_agent() + await agent._dispatch_tool('flag_for_review', + {'model': 'account.move', 'record_id': 1, 'reason': 'anomaly'}) + agent._at.flag_for_review.assert_awaited_once() + + +# ── handle_peer_request ────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_peer_request_trial_balance(): + agent = _make_agent() + agent._at.get_trial_balance = AsyncMock(return_value=[{'account_name': 'Cash'}]) + result = await agent.handle_peer_request('trial_balance', {}, 'dir-1') + assert 'trial_balance' in result + +@pytest.mark.asyncio +async def test_peer_request_account_balance(): + agent = _make_agent() + agent._at.get_account_balance = AsyncMock(return_value={'debit': 100, 'credit': 50}) + result = await agent.handle_peer_request('account_balance', {'account_id': 42}, 'dir-1') + assert 'debit' in result + +@pytest.mark.asyncio +async def test_peer_request_tax_summary(): + agent = _make_agent() + agent._at.get_tax_summary = AsyncMock(return_value={'total_tax_amount': 0}) + result = await agent.handle_peer_request('tax_summary', {}, 'dir-1') + assert 'total_tax_amount' in result + +@pytest.mark.asyncio +async def test_peer_request_unknown_returns_error(): + agent = _make_agent() + result = await agent.handle_peer_request('nonexistent', {}, 'dir-1') + assert 'error' in result + + +# ── sweep ──────────────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_sweep_returns_sweep_report(): + agent = _make_agent() + agent._at.get_trial_balance = AsyncMock(return_value=[]) + result = await agent.sweep() + assert isinstance(result, SweepReport) + +@pytest.mark.asyncio +async def test_sweep_finds_large_balance(): + agent = _make_agent() + agent._at.get_trial_balance = AsyncMock(return_value=[ + {'account_name': 'Suspense', 'balance': 600000.0} + ]) + result = await agent.sweep() + assert len(result.findings) == 1 + assert result.findings[0]['type'] == 'large_balance' + +@pytest.mark.asyncio +async def test_sweep_below_threshold_no_findings(): + agent = _make_agent() + agent._at.get_trial_balance = AsyncMock(return_value=[ + {'account_name': 'Cash', 'balance': 100.0} + ]) + result = await agent.sweep() + assert result.findings == [] + +@pytest.mark.asyncio +async def test_sweep_handles_exception(): + agent = _make_agent() + agent._at.get_trial_balance = AsyncMock(side_effect=Exception('DB error')) + result = await agent.sweep() + assert isinstance(result, SweepReport) + assert result.error is not None diff --git a/tests/test_accounting_tools.py b/tests/test_accounting_tools.py new file mode 100644 index 0000000..4d377f2 --- /dev/null +++ b/tests/test_accounting_tools.py @@ -0,0 +1,193 @@ +"""Unit tests for AccountingTools.""" +import pytest +from unittest.mock import AsyncMock, MagicMock +from agent_service.tools.accounting_tools import AccountingTools +from agent_service.tools.odoo_client import WriteResult + + +def _make_tools(): + odoo = MagicMock() + odoo.search_read = AsyncMock(return_value=[]) + odoo.write = AsyncMock(return_value=WriteResult( + success=True, model='', record_id=None, action='write')) + odoo.create = AsyncMock(return_value=WriteResult( + success=True, model='', record_id=42, action='create')) + odoo.call = AsyncMock(return_value=True) + return AccountingTools(odoo) + + +# ── get_journal_entries ─────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_get_journal_entries_default(): + t = _make_tools() + result = await t.get_journal_entries() + t._o.search_read.assert_awaited_once() + assert isinstance(result, list) + + +@pytest.mark.asyncio +async def test_get_journal_entries_with_filters(): + t = _make_tools() + t._o.search_read = AsyncMock(return_value=[{'id': 1, 'name': 'JE/001'}]) + result = await t.get_journal_entries(journal_id=5, date_from='2026-01-01', state='draft') + assert len(result) == 1 + call_args = t._o.search_read.call_args + domain = call_args[0][1] + assert ('journal_id', '=', 5) in domain + assert ('date', '>=', '2026-01-01') in domain + + +@pytest.mark.asyncio +async def test_get_journal_entries_date_to_filter(): + t = _make_tools() + await t.get_journal_entries(date_to='2026-01-31') + domain = t._o.search_read.call_args[0][1] + assert ('date', '<=', '2026-01-31') in domain + + +# ── get_chart_of_accounts ───────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_get_chart_of_accounts_default(): + t = _make_tools() + result = await t.get_chart_of_accounts() + t._o.search_read.assert_awaited_once() + assert isinstance(result, list) + + +@pytest.mark.asyncio +async def test_get_chart_of_accounts_with_type(): + t = _make_tools() + await t.get_chart_of_accounts(account_type='asset_current') + domain = t._o.search_read.call_args[0][1] + assert ('account_type', '=', 'asset_current') in domain + + +# ── get_account_balance ─────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_get_account_balance_found(): + t = _make_tools() + t._o.search_read = AsyncMock(return_value=[ + {'id': 1, 'code': '1000', 'name': 'Cash', 'balance': 5000.0} + ]) + result = await t.get_account_balance(account_id=1) + assert result['balance'] == 5000.0 + assert result['code'] == '1000' + + +@pytest.mark.asyncio +async def test_get_account_balance_not_found(): + t = _make_tools() + t._o.search_read = AsyncMock(return_value=[]) + result = await t.get_account_balance(account_id=999) + assert result == {} + + +# ── get_trial_balance ───────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_get_trial_balance_aggregates_lines(): + t = _make_tools() + t._o.search_read = AsyncMock(return_value=[ + {'account_id': [1, 'Cash'], 'debit': 1000.0, 'credit': 0.0, 'balance': 1000.0}, + {'account_id': [1, 'Cash'], 'debit': 500.0, 'credit': 0.0, 'balance': 500.0}, + {'account_id': [2, 'AP'], 'debit': 0.0, 'credit': 2000.0, 'balance': -2000.0}, + ]) + result = await t.get_trial_balance() + assert isinstance(result, list) + assert len(result) == 2 + cash_entry = next(r for r in result if r['account_id'] == 1) + assert cash_entry['debit'] == 1500.0 + + +@pytest.mark.asyncio +async def test_get_trial_balance_balance_computed(): + t = _make_tools() + t._o.search_read = AsyncMock(return_value=[ + {'account_id': [1, 'Cash'], 'debit': 1000.0, 'credit': 200.0, 'balance': 800.0}, + ]) + result = await t.get_trial_balance() + assert result[0]['balance'] == 800.0 + + +@pytest.mark.asyncio +async def test_get_trial_balance_empty_returns_empty(): + t = _make_tools() + result = await t.get_trial_balance() + assert result == [] + + +# ── get_tax_summary ─────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_get_tax_summary_returns_dict(): + t = _make_tools() + t._o.search_read = AsyncMock(return_value=[ + {'tax_ids': [1], 'debit': 100.0, 'credit': 0.0, 'balance': 100.0}, + {'tax_ids': [2], 'debit': 50.0, 'credit': 0.0, 'balance': 50.0}, + ]) + result = await t.get_tax_summary() + assert 'total_tax_lines' in result + assert result['total_tax_lines'] == 2 + assert 'total_tax_amount' in result + assert result['total_tax_amount'] == 150.0 + + +@pytest.mark.asyncio +async def test_get_tax_summary_with_date_range(): + t = _make_tools() + await t.get_tax_summary(date_from='2026-01-01', date_to='2026-01-31') + domain = t._o.search_read.call_args[0][1] + assert ('date', '>=', '2026-01-01') in domain + assert ('date', '<=', '2026-01-31') in domain + + +@pytest.mark.asyncio +async def test_get_tax_summary_includes_period(): + t = _make_tools() + result = await t.get_tax_summary(date_from='2026-01-01', date_to='2026-01-31') + assert result['period_from'] == '2026-01-01' + assert result['period_to'] == '2026-01-31' + + +# ── flag_for_review ─────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_flag_for_review_calls_message_post(): + t = _make_tools() + result = await t.flag_for_review('account.move', 1, 'Test reason', severity='high') + assert result is True + t._o.call.assert_awaited_once() + call_args = t._o.call.call_args + assert 'message_post' in str(call_args) + kwargs = call_args[0][3] + assert '[AI FLAG - HIGH]' in kwargs['body'] + assert 'Test reason' in kwargs['body'] + + +@pytest.mark.asyncio +async def test_flag_for_review_default_severity(): + t = _make_tools() + await t.flag_for_review('account.move', 1, 'reason') + kwargs = t._o.call.call_args[0][3] + assert 'MEDIUM' in kwargs['body'] + + +# ── post_chatter_note ───────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_post_chatter_note_returns_true(): + t = _make_tools() + result = await t.post_chatter_note('account.move', 1, 'Test note') + assert result is True + t._o.call.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_post_chatter_note_includes_body(): + t = _make_tools() + await t.post_chatter_note('account.move', 5, 'My note text') + kwargs = t._o.call.call_args[0][3] + assert kwargs['body'] == 'My note text' diff --git a/tests/test_crm_agent.py b/tests/test_crm_agent.py new file mode 100644 index 0000000..13aa8a2 --- /dev/null +++ b/tests/test_crm_agent.py @@ -0,0 +1,218 @@ +"""Unit tests for CrmAgent — plan, gather, reason, report, peer_bus, sweep.""" +import pytest +from unittest.mock import AsyncMock, MagicMock +from agent_service.agents.crm_agent import CrmAgent, CRM_TOOLS +from agent_service.agents.base_agent import AgentDirective, AgentReport, SweepReport + + +def _directive(intent='', context=None): + return AgentDirective( + directive_id='crm-d1', user_id='1', intent=intent, + context=context or {}, agent_name='crm_agent', + ) + + +def _make_agent(): + agent = CrmAgent(odoo=MagicMock(), llm=MagicMock()) + agent._ct = MagicMock() + agent._ct.get_pipeline_summary = AsyncMock(return_value={ + 'total_opportunities': 10, 'weighted_pipeline': 50000.0, + }) + agent._ct.get_leads = AsyncMock(return_value=[]) + agent._ct.get_opportunities = AsyncMock(return_value=[]) + agent._ct.get_won_lost_analysis = AsyncMock(return_value={'won_count': 5, 'lost_count': 2}) + agent._ct.update_lead_stage = AsyncMock(return_value=True) + agent._ct.assign_lead = AsyncMock(return_value=True) + agent._ct.log_activity = AsyncMock(return_value=True) + agent._ct.post_chatter_note = AsyncMock(return_value=True) + return agent + + +# ── Meta ──────────────────────────────────────────────────────────────────── + +def test_tool_count(): + assert len(CRM_TOOLS) <= 8 + +def test_agent_name(): + assert CrmAgent.name == 'crm_agent' + + +# ── _plan ─────────────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_plan_pipeline_intent(): + agent = _make_agent() + plan = await agent._plan(_directive(intent='show pipeline summary')) + assert plan['fetch_pipeline'] is True + +@pytest.mark.asyncio +async def test_plan_leads_intent(): + agent = _make_agent() + plan = await agent._plan(_directive(intent='list my leads')) + assert plan['fetch_leads'] is True + +@pytest.mark.asyncio +async def test_plan_opportunities_intent(): + agent = _make_agent() + plan = await agent._plan(_directive(intent='show opportunities')) + assert plan['fetch_opportunities'] is True + +@pytest.mark.asyncio +async def test_plan_won_lost_intent(): + agent = _make_agent() + plan = await agent._plan(_directive(intent='show won and lost analysis')) + assert plan['fetch_won_lost'] is True + +@pytest.mark.asyncio +async def test_plan_propagates_user_id(): + agent = _make_agent() + plan = await agent._plan(_directive(context={'user_id': 7})) + assert plan['user_id'] == 7 + + +# ── _gather ───────────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_gather_pipeline_fetched_by_default(): + agent = _make_agent() + ctx = {'plan': {'fetch_pipeline': True, 'fetch_leads': False, + 'fetch_opportunities': False, 'fetch_won_lost': False, 'user_id': None}} + data = await agent._gather(ctx) + assert 'pipeline' in data + agent._ct.get_pipeline_summary.assert_awaited_once() + +@pytest.mark.asyncio +async def test_gather_leads_when_requested(): + agent = _make_agent() + agent._ct.get_leads = AsyncMock(return_value=[{'id': 1, 'name': 'Lead A'}]) + ctx = {'plan': {'fetch_pipeline': False, 'fetch_leads': True, + 'fetch_opportunities': False, 'fetch_won_lost': False, 'user_id': None}} + data = await agent._gather(ctx) + assert 'leads' in data + +@pytest.mark.asyncio +async def test_gather_won_lost_when_requested(): + agent = _make_agent() + ctx = {'plan': {'fetch_pipeline': False, 'fetch_leads': False, + 'fetch_opportunities': False, 'fetch_won_lost': True, 'user_id': None}} + data = await agent._gather(ctx) + assert 'won_lost' in data + + +# ── _reason ───────────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_reason_low_pipeline_escalates(): + agent = _make_agent() + agent._gathered_data = {'pipeline': {'weighted_pipeline': 500.0, 'total_opportunities': 2}} + analysis = await agent._reason({}) + assert len(analysis['escalations']) == 1 + assert 'Low weighted pipeline' in analysis['escalations'][0] + +@pytest.mark.asyncio +async def test_reason_healthy_pipeline_no_escalation(): + agent = _make_agent() + agent._gathered_data = {'pipeline': {'weighted_pipeline': 100000.0, 'total_opportunities': 20}} + analysis = await agent._reason({}) + assert analysis['escalations'] == [] + + +# ── _report ───────────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_report_includes_pipeline_stats(): + agent = _make_agent() + agent._gathered_data = { + 'pipeline': {'total_opportunities': 15, 'weighted_pipeline': 75000.0} + } + agent._escalations_list = [] + report = await agent._report({}) + assert isinstance(report, AgentReport) + assert '15' in report.summary + +@pytest.mark.asyncio +async def test_report_includes_won_lost(): + agent = _make_agent() + agent._gathered_data = {'won_lost': {'won_count': 8, 'lost_count': 3}} + agent._escalations_list = [] + report = await agent._report({}) + assert '8' in report.summary + assert '3' in report.summary + +@pytest.mark.asyncio +async def test_report_fallback_message(): + agent = _make_agent() + agent._gathered_data = {} + agent._escalations_list = [] + report = await agent._report({}) + assert 'crm' in report.summary.lower() or 'complete' in report.summary.lower() + + +# ── _dispatch_tool ─────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_dispatch_get_pipeline_summary(): + agent = _make_agent() + await agent._dispatch_tool('get_pipeline_summary', {}) + agent._ct.get_pipeline_summary.assert_awaited_once() + +@pytest.mark.asyncio +async def test_dispatch_update_lead_stage(): + agent = _make_agent() + await agent._dispatch_tool('update_lead_stage', {'lead_id': 1, 'stage_id': 3}) + agent._ct.update_lead_stage.assert_awaited_once_with(lead_id=1, stage_id=3) + +@pytest.mark.asyncio +async def test_dispatch_unknown_tool_raises(): + agent = _make_agent() + with pytest.raises(ValueError, match='Unknown tool'): + await agent._dispatch_tool('nonexistent', {}) + + +# ── handle_peer_request ────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_peer_pipeline_summary(): + agent = _make_agent() + result = await agent.handle_peer_request('pipeline_summary', {}, 'dir-1') + assert 'total_opportunities' in result or 'weighted_pipeline' in result + +@pytest.mark.asyncio +async def test_peer_opportunities(): + agent = _make_agent() + agent._ct.get_opportunities = AsyncMock(return_value=[{'id': 1}]) + result = await agent.handle_peer_request('opportunities', {'user_id': 5}, 'dir-1') + assert 'opportunities' in result + +@pytest.mark.asyncio +async def test_peer_unknown_returns_error(): + agent = _make_agent() + result = await agent.handle_peer_request('bad_type', {}, 'dir-1') + assert 'error' in result + + +# ── sweep ──────────────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_sweep_low_pipeline_generates_finding(): + agent = _make_agent() + agent._ct.get_pipeline_summary = AsyncMock(return_value={'weighted_pipeline': 100.0}) + result = await agent.sweep() + assert isinstance(result, SweepReport) + assert len(result.findings) == 1 + assert result.findings[0]['type'] == 'low_pipeline' + +@pytest.mark.asyncio +async def test_sweep_healthy_pipeline_no_findings(): + agent = _make_agent() + agent._ct.get_pipeline_summary = AsyncMock(return_value={'weighted_pipeline': 99999.0}) + result = await agent.sweep() + assert result.findings == [] + +@pytest.mark.asyncio +async def test_sweep_handles_exception(): + agent = _make_agent() + agent._ct.get_pipeline_summary = AsyncMock(side_effect=Exception('network error')) + result = await agent.sweep() + assert isinstance(result, SweepReport) + assert result.error is not None diff --git a/tests/test_crm_tools.py b/tests/test_crm_tools.py new file mode 100644 index 0000000..643b470 --- /dev/null +++ b/tests/test_crm_tools.py @@ -0,0 +1,192 @@ +"""Unit tests for CrmTools.""" +import pytest +from unittest.mock import AsyncMock, MagicMock +from agent_service.tools.crm_tools import CrmTools +from agent_service.tools.odoo_client import WriteResult + + +def _make_tools(): + odoo = MagicMock() + odoo.search_read = AsyncMock(return_value=[]) + odoo.write = AsyncMock(return_value=WriteResult( + success=True, model='', record_id=None, action='write')) + odoo.create = AsyncMock(return_value=WriteResult( + success=True, model='', record_id=42, action='create')) + odoo.call = AsyncMock(return_value=True) + return CrmTools(odoo) + + +# ── get_leads ──────────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_get_leads_default(): + t = _make_tools() + result = await t.get_leads() + t._o.search_read.assert_awaited_once() + domain = t._o.search_read.call_args[0][1] + assert ('type', '=', 'lead') in domain + assert isinstance(result, list) + + +@pytest.mark.asyncio +async def test_get_leads_with_stage_filter(): + t = _make_tools() + await t.get_leads(stage_id=3) + domain = t._o.search_read.call_args[0][1] + assert ('stage_id', '=', 3) in domain + + +@pytest.mark.asyncio +async def test_get_leads_with_user_filter(): + t = _make_tools() + await t.get_leads(user_id=5) + domain = t._o.search_read.call_args[0][1] + assert ('user_id', '=', 5) in domain + + +# ── get_opportunities ──────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_get_opportunities_default(): + t = _make_tools() + result = await t.get_opportunities() + domain = t._o.search_read.call_args[0][1] + assert ('type', '=', 'opportunity') in domain + assert isinstance(result, list) + + +@pytest.mark.asyncio +async def test_get_opportunities_filters(): + t = _make_tools() + await t.get_opportunities(stage_id=2, user_id=7) + domain = t._o.search_read.call_args[0][1] + assert ('stage_id', '=', 2) in domain + assert ('user_id', '=', 7) in domain + + +# ── get_pipeline_summary ───────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_get_pipeline_summary_empty(): + t = _make_tools() + result = await t.get_pipeline_summary() + assert 'total_opportunities' in result + assert result['total_opportunities'] == 0 + assert 'weighted_pipeline' in result + + +@pytest.mark.asyncio +async def test_get_pipeline_summary_aggregates_by_stage(): + t = _make_tools() + t._o.search_read = AsyncMock(side_effect=[ + [{'id': 1, 'name': 'New', 'sequence': 1}], # stages + [ + {'stage_id': [1, 'New'], 'expected_revenue': 10000.0, 'probability': 20}, + {'stage_id': [1, 'New'], 'expected_revenue': 5000.0, 'probability': 50}, + ], + ]) + result = await t.get_pipeline_summary() + assert result['total_opportunities'] == 2 + assert result['total_pipeline'] == 15000.0 + assert len(result['stages']) == 1 + assert result['stages'][0]['count'] == 2 + + +@pytest.mark.asyncio +async def test_get_pipeline_summary_weighted_calculation(): + t = _make_tools() + t._o.search_read = AsyncMock(side_effect=[ + [], + [{'stage_id': [1, 'Stage'], 'expected_revenue': 1000.0, 'probability': 50}], + ]) + result = await t.get_pipeline_summary() + assert result['weighted_pipeline'] == 500.0 + + +# ── update_lead_stage ───────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_update_lead_stage_success(): + t = _make_tools() + result = await t.update_lead_stage(lead_id=1, stage_id=3) + assert result is True + t._o.write.assert_awaited_once_with('crm.lead', [1], {'stage_id': 3}) + + +@pytest.mark.asyncio +async def test_update_lead_stage_failure(): + t = _make_tools() + t._o.write = AsyncMock(return_value=WriteResult( + success=False, model='', record_id=None, action='write', error='permission denied')) + result = await t.update_lead_stage(lead_id=1, stage_id=3) + assert result is False + + +# ── assign_lead ─────────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_assign_lead_success(): + t = _make_tools() + result = await t.assign_lead(lead_id=5, user_id=10) + assert result is True + t._o.write.assert_awaited_once_with('crm.lead', [5], {'user_id': 10}) + + +# ── log_activity ────────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_log_activity_without_type_record(): + t = _make_tools() + t._o.search_read = AsyncMock(return_value=[]) + result = await t.log_activity(lead_id=1, activity_type='call', note='Called customer') + assert result is True + t._o.call.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_log_activity_with_type_record(): + t = _make_tools() + t._o.search_read = AsyncMock(return_value=[{'id': 7}]) + await t.log_activity(lead_id=1, activity_type='email', note='Sent email', date_deadline='2026-06-01') + call_args = t._o.call.call_args[0] + vals = call_args[2][0] + assert vals['activity_type_id'] == 7 + assert vals['date_deadline'] == '2026-06-01' + assert vals['note'] == 'Sent email' + + +# ── get_won_lost_analysis ───────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_get_won_lost_analysis_counts(): + t = _make_tools() + t._o.search_read = AsyncMock(side_effect=[ + [{'expected_revenue': 5000.0}, {'expected_revenue': 3000.0}], # won + [{'expected_revenue': 1000.0}], # lost + ]) + result = await t.get_won_lost_analysis() + assert result['won_count'] == 2 + assert result['won_revenue'] == 8000.0 + assert result['lost_count'] == 1 + assert result['lost_revenue'] == 1000.0 + + +@pytest.mark.asyncio +async def test_get_won_lost_analysis_with_dates(): + t = _make_tools() + await t.get_won_lost_analysis(date_from='2026-01-01', date_to='2026-01-31') + first_call = t._o.search_read.call_args_list[0][0][1] + assert ('date_closed', '>=', '2026-01-01') in first_call + + +# ── post_chatter_note ───────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_post_chatter_note_calls_message_post(): + t = _make_tools() + result = await t.post_chatter_note('crm.lead', 1, 'Note text') + assert result is True + t._o.call.assert_awaited_once() + call_args = t._o.call.call_args[0] + assert call_args[0] == 'crm.lead' + assert call_args[1] == 'message_post' diff --git a/tests/test_elearning_agent.py b/tests/test_elearning_agent.py new file mode 100644 index 0000000..a3ae66f --- /dev/null +++ b/tests/test_elearning_agent.py @@ -0,0 +1,317 @@ +"""Unit tests for ElearningAgent — tool count, plan, gather, reason, act, report, peer_bus, sweep.""" +import pytest +from unittest.mock import AsyncMock, MagicMock +from agent_service.agents.elearning_agent import ElearningAgent, ELEARNING_TOOLS +from agent_service.agents.base_agent import AgentReport, SweepReport + + +def _make_agent(): + agent = ElearningAgent(odoo=MagicMock(), llm=MagicMock()) + agent._el = MagicMock() + agent._el.get_courses = AsyncMock(return_value=[ + {'id': 1, 'name': 'Python 101', 'completion_rate': 80.0}, + {'id': 2, 'name': 'HR Basics', 'completion_rate': 20.0}, + ]) + agent._el.get_course_details = AsyncMock(return_value={ + 'channel': {'name': 'Python 101'}, 'slide_count': 10, + 'enrolled_users': [], 'slide_completion': [], + }) + agent._el.get_learning_summary = AsyncMock(return_value={ + 'total_courses': 5, 'total_enrollments': 50, 'avg_completion': 65.0, + 'low_completion_courses': [{'id': 2, 'name': 'HR Basics', 'completion_rate': 20.0}], + }) + agent._el.create_course = AsyncMock(return_value={'id': 10, 'name': 'New Course', 'success': True}) + agent._el.update_course = AsyncMock(return_value={'success': True}) + agent._el.add_section = AsyncMock(return_value={'id': 20, 'name': 'Section 1', 'success': True}) + agent._el.create_slide = AsyncMock(return_value={'id': 30, 'name': 'Slide 1', 'slide_type': 'webpage', 'success': True}) + agent._el.enroll_user = AsyncMock(return_value={'success': True}) + agent._el.post_chatter_note = AsyncMock(return_value=True) + agent._el.suggest_next_course = AsyncMock(return_value=[{'id': 3, 'name': 'Advanced Python'}]) + return agent + + +# ── Meta ──────────────────────────────────────────────────────────────────── + +def test_tool_count_exactly_8(): + assert len(ELEARNING_TOOLS) == 8 + +def test_tool_names(): + names = {t['name'] for t in ELEARNING_TOOLS} + assert 'get_courses' in names + assert 'get_course_details' in names + assert 'create_course' in names + assert 'update_course' in names + assert 'add_section' in names + assert 'create_slide' in names + assert 'enroll_user' in names + assert 'post_chatter_note' in names + +def test_removed_tools_not_present(): + names = {t['name'] for t in ELEARNING_TOOLS} + assert 'get_course_stats' not in names + assert 'get_enrolled_users' not in names + assert 'get_slide_completion' not in names + assert 'get_learning_summary' not in names + assert 'publish_course' not in names + assert 'flag_low_completion' not in names + assert 'suggest_next_course' not in names + +def test_update_course_has_website_published_param(): + update = next(t for t in ELEARNING_TOOLS if t['name'] == 'update_course') + assert 'website_published' in update['parameters'] + +def test_agent_name(): + assert ElearningAgent.name == 'elearning_agent' + + +# ── _plan ─────────────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_plan_create_intent(): + agent = _make_agent() + agent._directive = MagicMock() + agent._directive.task = 'create a new course on Python' + agent._directive.params = {} + plan = await agent._plan() + assert plan['intent'] == 'create' + +@pytest.mark.asyncio +async def test_plan_build_intent(): + agent = _make_agent() + agent._directive = MagicMock() + agent._directive.task = 'build me a course on compliance' + agent._directive.params = {} + plan = await agent._plan() + assert plan['intent'] == 'create' + +@pytest.mark.asyncio +async def test_plan_enroll_intent(): + agent = _make_agent() + agent._directive = MagicMock() + agent._directive.task = 'enroll Alice in Python 101' + agent._directive.params = {} + plan = await agent._plan() + assert plan['intent'] == 'enroll' + +@pytest.mark.asyncio +async def test_plan_read_intent_default(): + agent = _make_agent() + agent._directive = MagicMock() + agent._directive.task = 'show me the course list' + agent._directive.params = {} + plan = await agent._plan() + assert plan['intent'] == 'read' + + +# ── _gather ───────────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_gather_reads_summary_when_no_channel(): + agent = _make_agent() + await agent._gather({'intent': 'read', 'channel_id': None}) + agent._el.get_learning_summary.assert_awaited_once() + +@pytest.mark.asyncio +async def test_gather_reads_course_details_when_channel_given(): + agent = _make_agent() + await agent._gather({'intent': 'read', 'channel_id': 1}) + agent._el.get_course_details.assert_awaited_once_with(channel_id=1) + + +# ── _reason ───────────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_reason_read_returns_low_completion_list(): + agent = _make_agent() + agent._gathered = { + 'intent': 'read', + 'summary': { + 'low_completion_courses': [{'id': 2, 'completion_rate': 20.0}] + } + } + reasoning = await agent._reason() + assert reasoning['intent'] == 'read' + assert len(reasoning['low_completion']) == 1 + + +# ── _act ───────────────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_act_read_intent_does_nothing(): + agent = _make_agent() + await agent._act({'intent': 'read', 'low_completion': []}) + agent._el.post_chatter_note.assert_not_called() + +@pytest.mark.asyncio +async def test_act_flags_low_completion_courses(): + agent = _make_agent() + await agent._act({ + 'intent': 'read', + 'low_completion': [{'id': 2, 'completion_rate': 15.0}], + }) + agent._el.post_chatter_note.assert_awaited_once() + call_kwargs = agent._el.post_chatter_note.call_args.kwargs + assert call_kwargs['model'] == 'slide.channel' + assert call_kwargs['record_id'] == 2 + assert '[AI FLAG]' in call_kwargs['note'] + +@pytest.mark.asyncio +async def test_act_caps_at_3_flags(): + agent = _make_agent() + low_courses = [{'id': i, 'completion_rate': 10.0} for i in range(5)] + await agent._act({'intent': 'read', 'low_completion': low_courses}) + assert agent._el.post_chatter_note.await_count == 3 + + +# ── _report ───────────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_report_includes_summary_stats(): + agent = _make_agent() + agent._directive = MagicMock() + agent._directive.directive_id = 'test-d1' + agent._gathered = { + 'summary': {'total_courses': 5, 'total_enrollments': 50, 'avg_completion': 65.0} + } + agent._actions_taken = [] + report = await agent._report() + assert isinstance(report, AgentReport) + assert '5' in report.summary + +@pytest.mark.asyncio +async def test_report_includes_course_details(): + agent = _make_agent() + agent._directive = MagicMock() + agent._directive.directive_id = 'test-d1' + agent._gathered = { + 'course_details': {'channel': {'name': 'Python 101'}, 'slide_count': 10} + } + agent._actions_taken = [] + report = await agent._report() + assert 'Python 101' in report.summary + assert '10' in report.summary + +@pytest.mark.asyncio +async def test_report_fallback(): + agent = _make_agent() + agent._directive = MagicMock() + agent._directive.directive_id = 'test-d1' + agent._gathered = {} + agent._actions_taken = [] + report = await agent._report() + assert 'complete' in report.summary.lower() or 'elearning' in report.summary.lower() + + +# ── tool dispatchers ───────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_tool_get_courses(): + agent = _make_agent() + result = await agent._tool_get_courses() + agent._el.get_courses.assert_awaited_once() + assert len(result) == 2 + +@pytest.mark.asyncio +async def test_tool_get_course_details(): + agent = _make_agent() + result = await agent._tool_get_course_details(channel_id=1) + agent._el.get_course_details.assert_awaited_once_with(channel_id=1) + assert 'slide_count' in result + +@pytest.mark.asyncio +async def test_tool_create_course_records_action(): + agent = _make_agent() + result = await agent._tool_create_course(name='New Course') + assert result['success'] is True + assert any(a['action'] == 'create_course' for a in agent._actions_taken) + +@pytest.mark.asyncio +async def test_tool_update_course_with_publish(): + agent = _make_agent() + await agent._tool_update_course(channel_id=1, website_published=True) + agent._el.update_course.assert_awaited_once_with( + channel_id=1, name=None, description=None, enroll_policy=None, website_published=True) + +@pytest.mark.asyncio +async def test_tool_add_section_records_action(): + agent = _make_agent() + result = await agent._tool_add_section(channel_id=1, name='Section A', sequence=10) + assert result['success'] is True + assert any(a['action'] == 'add_section' for a in agent._actions_taken) + +@pytest.mark.asyncio +async def test_tool_create_slide_records_action(): + agent = _make_agent() + result = await agent._tool_create_slide(channel_id=1, name='Slide 1') + assert result['success'] is True + assert any(a['action'] == 'create_slide' for a in agent._actions_taken) + +@pytest.mark.asyncio +async def test_tool_enroll_user(): + agent = _make_agent() + result = await agent._tool_enroll_user(channel_id=1, partner_id=5) + agent._el.enroll_user.assert_awaited_once_with(channel_id=1, partner_id=5) + +@pytest.mark.asyncio +async def test_tool_post_chatter_note(): + agent = _make_agent() + result = await agent._tool_post_chatter_note( + model='slide.channel', record_id=1, note='Test note') + assert result['success'] is True + + +# ── handle_peer_request ────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_peer_learning_summary(): + agent = _make_agent() + result = await agent.handle_peer_request('learning_summary', {}, 'dir-1') + assert result['success'] is True + assert 'total_courses' in result + +@pytest.mark.asyncio +async def test_peer_suggest_courses(): + agent = _make_agent() + result = await agent.handle_peer_request('suggest_courses', {'partner_id': 3}, 'dir-1') + assert result['success'] is True + assert 'courses' in result + +@pytest.mark.asyncio +async def test_peer_suggest_courses_missing_partner_id(): + agent = _make_agent() + result = await agent.handle_peer_request('suggest_courses', {}, 'dir-1') + assert result['success'] is False + +@pytest.mark.asyncio +async def test_peer_unknown_returns_error(): + agent = _make_agent() + result = await agent.handle_peer_request('bad_type', {}, 'dir-1') + assert result['success'] is False + + +# ── sweep ──────────────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_sweep_finds_low_completion_course(): + agent = _make_agent() + result = await agent.sweep() + assert isinstance(result, SweepReport) + assert any(f['issue'] == 'low_completion' for f in result.findings) + +@pytest.mark.asyncio +async def test_sweep_no_findings_when_all_complete(): + agent = _make_agent() + agent._el.get_learning_summary = AsyncMock(return_value={ + 'total_courses': 3, 'total_enrollments': 30, + 'avg_completion': 90.0, 'low_completion_courses': [], + }) + result = await agent.sweep() + assert result.findings == [] + +@pytest.mark.asyncio +async def test_sweep_handles_exception(): + agent = _make_agent() + agent._el.get_learning_summary = AsyncMock(side_effect=Exception('timeout')) + result = await agent.sweep() + assert isinstance(result, SweepReport) + assert result.findings[0]['issue'] == 'sweep_failed' diff --git a/tests/test_elearning_tools.py b/tests/test_elearning_tools.py new file mode 100644 index 0000000..a838d04 --- /dev/null +++ b/tests/test_elearning_tools.py @@ -0,0 +1,301 @@ +"""Unit tests for ElearningTools.""" +import pytest +from unittest.mock import AsyncMock, MagicMock, call +from agent_service.tools.elearning_tools import ElearningTools +from agent_service.tools.odoo_client import WriteResult + + +def _make_tools(): + odoo = MagicMock() + odoo.search_read = AsyncMock(return_value=[]) + odoo.write = AsyncMock(return_value=WriteResult( + success=True, model='', record_id=None, action='write')) + odoo.create = AsyncMock(return_value=WriteResult( + success=True, model='slide.channel', record_id=10, action='create')) + odoo.call = AsyncMock(return_value=True) + return ElearningTools(odoo) + + +# ── get_courses ─────────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_get_courses_default(): + t = _make_tools() + result = await t.get_courses() + t._o.search_read.assert_awaited_once() + domain = t._o.search_read.call_args[0][1] + assert ('active', '=', True) in domain + assert isinstance(result, list) + + +@pytest.mark.asyncio +async def test_get_courses_inactive(): + t = _make_tools() + await t.get_courses(active=False) + domain = t._o.search_read.call_args[0][1] + assert ('active', '=', False) in domain + + +# ── get_course_stats ────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_get_course_stats_found(): + t = _make_tools() + t._o.search_read = AsyncMock(side_effect=[ + [{'name': 'Python 101', 'total_slides': 10, 'members_count': 5, + 'completion_rate': 80.0, 'total_time': 300}], # channel + [{'name': 'Slide 1', 'slide_type': 'video', 'completion_rate': 90.0, + 'likes': 3, 'dislikes': 0, 'view_count': 50}], # slides + ]) + result = await t.get_course_stats(channel_id=1) + assert 'channel' in result + assert result['slide_count'] == 1 + assert 'avg_slide_completion' in result + assert 'total_views' in result + + +@pytest.mark.asyncio +async def test_get_course_stats_not_found(): + t = _make_tools() + t._o.search_read = AsyncMock(return_value=[]) + result = await t.get_course_stats(channel_id=999) + assert result == {} + + +# ── get_enrolled_users ──────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_get_enrolled_users(): + t = _make_tools() + t._o.search_read = AsyncMock(return_value=[ + {'partner_id': [1, 'Alice'], 'completion': 80, 'channel_completion': 80.0} + ]) + result = await t.get_enrolled_users(channel_id=1) + assert len(result) == 1 + domain = t._o.search_read.call_args[0][1] + assert ('channel_id', '=', 1) in domain + + +# ── get_slide_completion ────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_get_slide_completion_default(): + t = _make_tools() + await t.get_slide_completion(channel_id=1) + domain = t._o.search_read.call_args[0][1] + assert ('channel_id', '=', 1) in domain + assert ('channel_completion', '>=', 0.0) in domain + + +@pytest.mark.asyncio +async def test_get_slide_completion_min_filter(): + t = _make_tools() + await t.get_slide_completion(channel_id=1, min_completion=50.0) + domain = t._o.search_read.call_args[0][1] + assert ('channel_completion', '>=', 50.0) in domain + + +# ── get_course_details ──────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_get_course_details_combines_data(): + t = _make_tools() + t._o.search_read = AsyncMock(side_effect=[ + [{'name': 'Python 101', 'total_slides': 5, 'members_count': 10, + 'completion_rate': 70.0, 'total_time': 120}], + [{'name': 'Slide A', 'slide_type': 'video', 'completion_rate': 80.0, + 'likes': 1, 'dislikes': 0, 'view_count': 20}], + [{'partner_id': [1, 'Alice'], 'completion': 80, 'channel_completion': 80.0}], + [{'partner_id': [1, 'Alice'], 'channel_completion': 80.0}], + ]) + result = await t.get_course_details(channel_id=1) + assert 'channel' in result + assert 'enrolled_users' in result + assert 'slide_completion' in result + + +# ── get_learning_summary ────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_get_learning_summary_structure(): + t = _make_tools() + t._o.search_read = AsyncMock(return_value=[ + {'name': 'Course A', 'members_count': 10, 'completion_rate': 80.0}, + {'name': 'Course B', 'members_count': 5, 'completion_rate': 20.0}, + ]) + result = await t.get_learning_summary() + assert result['total_courses'] == 2 + assert result['total_enrollments'] == 15 + assert 'avg_completion' in result + assert 'low_completion_courses' in result + + +@pytest.mark.asyncio +async def test_get_learning_summary_flags_low_completion(): + t = _make_tools() + t._o.search_read = AsyncMock(return_value=[ + {'name': 'Bad Course', 'members_count': 5, 'completion_rate': 10.0}, + ]) + result = await t.get_learning_summary() + assert len(result['low_completion_courses']) == 1 + + +# ── create_course ───────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_create_course_success(): + t = _make_tools() + result = await t.create_course(name='New Course') + assert result['success'] is True + assert result['id'] == 10 + assert result['name'] == 'New Course' + + +@pytest.mark.asyncio +async def test_create_course_failure(): + t = _make_tools() + t._o.create = AsyncMock(return_value=WriteResult( + success=False, model='', record_id=None, action='create', error='DB error')) + result = await t.create_course(name='Bad Course') + assert result['success'] is False + assert 'error' in result + + +@pytest.mark.asyncio +async def test_create_course_with_options(): + t = _make_tools() + await t.create_course(name='Paid Course', enroll_policy='paid', website_published=True) + vals = t._o.create.call_args[0][1] + assert vals['enroll_policy'] == 'paid' + assert vals['website_published'] is True + + +# ── update_course ───────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_update_course_name(): + t = _make_tools() + result = await t.update_course(channel_id=1, name='Updated Name') + assert result['success'] is True + t._o.write.assert_awaited_once_with('slide.channel', [1], {'name': 'Updated Name'}) + + +@pytest.mark.asyncio +async def test_update_course_publish(): + t = _make_tools() + await t.update_course(channel_id=1, website_published=True) + vals = t._o.write.call_args[0][2] + assert vals['website_published'] is True + + +@pytest.mark.asyncio +async def test_update_course_no_values(): + t = _make_tools() + result = await t.update_course(channel_id=1) + assert result['success'] is False + assert 'error' in result + + +# ── add_section ─────────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_add_section_success(): + t = _make_tools() + t._o.create = AsyncMock(return_value=WriteResult( + success=True, model='slide.slide', record_id=20, action='create')) + result = await t.add_section(channel_id=1, name='Section 1', sequence=5) + assert result['success'] is True + assert result['id'] == 20 + vals = t._o.create.call_args[0][1] + assert vals['is_category'] is True + assert vals['sequence'] == 5 + + +# ── create_slide ────────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_create_slide_success(): + t = _make_tools() + t._o.create = AsyncMock(return_value=WriteResult( + success=True, model='slide.slide', record_id=30, action='create')) + result = await t.create_slide(channel_id=1, name='My Slide', slide_type='video') + assert result['success'] is True + assert result['id'] == 30 + assert result['slide_type'] == 'video' + + +@pytest.mark.asyncio +async def test_create_slide_with_content(): + t = _make_tools() + await t.create_slide(channel_id=1, name='HTML Slide', html_content='

hello

') + vals = t._o.create.call_args[0][1] + assert vals['html_content'] == '

hello

' + + +# ── enroll_user ─────────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_enroll_user_new(): + t = _make_tools() + t._o.search_read = AsyncMock(return_value=[]) + t._o.create = AsyncMock(return_value=WriteResult( + success=True, model='slide.channel.partner', record_id=5, action='create')) + result = await t.enroll_user(channel_id=1, partner_id=10) + assert result['success'] is True + t._o.create.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_enroll_user_already_enrolled(): + t = _make_tools() + t._o.search_read = AsyncMock(return_value=[{'id': 5}]) + result = await t.enroll_user(channel_id=1, partner_id=10) + assert result['success'] is True + assert result['already_enrolled'] is True + t._o.create.assert_not_awaited() + + +# ── suggest_next_course ─────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_suggest_next_course_excludes_completed(): + t = _make_tools() + t._o.search_read = AsyncMock(side_effect=[ + [{'channel_id': [1, 'Python 101']}], # completed + [{'id': 2, 'name': 'Advanced Python'}], # suggestions + ]) + result = await t.suggest_next_course(partner_id=5) + domain = t._o.search_read.call_args_list[1][0][1] + assert ('id', 'not in', [1]) in domain + + +@pytest.mark.asyncio +async def test_suggest_next_course_no_completed(): + t = _make_tools() + t._o.search_read = AsyncMock(side_effect=[ + [], + [{'id': 1, 'name': 'Python 101'}], + ]) + result = await t.suggest_next_course(partner_id=5) + assert isinstance(result, list) + + +# ── post_chatter_note ───────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_post_chatter_note(): + t = _make_tools() + result = await t.post_chatter_note('slide.channel', 1, 'Test note') + assert result is True + t._o.call.assert_awaited_once() + + +# ── flag_low_completion ─────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_flag_low_completion(): + t = _make_tools() + result = await t.flag_low_completion(channel_id=1, reason='Only 10% completion') + assert result is True + call_kwargs = t._o.call.call_args[0][3] + assert '[AI FLAG]' in call_kwargs['body'] diff --git a/tests/test_employees_agent.py b/tests/test_employees_agent.py new file mode 100644 index 0000000..7d122f4 --- /dev/null +++ b/tests/test_employees_agent.py @@ -0,0 +1,260 @@ +"""Unit tests for EmployeesAgent — plan, gather, reason, report, peer_bus, sweep.""" +import datetime +import pytest +from unittest.mock import AsyncMock, MagicMock +from agent_service.agents.employees_agent import EmployeesAgent, EMPLOYEES_TOOLS +from agent_service.agents.base_agent import AgentDirective, AgentReport, SweepReport + + +def _directive(intent='', context=None): + return AgentDirective( + directive_id='emp-d1', user_id='1', intent=intent, + context=context or {}, agent_name='employees_agent', + ) + + +def _make_agent(): + agent = EmployeesAgent(odoo=MagicMock(), llm=MagicMock()) + agent._ht = MagicMock() + agent._ht.get_employees = AsyncMock(return_value=[ + {'id': 1, 'name': 'Alice'}, {'id': 2, 'name': 'Bob'} + ]) + agent._ht.get_leaves = AsyncMock(return_value=[]) + agent._ht.get_contracts = AsyncMock(return_value=[]) + agent._ht.get_employee_profile = AsyncMock(return_value={'id': 1, 'name': 'Alice'}) + agent._ht.get_department_summary = AsyncMock(return_value={'headcount': 5, 'avg_wage': 50000.0}) + agent._ht.get_attendance_summary = AsyncMock(return_value={'total_hours': 160.0}) + agent._ht.flag_for_review = AsyncMock(return_value=True) + agent._ht.post_chatter_note = AsyncMock(return_value=True) + return agent + + +# ── Meta ──────────────────────────────────────────────────────────────────── + +def test_tool_count(): + assert len(EMPLOYEES_TOOLS) <= 8 + +def test_agent_name(): + assert EmployeesAgent.name == 'employees_agent' + +def test_hipaa_locked(): + from agent_service.llm.llm_router import HIPAA_LOCKED_AGENTS + assert 'employees_agent' in HIPAA_LOCKED_AGENTS + + +# ── _plan ─────────────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_plan_employee_intent(): + agent = _make_agent() + plan = await agent._plan(_directive(intent='show all employees')) + assert plan['fetch_employees'] is True + +@pytest.mark.asyncio +async def test_plan_headcount_intent(): + agent = _make_agent() + plan = await agent._plan(_directive(intent='what is the headcount')) + assert plan['fetch_employees'] is True + +@pytest.mark.asyncio +async def test_plan_leave_intent(): + agent = _make_agent() + plan = await agent._plan(_directive(intent='show pending leave requests')) + assert plan['fetch_leaves'] is True + +@pytest.mark.asyncio +async def test_plan_vacation_intent(): + agent = _make_agent() + plan = await agent._plan(_directive(intent='list vacation requests')) + assert plan['fetch_leaves'] is True + +@pytest.mark.asyncio +async def test_plan_contract_intent(): + agent = _make_agent() + plan = await agent._plan(_directive(intent='check contracts')) + assert plan['fetch_contracts'] is True + +@pytest.mark.asyncio +async def test_plan_propagates_context(): + agent = _make_agent() + plan = await agent._plan(_directive(context={'department_id': 3, 'employee_id': 7})) + assert plan['department_id'] == 3 + assert plan['employee_id'] == 7 + + +# ── _gather ───────────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_gather_employees_by_default(): + agent = _make_agent() + ctx = {'plan': {'fetch_employees': True, 'fetch_leaves': False, + 'fetch_contracts': False, 'department_id': None, 'employee_id': None}} + data = await agent._gather(ctx) + assert 'employees' in data + +@pytest.mark.asyncio +async def test_gather_department_summary_when_dept_id(): + agent = _make_agent() + ctx = {'plan': {'fetch_employees': True, 'fetch_leaves': False, + 'fetch_contracts': False, 'department_id': 5, 'employee_id': None}} + data = await agent._gather(ctx) + assert 'dept_summary' in data + agent._ht.get_department_summary.assert_awaited_once_with(5) + +@pytest.mark.asyncio +async def test_gather_leaves(): + agent = _make_agent() + agent._ht.get_leaves = AsyncMock(return_value=[{'id': 1, 'name': 'Alice Leave'}]) + ctx = {'plan': {'fetch_employees': False, 'fetch_leaves': True, + 'fetch_contracts': False, 'department_id': None, 'employee_id': None}} + data = await agent._gather(ctx) + assert 'leaves' in data + +@pytest.mark.asyncio +async def test_gather_contracts(): + agent = _make_agent() + agent._ht.get_contracts = AsyncMock(return_value=[{'id': 10, 'date_end': '2025-01-01'}]) + ctx = {'plan': {'fetch_employees': False, 'fetch_leaves': False, + 'fetch_contracts': True, 'department_id': None, 'employee_id': None}} + data = await agent._gather(ctx) + assert 'contracts' in data + + +# ── _reason ───────────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_reason_flags_expired_contracts(): + agent = _make_agent() + yesterday = str((datetime.date.today() - datetime.timedelta(days=1))) + agent._gathered_data = { + 'contracts': [{'id': 1, 'date_end': yesterday, 'employee_id': [1, 'Alice']}] + } + analysis = await agent._reason({}) + assert len(analysis['escalations']) == 1 + assert 'expired' in analysis['escalations'][0].lower() + +@pytest.mark.asyncio +async def test_reason_active_contracts_no_escalation(): + agent = _make_agent() + future = str((datetime.date.today() + datetime.timedelta(days=30))) + agent._gathered_data = { + 'contracts': [{'id': 1, 'date_end': future}] + } + analysis = await agent._reason({}) + assert analysis['escalations'] == [] + +@pytest.mark.asyncio +async def test_reason_no_contracts_no_escalation(): + agent = _make_agent() + agent._gathered_data = {'employees': [{'id': 1}]} + analysis = await agent._reason({}) + assert analysis['escalations'] == [] + + +# ── _report ───────────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_report_employee_count(): + agent = _make_agent() + agent._gathered_data = {'employees': [{'id': 1}, {'id': 2}, {'id': 3}]} + agent._escalations_list = [] + report = await agent._report({}) + assert isinstance(report, AgentReport) + assert '3' in report.summary + +@pytest.mark.asyncio +async def test_report_dept_summary(): + agent = _make_agent() + agent._gathered_data = {'dept_summary': {'headcount': 8, 'avg_wage': 60000.0}} + agent._escalations_list = [] + report = await agent._report({}) + assert '8' in report.summary + +@pytest.mark.asyncio +async def test_report_fallback_message(): + agent = _make_agent() + agent._gathered_data = {} + agent._escalations_list = [] + report = await agent._report({}) + assert 'complete' in report.summary.lower() or 'hr' in report.summary.lower() + + +# ── _dispatch_tool ─────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_dispatch_get_employees(): + agent = _make_agent() + await agent._dispatch_tool('get_employees', {}) + agent._ht.get_employees.assert_awaited_once() + +@pytest.mark.asyncio +async def test_dispatch_get_employee_profile(): + agent = _make_agent() + await agent._dispatch_tool('get_employee_profile', {'employee_id': 42}) + agent._ht.get_employee_profile.assert_awaited_once_with(employee_id=42) + +@pytest.mark.asyncio +async def test_dispatch_unknown_raises(): + agent = _make_agent() + with pytest.raises(ValueError, match='Unknown tool'): + await agent._dispatch_tool('nonexistent', {}) + + +# ── handle_peer_request ────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_peer_employee_list(): + agent = _make_agent() + result = await agent.handle_peer_request('employee_list', {}, 'dir-1') + assert 'employees' in result + assert len(result['employees']) == 2 + +@pytest.mark.asyncio +async def test_peer_employee_profile(): + agent = _make_agent() + result = await agent.handle_peer_request('employee_profile', {'employee_id': 1}, 'dir-1') + assert 'name' in result + +@pytest.mark.asyncio +async def test_peer_headcount(): + agent = _make_agent() + result = await agent.handle_peer_request('headcount', {}, 'dir-1') + assert result['headcount'] == 2 + +@pytest.mark.asyncio +async def test_peer_unknown_returns_error(): + agent = _make_agent() + result = await agent.handle_peer_request('bad_type', {}, 'dir-1') + assert 'error' in result + + +# ── sweep ──────────────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_sweep_finds_expired_contracts(): + agent = _make_agent() + yesterday = str((datetime.date.today() - datetime.timedelta(days=1))) + agent._ht.get_contracts = AsyncMock(return_value=[ + {'id': 1, 'date_end': yesterday, 'employee_id': [1, 'Alice']} + ]) + result = await agent.sweep() + assert isinstance(result, SweepReport) + assert len(result.findings) == 1 + assert result.findings[0]['type'] == 'expired_contract' + +@pytest.mark.asyncio +async def test_sweep_active_contracts_no_findings(): + agent = _make_agent() + future = str((datetime.date.today() + datetime.timedelta(days=60))) + agent._ht.get_contracts = AsyncMock(return_value=[ + {'id': 1, 'date_end': future} + ]) + result = await agent.sweep() + assert result.findings == [] + +@pytest.mark.asyncio +async def test_sweep_handles_exception(): + agent = _make_agent() + agent._ht.get_contracts = AsyncMock(side_effect=Exception('DB down')) + result = await agent.sweep() + assert result.error is not None diff --git a/tests/test_employees_tools.py b/tests/test_employees_tools.py new file mode 100644 index 0000000..0c6396f --- /dev/null +++ b/tests/test_employees_tools.py @@ -0,0 +1,202 @@ +"""Unit tests for EmployeesTools.""" +import pytest +from unittest.mock import AsyncMock, MagicMock +from agent_service.tools.employees_tools import EmployeesTools +from agent_service.tools.odoo_client import WriteResult + + +def _make_tools(): + odoo = MagicMock() + odoo.search_read = AsyncMock(return_value=[]) + odoo.write = AsyncMock(return_value=WriteResult( + success=True, model='', record_id=None, action='write')) + odoo.create = AsyncMock(return_value=WriteResult( + success=True, model='', record_id=42, action='create')) + odoo.call = AsyncMock(return_value=True) + return EmployeesTools(odoo) + + +# ── get_employees ──────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_get_employees_default(): + t = _make_tools() + result = await t.get_employees() + t._o.search_read.assert_awaited_once() + domain = t._o.search_read.call_args[0][1] + assert ('active', '=', True) in domain + assert isinstance(result, list) + + +@pytest.mark.asyncio +async def test_get_employees_with_department(): + t = _make_tools() + await t.get_employees(department_id=5) + domain = t._o.search_read.call_args[0][1] + assert ('department_id', '=', 5) in domain + + +@pytest.mark.asyncio +async def test_get_employees_inactive(): + t = _make_tools() + await t.get_employees(active=False) + domain = t._o.search_read.call_args[0][1] + assert ('active', '=', False) in domain + + +# ── get_employee_profile ────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_get_employee_profile_found(): + t = _make_tools() + t._o.search_read = AsyncMock(return_value=[ + {'id': 1, 'name': 'Alice', 'job_title': 'Dev', 'work_email': 'alice@co.com'} + ]) + result = await t.get_employee_profile(employee_id=1) + assert result['name'] == 'Alice' + domain = t._o.search_read.call_args[0][1] + assert ('id', '=', 1) in domain + + +@pytest.mark.asyncio +async def test_get_employee_profile_not_found(): + t = _make_tools() + t._o.search_read = AsyncMock(return_value=[]) + result = await t.get_employee_profile(employee_id=999) + assert result == {} + + +# ── get_leaves ──────────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_get_leaves_default(): + t = _make_tools() + result = await t.get_leaves() + t._o.search_read.assert_awaited_once() + assert isinstance(result, list) + + +@pytest.mark.asyncio +async def test_get_leaves_with_employee(): + t = _make_tools() + await t.get_leaves(employee_id=3) + domain = t._o.search_read.call_args[0][1] + assert ('employee_id', '=', 3) in domain + + +@pytest.mark.asyncio +async def test_get_leaves_with_state(): + t = _make_tools() + await t.get_leaves(state='validate') + domain = t._o.search_read.call_args[0][1] + assert ('state', '=', 'validate') in domain + + +@pytest.mark.asyncio +async def test_get_leaves_with_date_from(): + t = _make_tools() + await t.get_leaves(date_from='2026-01-01') + domain = t._o.search_read.call_args[0][1] + assert ('date_from', '>=', '2026-01-01') in domain + + +# ── get_contracts ───────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_get_contracts_default_state_open(): + t = _make_tools() + await t.get_contracts() + domain = t._o.search_read.call_args[0][1] + assert ('state', '=', 'open') in domain + + +@pytest.mark.asyncio +async def test_get_contracts_with_employee(): + t = _make_tools() + await t.get_contracts(employee_id=7) + domain = t._o.search_read.call_args[0][1] + assert ('employee_id', '=', 7) in domain + + +@pytest.mark.asyncio +async def test_get_contracts_returns_list(): + t = _make_tools() + t._o.search_read = AsyncMock(return_value=[ + {'id': 1, 'name': 'Contract A', 'wage': 50000.0} + ]) + result = await t.get_contracts() + assert len(result) == 1 + + +# ── get_attendance_summary ──────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_get_attendance_summary_totals(): + t = _make_tools() + t._o.search_read = AsyncMock(return_value=[ + {'worked_hours': 8.0, 'check_in': '2026-01-05 09:00:00'}, + {'worked_hours': 7.5, 'check_in': '2026-01-06 09:00:00'}, + ]) + result = await t.get_attendance_summary(employee_id=1, date_from='2026-01-01', date_to='2026-01-31') + assert result['total_hours'] == 15.5 + assert result['days_present'] == 2 + assert result['attendance_records'] == 2 + + +@pytest.mark.asyncio +async def test_get_attendance_summary_structure(): + t = _make_tools() + result = await t.get_attendance_summary(employee_id=1, date_from='2026-01-01', date_to='2026-01-31') + assert 'total_hours' in result + assert 'days_present' in result + assert result['employee_id'] == 1 + + +# ── get_department_summary ──────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_get_department_summary_headcount(): + t = _make_tools() + t._o.search_read = AsyncMock(side_effect=[ + [{'id': 1, 'name': 'Alice'}, {'id': 2, 'name': 'Bob'}], # employees + [{'id': 10, 'employee_id': [1, 'Alice'], 'wage': 60000.0}, + {'id': 11, 'employee_id': [2, 'Bob'], 'wage': 55000.0}], # contracts + ]) + result = await t.get_department_summary(department_id=3) + assert result['headcount'] == 2 + assert result['department_id'] == 3 + + +@pytest.mark.asyncio +async def test_get_department_summary_avg_wage(): + t = _make_tools() + t._o.search_read = AsyncMock(side_effect=[ + [{'id': 1, 'name': 'Alice'}], + [{'id': 10, 'employee_id': [1, 'Alice'], 'wage': 60000.0}], + ]) + result = await t.get_department_summary(department_id=3) + assert result['avg_wage'] == 60000.0 + + +# ── flag_for_review ─────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_flag_for_review(): + t = _make_tools() + result = await t.flag_for_review('hr.employee', 1, 'Contract expired', severity='high') + assert result is True + call_args = t._o.call.call_args[0] + assert call_args[1] == 'message_post' + assert '[AI FLAG - HIGH]' in call_args[3]['body'] + + +# ── post_chatter_note ───────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_post_chatter_note(): + t = _make_tools() + result = await t.post_chatter_note('hr.employee', 1, 'Note text') + assert result is True + t._o.call.assert_awaited_once() + call_args = t._o.call.call_args[0] + assert call_args[3]['body'] == 'Note text' diff --git a/tests/test_expenses_tools.py b/tests/test_expenses_tools.py new file mode 100644 index 0000000..979ec95 --- /dev/null +++ b/tests/test_expenses_tools.py @@ -0,0 +1,290 @@ +"""Unit tests for ExpensesTools.""" +import pytest +from unittest.mock import AsyncMock, MagicMock +from agent_service.tools.expenses_tools import ExpensesTools +from agent_service.tools.odoo_client import WriteResult + + +def _make_tools(): + odoo = MagicMock() + odoo.search_read = AsyncMock(return_value=[]) + odoo.write = AsyncMock(return_value=WriteResult( + success=True, model='', record_id=None, action='write')) + odoo.create = AsyncMock(return_value=WriteResult( + success=True, model='', record_id=42, action='create')) + odoo.call = AsyncMock(return_value=True) + return ExpensesTools(odoo) + + +# ── get_expenses ────────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_get_expenses_default(): + t = _make_tools() + result = await t.get_expenses() + t._o.search_read.assert_awaited_once() + assert isinstance(result, list) + + +@pytest.mark.asyncio +async def test_get_expenses_with_employee(): + t = _make_tools() + await t.get_expenses(employee_id=5) + domain = t._o.search_read.call_args[0][1] + assert ('employee_id', '=', 5) in domain + + +@pytest.mark.asyncio +async def test_get_expenses_with_state(): + t = _make_tools() + await t.get_expenses(state='draft') + domain = t._o.search_read.call_args[0][1] + assert ('state', '=', 'draft') in domain + + +@pytest.mark.asyncio +async def test_get_expenses_with_date_range(): + t = _make_tools() + await t.get_expenses(date_from='2026-01-01', date_to='2026-01-31') + domain = t._o.search_read.call_args[0][1] + assert ('date', '>=', '2026-01-01') in domain + assert ('date', '<=', '2026-01-31') in domain + + +# ── get_expense_sheets ──────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_get_expense_sheets_default(): + t = _make_tools() + result = await t.get_expense_sheets() + t._o.search_read.assert_awaited_once() + assert isinstance(result, list) + + +@pytest.mark.asyncio +async def test_get_expense_sheets_with_state(): + t = _make_tools() + await t.get_expense_sheets(state='submit') + domain = t._o.search_read.call_args[0][1] + assert ('state', '=', 'submit') in domain + + +@pytest.mark.asyncio +async def test_get_expense_sheets_with_employee(): + t = _make_tools() + await t.get_expense_sheets(employee_id=3) + domain = t._o.search_read.call_args[0][1] + assert ('employee_id', '=', 3) in domain + + +# ── get_pending_approvals ───────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_get_pending_approvals_uses_submit_state(): + t = _make_tools() + t._o.search_read = AsyncMock(return_value=[ + {'id': 1, 'name': 'Sheet 1', 'total_amount': 500.0} + ]) + result = await t.get_pending_approvals() + domain = t._o.search_read.call_args[0][1] + assert ('state', '=', 'submit') in domain + assert len(result) == 1 + + +# ── approve_expense_sheet ───────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_approve_expense_sheet_success(): + t = _make_tools() + result = await t.approve_expense_sheet(sheet_id=1) + assert result is True + t._o.call.assert_awaited_once_with( + 'hr.expense.sheet', 'approve_expense_sheets', [[1]]) + + +@pytest.mark.asyncio +async def test_approve_expense_sheet_handles_exception(): + t = _make_tools() + t._o.call = AsyncMock(side_effect=Exception('permission denied')) + result = await t.approve_expense_sheet(sheet_id=1) + assert result is False + + +# ── get_expenses_summary ────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_get_expenses_summary_structure(): + t = _make_tools() + t._o.search_read = AsyncMock(side_effect=[ + [{'total_amount': 100.0}, {'total_amount': 200.0}], # expenses + [], # pending_approvals + ]) + result = await t.get_expenses_summary() + assert 'total_expenses' in result + assert result['total_expenses'] == 2 + assert result['total_amount'] == 300.0 + assert 'pending_approval_count' in result + + +@pytest.mark.asyncio +async def test_get_expenses_summary_with_dates(): + t = _make_tools() + await t.get_expenses_summary(date_from='2026-01-01', date_to='2026-01-31') + domain = t._o.search_read.call_args_list[0][0][1] + assert ('date', '>=', '2026-01-01') in domain + + +# ── get_expense_by_employee ─────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_get_expense_by_employee(): + t = _make_tools() + t._o.search_read = AsyncMock(return_value=[ + {'id': 1, 'name': 'Flight', 'total_amount': 500.0} + ]) + result = await t.get_expense_by_employee(employee_id=5) + domain = t._o.search_read.call_args[0][1] + assert ('employee_id', '=', 5) in domain + assert len(result) == 1 + + +# ── get_employee_id_for_user ────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_get_employee_id_for_user_found(): + t = _make_tools() + t._o.search_read = AsyncMock(return_value=[{'id': 7, 'name': 'Alice'}]) + result = await t.get_employee_id_for_user(user_id=3) + assert result == 7 + + +@pytest.mark.asyncio +async def test_get_employee_id_for_user_not_found(): + t = _make_tools() + t._o.search_read = AsyncMock(return_value=[]) + result = await t.get_employee_id_for_user(user_id=99) + assert result is None + + +@pytest.mark.asyncio +async def test_get_employee_id_for_user_none_input(): + t = _make_tools() + result = await t.get_employee_id_for_user(user_id=None) + assert result is None + + +# ── get_default_expense_product ─────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_get_default_expense_product_found(): + t = _make_tools() + t._o.search_read = AsyncMock(return_value=[{'id': 99, 'name': 'Expense'}]) + result = await t.get_default_expense_product() + assert result == 99 + + +@pytest.mark.asyncio +async def test_get_default_expense_product_not_found(): + t = _make_tools() + result = await t.get_default_expense_product() + assert result is None + + +# ── get_expense_products ────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_get_expense_products(): + t = _make_tools() + t._o.search_read = AsyncMock(return_value=[ + {'id': 1, 'name': 'Travel'}, {'id': 2, 'name': 'Meals'} + ]) + result = await t.get_expense_products() + assert len(result) == 2 + + +@pytest.mark.asyncio +async def test_get_expense_products_returns_empty_on_exception(): + t = _make_tools() + t._o.search_read = AsyncMock(side_effect=Exception('error')) + result = await t.get_expense_products() + assert result == [] + + +# ── create_expense_sheet ────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_create_expense_sheet(): + t = _make_tools() + result = await t.create_expense_sheet(name='Q1 Expenses', employee_id=5) + t._o.create.assert_awaited_once_with('hr.expense.sheet', { + 'name': 'Q1 Expenses', 'employee_id': 5, + }) + + +# ── create_expense ──────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_create_expense_basic(): + t = _make_tools() + await t.create_expense( + sheet_id=1, employee_id=5, name='Flight', + total_amount=500.0, date='2026-01-15', + ) + vals = t._o.create.call_args[0][1] + assert vals['name'] == 'Flight' + assert vals['total_amount'] == 500.0 + assert vals['date'] == '2026-01-15' + + +@pytest.mark.asyncio +async def test_create_expense_with_product(): + t = _make_tools() + await t.create_expense( + sheet_id=1, employee_id=5, name='Hotel', + total_amount=200.0, date='2026-01-16', product_id=10, + ) + vals = t._o.create.call_args[0][1] + assert vals['product_id'] == 10 + + +# ── attach_receipt ──────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_attach_receipt_success(): + t = _make_tools() + result = await t.attach_receipt( + model='hr.expense', record_id=1, + filename='receipt.pdf', file_b64='base64data', mimetype='application/pdf', + ) + assert result is True + t._o.create.assert_awaited_once() + vals = t._o.create.call_args[0][1] + assert vals['name'] == 'receipt.pdf' + assert vals['res_model'] == 'hr.expense' + + +@pytest.mark.asyncio +async def test_attach_receipt_handles_exception(): + t = _make_tools() + t._o.create = AsyncMock(side_effect=Exception('error')) + result = await t.attach_receipt('hr.expense', 1, 'f.pdf', 'b64', 'application/pdf') + assert result is False + + +# ── flag_for_review / post_chatter_note ────────────────────────────────────── + +@pytest.mark.asyncio +async def test_flag_for_review(): + t = _make_tools() + result = await t.flag_for_review('hr.expense', 1, 'Suspicious amount') + assert result is True + call_args = t._o.call.call_args[0] + assert 'MEDIUM' in call_args[3]['body'] + + +@pytest.mark.asyncio +async def test_post_chatter_note(): + t = _make_tools() + result = await t.post_chatter_note('hr.expense', 1, 'Expense verified') + assert result is True + t._o.call.assert_awaited_once() diff --git a/tests/test_finance_agent.py b/tests/test_finance_agent.py index feb4a55..2ffd0cf 100644 --- a/tests/test_finance_agent.py +++ b/tests/test_finance_agent.py @@ -123,7 +123,7 @@ async def test_sweep_returns_findings(agent, mock_ft): @pytest.mark.asyncio async def test_handle_peer_request_overdue_summary(agent, mock_ft): agent._ft = mock_ft - result = await agent.handle_peer_request({'type': 'overdue_summary'}) + result = await agent.handle_peer_request('overdue_summary', {}, 'dir-1') assert 'overdue_count' in result diff --git a/tests/test_finance_tools.py b/tests/test_finance_tools.py new file mode 100644 index 0000000..902c1e2 --- /dev/null +++ b/tests/test_finance_tools.py @@ -0,0 +1,213 @@ +"""Unit tests for FinanceTools.""" +import pytest +from unittest.mock import AsyncMock, MagicMock +from agent_service.tools.finance_tools import FinanceTools + + +def _make_tools(): + odoo = MagicMock() + odoo.search_read = AsyncMock(return_value=[]) + odoo.read = AsyncMock(return_value=[]) + odoo.call = AsyncMock(return_value=True) + odoo.post_chatter = AsyncMock(return_value=99) + return FinanceTools(odoo) + + +# ── get_invoices ────────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_get_invoices_default(): + t = _make_tools() + result = await t.get_invoices() + t._odoo.search_read.assert_awaited_once() + assert isinstance(result, list) + + +@pytest.mark.asyncio +async def test_get_invoices_state_filter(): + t = _make_tools() + await t.get_invoices(state='posted') + domain = t._odoo.search_read.call_args[0][1] + assert ['state', '=', 'posted'] in domain + + +@pytest.mark.asyncio +async def test_get_invoices_move_type_filter(): + t = _make_tools() + await t.get_invoices(move_type='out_invoice') + domain = t._odoo.search_read.call_args[0][1] + assert ['move_type', '=', 'out_invoice'] in domain + + +@pytest.mark.asyncio +async def test_get_invoices_partner_filter(): + t = _make_tools() + await t.get_invoices(partner_id=5) + domain = t._odoo.search_read.call_args[0][1] + assert ['partner_id', '=', 5] in domain + + +@pytest.mark.asyncio +async def test_get_invoices_date_range(): + t = _make_tools() + await t.get_invoices(date_from='2026-01-01', date_to='2026-01-31') + domain = t._odoo.search_read.call_args[0][1] + assert ['invoice_date', '>=', '2026-01-01'] in domain + assert ['invoice_date', '<=', '2026-01-31'] in domain + + +# ── get_overdue_invoices ────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_get_overdue_invoices_default(): + t = _make_tools() + result = await t.get_overdue_invoices() + domain = t._odoo.search_read.call_args[0][1] + assert ['move_type', 'in', ['out_invoice', 'out_refund']] in domain + assert ['state', '=', 'posted'] in domain + assert ['payment_state', 'in', ['not_paid', 'partial']] in domain + + +@pytest.mark.asyncio +async def test_get_overdue_invoices_partner_filter(): + t = _make_tools() + await t.get_overdue_invoices(partner_id=7) + domain = t._odoo.search_read.call_args[0][1] + assert ['partner_id', '=', 7] in domain + + +# ── get_unreconciled_statements ─────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_get_unreconciled_statements(): + t = _make_tools() + await t.get_unreconciled_statements(journal_id=3) + domain = t._odoo.search_read.call_args[0][1] + assert ['journal_id', '=', 3] in domain + assert ['is_reconciled', '=', False] in domain + + +@pytest.mark.asyncio +async def test_get_unreconciled_statements_date_range(): + t = _make_tools() + await t.get_unreconciled_statements(journal_id=3, date_from='2026-01-01', date_to='2026-01-31') + domain = t._odoo.search_read.call_args[0][1] + assert ['date', '>=', '2026-01-01'] in domain + + +# ── send_payment_reminder ───────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_send_payment_reminder_success(): + t = _make_tools() + t._odoo.read = AsyncMock(return_value=[{ + 'name': 'INV/001', 'partner_id': [1, 'ACME'], + 'amount_residual': 500.0, 'invoice_date_due': '2026-01-01', + }]) + result = await t.send_payment_reminder(invoice_id=1) + assert result['success'] is True + assert result['invoice'] == 'INV/001' + t._odoo.post_chatter.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_send_payment_reminder_invoice_not_found(): + t = _make_tools() + t._odoo.read = AsyncMock(return_value=[]) + result = await t.send_payment_reminder(invoice_id=999) + assert result['success'] is False + assert 'error' in result + + +@pytest.mark.asyncio +async def test_send_payment_reminder_custom_message(): + t = _make_tools() + t._odoo.read = AsyncMock(return_value=[{ + 'name': 'INV/001', 'partner_id': [1, 'ACME'], + 'amount_residual': 500.0, 'invoice_date_due': '2026-01-01', + }]) + await t.send_payment_reminder(invoice_id=1, custom_message='Please pay!') + call_args = t._odoo.post_chatter.call_args[0] + assert call_args[2] == 'Please pay!' + + +# ── get_financial_summary ───────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_get_financial_summary_current_period(): + t = _make_tools() + t._odoo.search_read = AsyncMock(return_value=[ + {'amount_total': 1000.0, 'amount_residual': 200.0, 'payment_state': 'partial'}, + {'amount_total': 2000.0, 'amount_residual': 0.0, 'payment_state': 'paid'}, + ]) + result = await t.get_financial_summary() + assert result['total_invoiced'] == 3000.0 + assert result['invoice_count'] == 2 + assert result['paid_count'] == 1 + assert 'collection_rate' in result + + +@pytest.mark.asyncio +async def test_get_financial_summary_explicit_period(): + t = _make_tools() + await t.get_financial_summary(period='2026-01') + domain = t._odoo.search_read.call_args[0][1] + assert ['invoice_date', '>=', '2026-01-01'] in domain + assert ['invoice_date', '<=', '2026-01-31'] in domain + + +@pytest.mark.asyncio +async def test_get_financial_summary_empty(): + t = _make_tools() + result = await t.get_financial_summary() + assert result['total_invoiced'] == 0 + assert result['collection_rate'] == 0 + + +# ── get_payment_history ─────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_get_payment_history(): + t = _make_tools() + t._odoo.search_read = AsyncMock(return_value=[ + {'name': 'PAY/001', 'amount': 500.0, 'state': 'posted'} + ]) + result = await t.get_payment_history(partner_id=5) + domain = t._odoo.search_read.call_args[0][1] + assert ['partner_id', '=', 5] in domain + assert ['state', '=', 'posted'] in domain + assert len(result) == 1 + + +# ── flag_for_review ─────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_flag_for_review(): + t = _make_tools() + result = await t.flag_for_review('account.move', 1, 'Overdue 90 days') + assert result['flagged'] is True + assert result['model'] == 'account.move' + assert result['record_id'] == 1 + t._odoo.post_chatter.assert_awaited_once() + note = t._odoo.post_chatter.call_args[0][2] + assert 'MEDIUM' in note + + +@pytest.mark.asyncio +async def test_flag_for_review_high_severity(): + t = _make_tools() + result = await t.flag_for_review('account.move', 1, 'reason', severity='high') + assert result['severity'] == 'high' + note = t._odoo.post_chatter.call_args[0][2] + assert 'HIGH' in note + + +# ── post_chatter_note ───────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_post_chatter_note(): + t = _make_tools() + result = await t.post_chatter_note('account.move', 1, 'Payment received') + assert result['success'] is True + assert result['message_id'] == 99 + t._odoo.post_chatter.assert_awaited_once_with('account.move', 1, 'Payment received') diff --git a/tests/test_peer_bus.py b/tests/test_peer_bus.py index e6ffe65..81a8da9 100644 --- a/tests/test_peer_bus.py +++ b/tests/test_peer_bus.py @@ -1,69 +1,171 @@ +"""Unit tests for PeerBus — request routing, circular detection, depth limit, timeout.""" import asyncio import pytest from unittest.mock import AsyncMock, MagicMock -from agent_service.agents.peer_bus import PeerBus, PeerResponse +from agent_service.agents.peer_bus import PeerBus, PeerResponse, PeerCircularRequestError -class MockAgent: - name = 'mock_agent' +def _make_registry(agents=None, active=None): + registry = MagicMock() + agents = agents or {} + active_set = set(active) if active is not None else set(agents.keys()) - async def handle_peer_request(self, request: dict) -> dict: - return {'answer': 42, 'echo': request.get('data')} + async def is_active(key): + return key in active_set + + registry.is_active = is_active + registry.get_agent_instance = lambda key: agents.get(key) + return registry -class SlowAgent: - name = 'slow_agent' +def _make_agent(response=None): + agent = MagicMock() + agent.handle_peer_request = AsyncMock( + return_value=response if response is not None else {'answer': 42, 'success': True} + ) + return agent - async def handle_peer_request(self, request: dict) -> dict: - await asyncio.sleep(35) - return {} +# ── basic request routing ──────────────────────────────────────────────────── @pytest.mark.asyncio -async def test_peer_bus_register_and_call(): - bus = PeerBus() - bus.register('mock_agent', MockAgent()) - resp = await bus.call('mock_agent', {'data': 'hello'}) +async def test_request_routes_to_agent(): + agent = _make_agent({'answer': 42, 'success': True}) + reg = _make_registry(agents={'mock_agent': agent}) + bus = PeerBus(reg, directive_id='d1') + resp = await bus.request('from_agent', 'mock_agent', 'some_request', {}, 'reason') assert isinstance(resp, PeerResponse) assert resp.available is True + assert resp.success is True assert resp.data.get('answer') == 42 @pytest.mark.asyncio -async def test_peer_bus_unknown_agent(): - bus = PeerBus() - resp = await bus.call('nonexistent_agent', {}) +async def test_request_passes_correct_args_to_handler(): + agent = _make_agent() + reg = _make_registry(agents={'mock_agent': agent}) + bus = PeerBus(reg, directive_id='d1') + await bus.request('from_agent', 'mock_agent', 'query_type', {'key': 'val'}, 'reason') + agent.handle_peer_request.assert_awaited_once_with('query_type', {'key': 'val'}, 'd1') + + +@pytest.mark.asyncio +async def test_request_records_call_log(): + agent = _make_agent() + reg = _make_registry(agents={'mock_agent': agent}) + bus = PeerBus(reg, directive_id='d1') + await bus.request('from_agent', 'mock_agent', 'some_type', {}, 'reason') + assert len(bus.call_log) == 1 + assert bus.call_log[0]['from'] == 'from_agent' + assert bus.call_log[0]['to'] == 'mock_agent' + assert bus.call_log[0]['type'] == 'some_type' + + +# ── inactive / missing agent ───────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_request_inactive_agent_returns_unavailable(): + reg = _make_registry(agents={}, active=set()) + bus = PeerBus(reg, directive_id='d1') + resp = await bus.request('from_agent', 'inactive_agent', 'some_type', {}, 'reason') assert resp.available is False @pytest.mark.asyncio -async def test_peer_bus_max_depth(): - bus = PeerBus() - bus.register('mock_agent', MockAgent()) - call_chain = ['a', 'b', 'c'] - resp = await bus.call('mock_agent', {}, _call_chain=call_chain) +async def test_request_no_instance_returns_unavailable(): + reg = _make_registry(agents={}, active={'ghost_agent'}) + bus = PeerBus(reg, directive_id='d1') + resp = await bus.request('from_agent', 'ghost_agent', 'some_type', {}, 'reason') assert resp.available is False +# ── circular detection ─────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_circular_request_raises(): + agent = _make_agent() + reg = _make_registry(agents={'mock_agent': agent}) + bus = PeerBus(reg, directive_id='d1') + bus._call_chain = ['mock_agent'] + with pytest.raises(PeerCircularRequestError): + await bus.request('some_agent', 'mock_agent', 'some_type', {}, 'reason') + + +# ── max depth limit ────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_max_depth_returns_unavailable(): + agent = _make_agent() + reg = _make_registry(agents={'mock_agent': agent}) + bus = PeerBus(reg, directive_id='d1') + bus._call_chain = ['a', 'b', 'c'] # already at MAX_DEPTH=3 + resp = await bus.request('some_agent', 'mock_agent', 'some_type', {}, 'reason') + assert resp.available is False + assert 'depth' in str(resp.error).lower() or 'max' in str(resp.error).lower() + + +# ── timeout ────────────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_timeout_returns_success_false(): + agent = MagicMock() + + async def slow_handler(request_type, params, directive_id): + await asyncio.sleep(35) + return {} + + agent.handle_peer_request = slow_handler + reg = _make_registry(agents={'slow_agent': agent}) + bus = PeerBus(reg, directive_id='d1') + + import unittest.mock as um + with um.patch('asyncio.wait_for', side_effect=asyncio.TimeoutError): + resp = await bus.request('from_agent', 'slow_agent', 'some_type', {}, 'reason') + assert resp.available is True + assert resp.success is False + assert 'timeout' in str(resp.error).lower() + + +# ── exception from agent ───────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_agent_exception_returns_success_false(): + agent = MagicMock() + agent.handle_peer_request = AsyncMock(side_effect=RuntimeError('boom')) + reg = _make_registry(agents={'bad_agent': agent}) + bus = PeerBus(reg, directive_id='d1') + resp = await bus.request('from_agent', 'bad_agent', 'some_type', {}, 'reason') + assert resp.available is True + assert resp.success is False + assert 'boom' in str(resp.error) + + +# ── call_log ───────────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_call_log_grows_with_requests(): + agent = _make_agent() + reg = _make_registry(agents={'mock_agent': agent}) + bus = PeerBus(reg, directive_id='d1') + await bus.request('a', 'mock_agent', 't1', {}, 'r1') + await bus.request('b', 'mock_agent', 't2', {}, 'r2') + assert len(bus.call_log) == 2 + + @pytest.mark.asyncio -async def test_peer_bus_timeout(): - bus = PeerBus() - bus.register('slow_agent', SlowAgent()) - resp = await bus.call('slow_agent', {}) - assert resp.available is False +async def test_call_log_empty_initially(): + reg = _make_registry() + bus = PeerBus(reg, directive_id='d1') + assert bus.call_log == [] +# ── PeerResponse fields ────────────────────────────────────────────────────── + @pytest.mark.asyncio -async def test_peer_bus_circular_detection(): - bus = PeerBus() - bus.register('mock_agent', MockAgent()) - resp = await bus.call('mock_agent', {}, _call_chain=['mock_agent']) - assert resp.available is False - - -def test_peer_bus_get_agent(): - bus = PeerBus() - agent = MockAgent() - bus.register('mock_agent', agent) - assert bus.get_agent('mock_agent') is agent - assert bus.get_agent('missing') is None +async def test_response_has_agent_and_request_type(): + agent = _make_agent() + reg = _make_registry(agents={'mock_agent': agent}) + bus = PeerBus(reg, directive_id='d1') + resp = await bus.request('from_agent', 'mock_agent', 'my_type', {}, 'reason') + assert resp.agent == 'mock_agent' + assert resp.request_type == 'my_type' diff --git a/tests/test_project_agent.py b/tests/test_project_agent.py new file mode 100644 index 0000000..bb037a9 --- /dev/null +++ b/tests/test_project_agent.py @@ -0,0 +1,238 @@ +"""Unit tests for ProjectAgent — plan, gather, reason, report, peer_bus, sweep.""" +import datetime +import pytest +from unittest.mock import AsyncMock, MagicMock +from agent_service.agents.project_agent import ProjectAgent, PROJECT_TOOLS +from agent_service.agents.base_agent import AgentDirective, AgentReport, SweepReport + + +def _directive(intent='', context=None): + return AgentDirective( + directive_id='proj-d1', user_id='1', intent=intent, + context=context or {}, agent_name='project_agent', + ) + + +def _make_agent(): + agent = ProjectAgent(odoo=MagicMock(), llm=MagicMock()) + agent._pt = MagicMock() + agent._pt.get_projects = AsyncMock(return_value=[ + {'id': 1, 'name': 'Alpha'}, {'id': 2, 'name': 'Beta'} + ]) + agent._pt.get_tasks = AsyncMock(return_value=[]) + agent._pt.get_project_summary = AsyncMock(return_value={'task_count': 5}) + agent._pt.update_task_stage = AsyncMock(return_value=True) + agent._pt.assign_task = AsyncMock(return_value=True) + agent._pt.create_task = AsyncMock(return_value=99) + agent._pt.log_timesheet = AsyncMock(return_value=True) + agent._pt.post_chatter_note = AsyncMock(return_value=True) + return agent + + +# ── Meta ──────────────────────────────────────────────────────────────────── + +def test_tool_count(): + assert len(PROJECT_TOOLS) <= 8 + +def test_agent_name(): + assert ProjectAgent.name == 'project_agent' + + +# ── _plan ─────────────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_plan_project_intent(): + agent = _make_agent() + plan = await agent._plan(_directive(intent='show all projects overview')) + assert plan['fetch_projects'] is True + +@pytest.mark.asyncio +async def test_plan_task_intent(): + agent = _make_agent() + plan = await agent._plan(_directive(intent='list my tasks')) + assert plan['fetch_tasks'] is True + +@pytest.mark.asyncio +async def test_plan_propagates_project_id(): + agent = _make_agent() + plan = await agent._plan(_directive(context={'project_id': 3})) + assert plan['project_id'] == 3 + + +# ── _gather ───────────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_gather_projects_by_default(): + agent = _make_agent() + ctx = {'plan': {'fetch_projects': True, 'fetch_tasks': False, + 'project_id': None, 'user_id': None}} + data = await agent._gather(ctx) + assert 'projects' in data + assert len(data['projects']) == 2 + +@pytest.mark.asyncio +async def test_gather_tasks_when_requested(): + agent = _make_agent() + agent._pt.get_tasks = AsyncMock(return_value=[{'id': 10, 'kanban_state': 'normal'}]) + ctx = {'plan': {'fetch_projects': False, 'fetch_tasks': True, + 'project_id': None, 'user_id': None}} + data = await agent._gather(ctx) + assert 'tasks' in data + +@pytest.mark.asyncio +async def test_gather_tasks_when_project_id_set(): + agent = _make_agent() + ctx = {'plan': {'fetch_projects': False, 'fetch_tasks': False, + 'project_id': 1, 'user_id': None}} + data = await agent._gather(ctx) + assert 'tasks' in data + + +# ── _reason ───────────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_reason_flags_blocked_tasks(): + agent = _make_agent() + agent._gathered_data = { + 'tasks': [ + {'id': i, 'kanban_state': 'blocked', 'name': f'Task {i}'} + for i in range(6) + ] + } + analysis = await agent._reason({}) + assert len(analysis['blocked_tasks']) == 6 + assert len(analysis['escalations']) == 1 + assert '6' in analysis['escalations'][0] + +@pytest.mark.asyncio +async def test_reason_few_blocked_no_escalation(): + agent = _make_agent() + agent._gathered_data = { + 'tasks': [{'id': 1, 'kanban_state': 'blocked', 'name': 'T1'}] + } + analysis = await agent._reason({}) + assert len(analysis['blocked_tasks']) == 1 + assert analysis['escalations'] == [] + +@pytest.mark.asyncio +async def test_reason_no_tasks_no_escalation(): + agent = _make_agent() + agent._gathered_data = {'projects': [{'id': 1}]} + analysis = await agent._reason({}) + assert analysis['escalations'] == [] + + +# ── _report ───────────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_report_project_count(): + agent = _make_agent() + agent._gathered_data = {'projects': [{'id': 1}, {'id': 2}, {'id': 3}]} + agent._escalations_list = [] + report = await agent._report({}) + assert isinstance(report, AgentReport) + assert '3' in report.summary + +@pytest.mark.asyncio +async def test_report_task_and_blocked_count(): + agent = _make_agent() + agent._gathered_data = { + 'tasks': [{'id': i, 'kanban_state': 'blocked' if i < 2 else 'normal'} + for i in range(5)] + } + agent._escalations_list = [] + report = await agent._report({'analysis': {'blocked_tasks': [0, 1]}}) + assert '5' in report.summary + assert '2' in report.summary + +@pytest.mark.asyncio +async def test_report_fallback_message(): + agent = _make_agent() + agent._gathered_data = {} + agent._escalations_list = [] + report = await agent._report({}) + assert 'complete' in report.summary.lower() or 'project' in report.summary.lower() + + +# ── _dispatch_tool ─────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_dispatch_get_projects(): + agent = _make_agent() + await agent._dispatch_tool('get_projects', {}) + agent._pt.get_projects.assert_awaited_once() + +@pytest.mark.asyncio +async def test_dispatch_create_task(): + agent = _make_agent() + await agent._dispatch_tool('create_task', {'project_id': 1, 'name': 'New Task'}) + agent._pt.create_task.assert_awaited_once_with(project_id=1, name='New Task') + +@pytest.mark.asyncio +async def test_dispatch_unknown_raises(): + agent = _make_agent() + with pytest.raises(ValueError, match='Unknown tool'): + await agent._dispatch_tool('nonexistent', {}) + + +# ── handle_peer_request ────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_peer_project_list(): + agent = _make_agent() + result = await agent.handle_peer_request('project_list', {}, 'dir-1') + assert 'projects' in result + assert len(result['projects']) == 2 + +@pytest.mark.asyncio +async def test_peer_task_count(): + agent = _make_agent() + agent._pt.get_tasks = AsyncMock(return_value=[{'id': 1}, {'id': 2}]) + result = await agent.handle_peer_request('task_count', {'project_id': 1}, 'dir-1') + assert result['count'] == 2 + +@pytest.mark.asyncio +async def test_peer_unknown_returns_error(): + agent = _make_agent() + result = await agent.handle_peer_request('bad_type', {}, 'dir-1') + assert 'error' in result + + +# ── sweep ──────────────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_sweep_finds_blocked_tasks(): + agent = _make_agent() + agent._pt.get_tasks = AsyncMock(return_value=[ + {'id': 1, 'name': 'Stuck', 'kanban_state': 'blocked', 'date_deadline': None} + ]) + result = await agent.sweep() + assert isinstance(result, SweepReport) + assert any(f['type'] == 'blocked_task' for f in result.findings) + +@pytest.mark.asyncio +async def test_sweep_finds_overdue_tasks(): + agent = _make_agent() + yesterday = str((datetime.date.today() - datetime.timedelta(days=1))) + agent._pt.get_tasks = AsyncMock(return_value=[ + {'id': 2, 'name': 'Late', 'kanban_state': 'normal', 'date_deadline': yesterday} + ]) + result = await agent.sweep() + assert any(f['type'] == 'overdue_task' for f in result.findings) + +@pytest.mark.asyncio +async def test_sweep_healthy_no_findings(): + agent = _make_agent() + future = str((datetime.date.today() + datetime.timedelta(days=7))) + agent._pt.get_tasks = AsyncMock(return_value=[ + {'id': 1, 'name': 'On track', 'kanban_state': 'normal', 'date_deadline': future} + ]) + result = await agent.sweep() + assert result.findings == [] + +@pytest.mark.asyncio +async def test_sweep_handles_exception(): + agent = _make_agent() + agent._pt.get_tasks = AsyncMock(side_effect=Exception('timeout')) + result = await agent.sweep() + assert result.error is not None diff --git a/tests/test_project_tools.py b/tests/test_project_tools.py new file mode 100644 index 0000000..ed1d626 --- /dev/null +++ b/tests/test_project_tools.py @@ -0,0 +1,233 @@ +"""Unit tests for ProjectTools.""" +import pytest +from unittest.mock import AsyncMock, MagicMock +from agent_service.tools.project_tools import ProjectTools +from agent_service.tools.odoo_client import WriteResult + + +def _make_tools(): + odoo = MagicMock() + odoo.search_read = AsyncMock(return_value=[]) + odoo.write = AsyncMock(return_value=WriteResult( + success=True, model='', record_id=None, action='write')) + odoo.create = AsyncMock(return_value=WriteResult( + success=True, model='', record_id=42, action='create')) + odoo.call = AsyncMock(return_value=99) + return ProjectTools(odoo) + + +# ── get_projects ────────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_get_projects_default(): + t = _make_tools() + result = await t.get_projects() + t._o.search_read.assert_awaited_once() + domain = t._o.search_read.call_args[0][1] + assert ('active', '=', True) in domain + assert isinstance(result, list) + + +@pytest.mark.asyncio +async def test_get_projects_inactive(): + t = _make_tools() + await t.get_projects(active=False) + domain = t._o.search_read.call_args[0][1] + assert ('active', '=', False) in domain + + +@pytest.mark.asyncio +async def test_get_projects_returns_data(): + t = _make_tools() + t._o.search_read = AsyncMock(return_value=[ + {'id': 1, 'name': 'Alpha', 'task_count': 10}, + {'id': 2, 'name': 'Beta', 'task_count': 5}, + ]) + result = await t.get_projects() + assert len(result) == 2 + + +# ── get_tasks ───────────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_get_tasks_default(): + t = _make_tools() + result = await t.get_tasks() + domain = t._o.search_read.call_args[0][1] + assert ('active', '=', True) in domain + assert isinstance(result, list) + + +@pytest.mark.asyncio +async def test_get_tasks_with_project(): + t = _make_tools() + await t.get_tasks(project_id=3) + domain = t._o.search_read.call_args[0][1] + assert ('project_id', '=', 3) in domain + + +@pytest.mark.asyncio +async def test_get_tasks_with_stage(): + t = _make_tools() + await t.get_tasks(stage_id=2) + domain = t._o.search_read.call_args[0][1] + assert ('stage_id', '=', 2) in domain + + +@pytest.mark.asyncio +async def test_get_tasks_with_user(): + t = _make_tools() + await t.get_tasks(user_id=5) + domain = t._o.search_read.call_args[0][1] + assert ('user_ids', 'in', [5]) in domain + + +# ── get_project_summary ─────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_get_project_summary_structure(): + t = _make_tools() + result = await t.get_project_summary(project_id=1) + assert 'project_id' in result + assert result['project_id'] == 1 + assert 'total_tasks' in result + assert 'blocked_tasks' in result + assert 'overdue_tasks' in result + + +@pytest.mark.asyncio +async def test_get_project_summary_counts_blocked(): + import datetime + t = _make_tools() + future = str(datetime.date.today().replace(year=2030)) + t._o.search_read = AsyncMock(return_value=[ + {'kanban_state': 'blocked', 'date_deadline': future, 'user_ids': []}, + {'kanban_state': 'normal', 'date_deadline': future, 'user_ids': []}, + {'kanban_state': 'blocked', 'date_deadline': future, 'user_ids': []}, + ]) + result = await t.get_project_summary(project_id=1) + assert result['total_tasks'] == 3 + assert result['blocked_tasks'] == 2 + + +@pytest.mark.asyncio +async def test_get_project_summary_counts_overdue(): + import datetime + t = _make_tools() + yesterday = str(datetime.date.today() - datetime.timedelta(days=1)) + t._o.search_read = AsyncMock(return_value=[ + {'kanban_state': 'normal', 'date_deadline': yesterday, 'user_ids': []}, + {'kanban_state': 'normal', 'date_deadline': None, 'user_ids': []}, + ]) + result = await t.get_project_summary(project_id=1) + assert result['overdue_tasks'] == 1 + + +# ── update_task_stage ───────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_update_task_stage_success(): + t = _make_tools() + result = await t.update_task_stage(task_id=5, stage_id=3) + assert result is True + t._o.write.assert_awaited_once_with('project.task', [5], {'stage_id': 3}) + + +@pytest.mark.asyncio +async def test_update_task_stage_failure(): + t = _make_tools() + t._o.write = AsyncMock(return_value=WriteResult( + success=False, model='', record_id=None, action='write')) + result = await t.update_task_stage(task_id=5, stage_id=3) + assert result is False + + +# ── assign_task ─────────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_assign_task_success(): + t = _make_tools() + result = await t.assign_task(task_id=5, user_id=10) + assert result is True + t._o.write.assert_awaited_once_with('project.task', [5], {'user_ids': [(4, 10)]}) + + +# ── create_task ─────────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_create_task_basic(): + t = _make_tools() + result = await t.create_task(project_id=1, name='New Task') + assert result == 99 + call_args = t._o.call.call_args[0] + assert call_args[0] == 'project.task' + assert call_args[1] == 'create' + vals = call_args[2][0] + assert vals['name'] == 'New Task' + assert vals['project_id'] == 1 + + +@pytest.mark.asyncio +async def test_create_task_with_all_options(): + t = _make_tools() + await t.create_task( + project_id=1, name='Full Task', + description='Description text', user_id=5, date_deadline='2026-06-01', + ) + vals = t._o.call.call_args[0][2][0] + assert vals['description'] == 'Description text' + assert vals['user_ids'] == [(4, 5)] + assert vals['date_deadline'] == '2026-06-01' + + +@pytest.mark.asyncio +async def test_create_task_no_optional_fields(): + t = _make_tools() + await t.create_task(project_id=1, name='Simple Task') + vals = t._o.call.call_args[0][2][0] + assert 'description' not in vals + assert 'user_ids' not in vals + assert 'date_deadline' not in vals + + +# ── log_timesheet ───────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_log_timesheet_basic(): + t = _make_tools() + result = await t.log_timesheet(task_id=5, employee_id=3, hours=8.0) + assert result == 99 + call_args = t._o.call.call_args[0] + assert call_args[0] == 'account.analytic.line' + vals = call_args[2][0] + assert vals['task_id'] == 5 + assert vals['employee_id'] == 3 + assert vals['unit_amount'] == 8.0 + + +@pytest.mark.asyncio +async def test_log_timesheet_with_date(): + t = _make_tools() + await t.log_timesheet(task_id=5, employee_id=3, hours=4.0, date='2026-01-15') + vals = t._o.call.call_args[0][2][0] + assert vals['date'] == '2026-01-15' + + +@pytest.mark.asyncio +async def test_log_timesheet_default_description(): + t = _make_tools() + await t.log_timesheet(task_id=5, employee_id=3, hours=2.0) + vals = t._o.call.call_args[0][2][0] + assert 'AI' in vals['name'] or vals['name'] + + +# ── post_chatter_note ───────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_post_chatter_note(): + t = _make_tools() + result = await t.post_chatter_note('project.task', 1, 'Task blocked by dep') + assert result is True + t._o.call.assert_awaited_once() + call_args = t._o.call.call_args[0] + assert call_args[3]['body'] == 'Task blocked by dep' diff --git a/tests/test_registry.py b/tests/test_registry.py new file mode 100644 index 0000000..4772d28 --- /dev/null +++ b/tests/test_registry.py @@ -0,0 +1,186 @@ +"""Unit tests for AgentRegistry.""" +import pytest +from unittest.mock import AsyncMock, MagicMock +from agent_service.agents.registry import AgentRegistry + + +def _make_registry(): + return AgentRegistry() + + +def _make_agent(name='test_agent'): + agent = MagicMock() + agent.name = name + return agent + + +# ── register / get_agent_instance ─────────────────────────────────────────── + +def test_register_stores_instance(): + reg = _make_registry() + agent = _make_agent() + reg.register('test_agent', agent) + assert reg.get_agent_instance('test_agent') is agent + + +def test_register_missing_returns_none(): + reg = _make_registry() + assert reg.get_agent_instance('nonexistent') is None + + +def test_register_activates_agent(): + reg = _make_registry() + reg.register('test_agent', _make_agent()) + assert 'test_agent' in reg._active + + +def test_register_assigns_default_capability(): + reg = _make_registry() + reg.register('test_agent', _make_agent()) + assert reg._capabilities.get('test_agent') + + +def test_register_known_agent_uses_default_description(): + reg = _make_registry() + reg.register('crm_agent', _make_agent('crm_agent')) + assert 'crm' in reg._capabilities['crm_agent'].lower() or \ + 'lead' in reg._capabilities['crm_agent'].lower() + + +def test_register_multiple_agents(): + reg = _make_registry() + reg.register('crm_agent', _make_agent('crm_agent')) + reg.register('sales_agent', _make_agent('sales_agent')) + assert len(reg._agents) == 2 + + +# ── is_active ──────────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_is_active_after_register(): + reg = _make_registry() + reg.register('crm_agent', _make_agent()) + assert await reg.is_active('crm_agent') is True + + +@pytest.mark.asyncio +async def test_is_active_unknown_agent(): + reg = _make_registry() + assert await reg.is_active('ghost_agent') is False + + +# ── get_active_agents ──────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_get_active_agents_returns_only_registered(): + reg = _make_registry() + reg.register('crm_agent', _make_agent()) + reg._active.add('ghost_agent') # active but no instance + result = await reg.get_active_agents() + keys = {r['agent_key'] for r in result} + assert 'crm_agent' in keys + assert 'ghost_agent' not in keys + + +@pytest.mark.asyncio +async def test_get_active_agents_includes_capabilities(): + reg = _make_registry() + reg.register('crm_agent', _make_agent()) + result = await reg.get_active_agents() + assert any('capabilities_summary' in r for r in result) + + +@pytest.mark.asyncio +async def test_get_active_agents_empty_registry(): + reg = _make_registry() + result = await reg.get_active_agents() + assert result == [] + + +# ── get_all ────────────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_get_all_includes_all_registered(): + reg = _make_registry() + reg.register('crm_agent', _make_agent()) + reg.register('sales_agent', _make_agent()) + result = await reg.get_all() + names = {r['name'] for r in result} + assert 'crm_agent' in names + assert 'sales_agent' in names + + +@pytest.mark.asyncio +async def test_get_all_marks_active_flag(): + reg = _make_registry() + reg.register('crm_agent', _make_agent()) + result = await reg.get_all() + crm = next(r for r in result if r['name'] == 'crm_agent') + assert crm['active'] is True + assert crm['has_instance'] is True + + +# ── sync ───────────────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_sync_replaces_active_set(): + reg = _make_registry() + reg.register('crm_agent', _make_agent()) + reg.register('sales_agent', _make_agent()) + await reg.sync(['sales_agent']) + assert await reg.is_active('sales_agent') is True + assert await reg.is_active('crm_agent') is False + + +@pytest.mark.asyncio +async def test_sync_with_empty_list_deactivates_all(): + reg = _make_registry() + reg.register('crm_agent', _make_agent()) + await reg.sync([]) + assert await reg.is_active('crm_agent') is False + + +# ── load_from_odoo ─────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_load_from_odoo_sets_active_agents(): + reg = _make_registry() + odoo = MagicMock() + odoo.search_read = AsyncMock(return_value=[ + {'agent_name': 'crm_agent', 'domain': 'CRM leads', 'backend': 'local'}, + {'agent_name': 'sales_agent', 'domain': '', 'backend': 'cloud'}, + ]) + await reg.load_from_odoo(odoo) + assert 'crm_agent' in reg._active + assert 'sales_agent' in reg._active + + +@pytest.mark.asyncio +async def test_load_from_odoo_uses_domain_when_present(): + reg = _make_registry() + odoo = MagicMock() + odoo.search_read = AsyncMock(return_value=[ + {'agent_name': 'crm_agent', 'domain': 'Custom CRM description', 'backend': 'local'}, + ]) + await reg.load_from_odoo(odoo) + assert reg._capabilities['crm_agent'] == 'Custom CRM description' + + +@pytest.mark.asyncio +async def test_load_from_odoo_falls_back_to_default_when_no_domain(): + reg = _make_registry() + odoo = MagicMock() + odoo.search_read = AsyncMock(return_value=[ + {'agent_name': 'crm_agent', 'domain': '', 'backend': 'local'}, + ]) + await reg.load_from_odoo(odoo) + assert reg._capabilities.get('crm_agent') + + +@pytest.mark.asyncio +async def test_load_from_odoo_handles_exception_gracefully(): + reg = _make_registry() + odoo = MagicMock() + odoo.search_read = AsyncMock(side_effect=Exception('DB error')) + await reg.load_from_odoo(odoo) + assert reg._active == set() diff --git a/tests/test_sales_agent.py b/tests/test_sales_agent.py new file mode 100644 index 0000000..e97a48f --- /dev/null +++ b/tests/test_sales_agent.py @@ -0,0 +1,227 @@ +"""Unit tests for SalesAgent — plan, gather, reason, report, peer_bus, sweep.""" +import datetime +import pytest +from unittest.mock import AsyncMock, MagicMock +from agent_service.agents.sales_agent import SalesAgent, SALES_TOOLS +from agent_service.agents.base_agent import AgentDirective, AgentReport, SweepReport + + +def _directive(intent='', context=None): + return AgentDirective( + directive_id='sales-d1', user_id='1', intent=intent, + context=context or {}, agent_name='sales_agent', + ) + + +def _make_agent(): + agent = SalesAgent(odoo=MagicMock(), llm=MagicMock()) + agent._st = MagicMock() + agent._st.get_sales_summary = AsyncMock(return_value={ + 'order_count': 12, 'total_revenue': 85000.0 + }) + agent._st.get_quotations = AsyncMock(return_value=[]) + agent._st.get_sales_orders = AsyncMock(return_value=[]) + agent._st.get_customer_orders = AsyncMock(return_value=[]) + agent._st.confirm_quotation = AsyncMock(return_value=True) + agent._st.update_order_note = AsyncMock(return_value=True) + agent._st.flag_for_review = AsyncMock(return_value=True) + agent._st.post_chatter_note = AsyncMock(return_value=True) + return agent + + +# ── Meta ──────────────────────────────────────────────────────────────────── + +def test_tool_count(): + assert len(SALES_TOOLS) <= 8 + +def test_agent_name(): + assert SalesAgent.name == 'sales_agent' + + +# ── _plan ─────────────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_plan_summary_intent(): + agent = _make_agent() + plan = await agent._plan(_directive(intent='show sales summary')) + assert plan['fetch_summary'] is True + +@pytest.mark.asyncio +async def test_plan_revenue_intent(): + agent = _make_agent() + plan = await agent._plan(_directive(intent='what is our revenue this month')) + assert plan['fetch_summary'] is True + +@pytest.mark.asyncio +async def test_plan_quotation_intent(): + agent = _make_agent() + plan = await agent._plan(_directive(intent='list open quotations')) + assert plan['fetch_quotations'] is True + +@pytest.mark.asyncio +async def test_plan_order_intent(): + agent = _make_agent() + plan = await agent._plan(_directive(intent='show all orders')) + assert plan['fetch_orders'] is True + +@pytest.mark.asyncio +async def test_plan_propagates_partner_and_dates(): + agent = _make_agent() + plan = await agent._plan(_directive(context={'partner_id': 5, 'date_from': '2026-01-01'})) + assert plan['partner_id'] == 5 + assert plan['date_from'] == '2026-01-01' + + +# ── _gather ───────────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_gather_summary_by_default(): + agent = _make_agent() + ctx = {'plan': {'fetch_summary': True, 'fetch_quotations': False, + 'fetch_orders': False, 'partner_id': None, + 'date_from': None, 'date_to': None}} + data = await agent._gather(ctx) + assert 'summary' in data + agent._st.get_sales_summary.assert_awaited_once() + +@pytest.mark.asyncio +async def test_gather_quotations(): + agent = _make_agent() + agent._st.get_quotations = AsyncMock(return_value=[{'id': 1}]) + ctx = {'plan': {'fetch_summary': False, 'fetch_quotations': True, + 'fetch_orders': False, 'partner_id': None, + 'date_from': None, 'date_to': None}} + data = await agent._gather(ctx) + assert 'quotations' in data + +@pytest.mark.asyncio +async def test_gather_orders(): + agent = _make_agent() + agent._st.get_sales_orders = AsyncMock(return_value=[{'id': 1}]) + ctx = {'plan': {'fetch_summary': False, 'fetch_quotations': False, + 'fetch_orders': True, 'partner_id': None, + 'date_from': None, 'date_to': None}} + data = await agent._gather(ctx) + assert 'orders' in data + + +# ── _reason ───────────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_reason_zero_revenue_escalates(): + agent = _make_agent() + agent._gathered_data = {'summary': {'order_count': 0, 'total_revenue': 0.0}} + analysis = await agent._reason({}) + assert len(analysis['escalations']) == 1 + assert 'no confirmed sales' in analysis['escalations'][0].lower() + +@pytest.mark.asyncio +async def test_reason_positive_revenue_no_escalation(): + agent = _make_agent() + agent._gathered_data = {'summary': {'order_count': 5, 'total_revenue': 10000.0}} + analysis = await agent._reason({}) + assert analysis['escalations'] == [] + +@pytest.mark.asyncio +async def test_reason_no_summary_no_escalation(): + agent = _make_agent() + agent._gathered_data = {'quotations': []} + analysis = await agent._reason({}) + assert analysis['escalations'] == [] + + +# ── _report ───────────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_report_includes_revenue(): + agent = _make_agent() + agent._gathered_data = {'summary': {'order_count': 12, 'total_revenue': 85000.0}} + agent._escalations_list = [] + report = await agent._report({}) + assert isinstance(report, AgentReport) + assert '12' in report.summary + assert '85000' in report.summary + +@pytest.mark.asyncio +async def test_report_fallback_message(): + agent = _make_agent() + agent._gathered_data = {} + agent._escalations_list = [] + report = await agent._report({}) + assert 'complete' in report.summary.lower() or 'sales' in report.summary.lower() + + +# ── _dispatch_tool ─────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_dispatch_get_sales_summary(): + agent = _make_agent() + await agent._dispatch_tool('get_sales_summary', {}) + agent._st.get_sales_summary.assert_awaited_once() + +@pytest.mark.asyncio +async def test_dispatch_confirm_quotation(): + agent = _make_agent() + await agent._dispatch_tool('confirm_quotation', {'order_id': 7}) + agent._st.confirm_quotation.assert_awaited_once_with(order_id=7) + +@pytest.mark.asyncio +async def test_dispatch_unknown_raises(): + agent = _make_agent() + with pytest.raises(ValueError, match='Unknown tool'): + await agent._dispatch_tool('nonexistent', {}) + + +# ── handle_peer_request ────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_peer_sales_summary(): + agent = _make_agent() + result = await agent.handle_peer_request('sales_summary', {}, 'dir-1') + assert 'order_count' in result or 'total_revenue' in result + +@pytest.mark.asyncio +async def test_peer_customer_orders(): + agent = _make_agent() + agent._st.get_customer_orders = AsyncMock(return_value=[{'id': 1}]) + result = await agent.handle_peer_request('customer_orders', {'partner_id': 10}, 'dir-1') + assert 'orders' in result + assert len(result['orders']) == 1 + +@pytest.mark.asyncio +async def test_peer_unknown_returns_error(): + agent = _make_agent() + result = await agent.handle_peer_request('bad_type', {}, 'dir-1') + assert 'error' in result + + +# ── sweep ──────────────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_sweep_finds_expired_quotations(): + agent = _make_agent() + yesterday = str((datetime.date.today() - datetime.timedelta(days=1))) + agent._st.get_quotations = AsyncMock(return_value=[ + {'id': 1, 'partner_id': [5, 'ACME'], 'validity_date': yesterday} + ]) + result = await agent.sweep() + assert isinstance(result, SweepReport) + assert len(result.findings) == 1 + assert result.findings[0]['type'] == 'expired_quotation' + +@pytest.mark.asyncio +async def test_sweep_valid_quotations_no_findings(): + agent = _make_agent() + future = str((datetime.date.today() + datetime.timedelta(days=14))) + agent._st.get_quotations = AsyncMock(return_value=[ + {'id': 1, 'partner_id': [5, 'ACME'], 'validity_date': future} + ]) + result = await agent.sweep() + assert result.findings == [] + +@pytest.mark.asyncio +async def test_sweep_handles_exception(): + agent = _make_agent() + agent._st.get_quotations = AsyncMock(side_effect=Exception('network')) + result = await agent.sweep() + assert result.error is not None diff --git a/tests/test_sales_tools.py b/tests/test_sales_tools.py new file mode 100644 index 0000000..0b8f5a9 --- /dev/null +++ b/tests/test_sales_tools.py @@ -0,0 +1,200 @@ +"""Unit tests for SalesTools.""" +import pytest +from unittest.mock import AsyncMock, MagicMock +from agent_service.tools.sales_tools import SalesTools +from agent_service.tools.odoo_client import WriteResult + + +def _make_tools(): + odoo = MagicMock() + odoo.search_read = AsyncMock(return_value=[]) + odoo.write = AsyncMock(return_value=WriteResult( + success=True, model='', record_id=None, action='write')) + odoo.create = AsyncMock(return_value=WriteResult( + success=True, model='', record_id=42, action='create')) + odoo.call = AsyncMock(return_value=True) + return SalesTools(odoo) + + +# ── get_sales_orders ────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_get_sales_orders_default(): + t = _make_tools() + result = await t.get_sales_orders() + domain = t._o.search_read.call_args[0][1] + assert ('state', '=', 'sale') in domain + assert isinstance(result, list) + + +@pytest.mark.asyncio +async def test_get_sales_orders_with_partner(): + t = _make_tools() + await t.get_sales_orders(partner_id=5) + domain = t._o.search_read.call_args[0][1] + assert ('partner_id', '=', 5) in domain + + +@pytest.mark.asyncio +async def test_get_sales_orders_with_date_range(): + t = _make_tools() + await t.get_sales_orders(date_from='2026-01-01', date_to='2026-01-31') + domain = t._o.search_read.call_args[0][1] + assert ('date_order', '>=', '2026-01-01') in domain + assert ('date_order', '<=', '2026-01-31') in domain + + +@pytest.mark.asyncio +async def test_get_sales_orders_custom_state(): + t = _make_tools() + await t.get_sales_orders(state='draft') + domain = t._o.search_read.call_args[0][1] + assert ('state', '=', 'draft') in domain + + +# ── get_quotations ──────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_get_quotations_default(): + t = _make_tools() + result = await t.get_quotations() + domain = t._o.search_read.call_args[0][1] + assert ('state', 'in', ['draft', 'sent']) in domain + assert isinstance(result, list) + + +@pytest.mark.asyncio +async def test_get_quotations_with_partner(): + t = _make_tools() + await t.get_quotations(partner_id=7) + domain = t._o.search_read.call_args[0][1] + assert ('partner_id', '=', 7) in domain + + +# ── get_sales_summary ───────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_get_sales_summary_empty(): + t = _make_tools() + result = await t.get_sales_summary() + assert result['order_count'] == 0 + assert result['total_revenue'] == 0 + + +@pytest.mark.asyncio +async def test_get_sales_summary_aggregates(): + t = _make_tools() + t._o.search_read = AsyncMock(return_value=[ + {'amount_total': 1000.0, 'user_id': [1, 'Alice']}, + {'amount_total': 2000.0, 'user_id': [1, 'Alice']}, + {'amount_total': 500.0, 'user_id': [2, 'Bob']}, + ]) + result = await t.get_sales_summary() + assert result['order_count'] == 3 + assert result['total_revenue'] == 3500.0 + assert 'by_sales_rep' in result + + +@pytest.mark.asyncio +async def test_get_sales_summary_by_rep(): + t = _make_tools() + t._o.search_read = AsyncMock(return_value=[ + {'amount_total': 5000.0, 'user_id': [1, 'Alice']}, + {'amount_total': 3000.0, 'user_id': [2, 'Bob']}, + ]) + result = await t.get_sales_summary() + reps = result['by_sales_rep'] + assert len(reps) == 2 + assert reps[0]['total'] >= reps[1]['total'] # sorted descending + + +@pytest.mark.asyncio +async def test_get_sales_summary_with_dates(): + t = _make_tools() + await t.get_sales_summary(date_from='2026-01-01', date_to='2026-01-31') + domain = t._o.search_read.call_args[0][1] + assert ('date_order', '>=', '2026-01-01') in domain + + +# ── get_customer_orders ─────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_get_customer_orders(): + t = _make_tools() + t._o.search_read = AsyncMock(return_value=[ + {'id': 1, 'name': 'SO/001', 'state': 'sale'} + ]) + result = await t.get_customer_orders(partner_id=5) + domain = t._o.search_read.call_args[0][1] + assert ('partner_id', '=', 5) in domain + assert len(result) == 1 + + +# ── confirm_quotation ───────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_confirm_quotation_success(): + t = _make_tools() + result = await t.confirm_quotation(order_id=3) + assert result is True + t._o.call.assert_awaited_once_with('sale.order', 'action_confirm', [[3]]) + + +@pytest.mark.asyncio +async def test_confirm_quotation_handles_exception(): + t = _make_tools() + t._o.call = AsyncMock(side_effect=Exception('access denied')) + result = await t.confirm_quotation(order_id=3) + assert result is False + + +# ── update_order_note ───────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_update_order_note_success(): + t = _make_tools() + result = await t.update_order_note(order_id=1, note='Special instructions') + assert result is True + t._o.write.assert_awaited_once_with('sale.order', [1], {'note': 'Special instructions'}) + + +@pytest.mark.asyncio +async def test_update_order_note_failure(): + t = _make_tools() + t._o.write = AsyncMock(return_value=WriteResult( + success=False, model='', record_id=None, action='write')) + result = await t.update_order_note(order_id=1, note='note') + assert result is False + + +# ── flag_for_review ─────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_flag_for_review(): + t = _make_tools() + result = await t.flag_for_review('sale.order', 1, 'Large discount applied') + assert result is True + call_args = t._o.call.call_args[0] + assert 'message_post' in str(call_args) + body = call_args[3]['body'] + assert '[AI FLAG - MEDIUM]' in body + assert 'Large discount applied' in body + + +@pytest.mark.asyncio +async def test_flag_for_review_high_severity(): + t = _make_tools() + await t.flag_for_review('sale.order', 1, 'reason', severity='high') + body = t._o.call.call_args[0][3]['body'] + assert 'HIGH' in body + + +# ── post_chatter_note ───────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_post_chatter_note(): + t = _make_tools() + result = await t.post_chatter_note('sale.order', 1, 'Order confirmed by AI') + assert result is True + call_args = t._o.call.call_args[0] + assert call_args[3]['body'] == 'Order confirmed by AI' diff --git a/tests/test_tool_validator.py b/tests/test_tool_validator.py index 116ad7b..80c0ff6 100644 --- a/tests/test_tool_validator.py +++ b/tests/test_tool_validator.py @@ -1,5 +1,8 @@ +"""Unit tests for ToolCallValidator and validate_agent_tools.""" import pytest -from agent_service.llm.tool_validator import ToolCallValidator, AgentConfigError +from agent_service.llm.tool_validator import ( + ToolCallValidator, AgentConfigError, ToolValidationError, validate_agent_tools, +) SAMPLE_TOOLS = [ {'name': 'get_invoices', 'parameters': { @@ -16,20 +19,25 @@ SAMPLE_TOOLS = [ def test_validator_init(): v = ToolCallValidator(SAMPLE_TOOLS) - assert 'get_invoices' in v._tool_map + assert 'get_invoices' in v._tools def test_raises_on_too_many_tools(): many_tools = [{'name': f'tool_{i}', 'parameters': {}} for i in range(9)] - with pytest.raises(AgentConfigError, match='MAX_TOOLS_PER_AGENT'): - ToolCallValidator(many_tools) + with pytest.raises(AgentConfigError, match='max is 8'): + validate_agent_tools(many_tools, 'test_agent') + + +def test_validate_agent_tools_ok_for_8(): + tools_8 = [{'name': f'tool_{i}', 'parameters': {}} for i in range(8)] + validate_agent_tools(tools_8, 'test_agent') # should not raise def test_valid_tool_call(): v = ToolCallValidator(SAMPLE_TOOLS) result = v.validate({'name': 'get_invoices', 'arguments': {'partner_id': 42}}) - assert result.name == 'get_invoices' - assert result.arguments['partner_id'] == 42 + assert result['name'] == 'get_invoices' + assert result['arguments']['partner_id'] == 42 def test_strips_hallucinated_params(): @@ -37,37 +45,71 @@ def test_strips_hallucinated_params(): result = v.validate({'name': 'get_invoices', 'arguments': { 'partner_id': 1, 'nonexistent_param': 'bad', }}) - assert 'nonexistent_param' not in result.arguments + assert 'nonexistent_param' not in result['arguments'] def test_missing_required_param_raises(): v = ToolCallValidator(SAMPLE_TOOLS) - with pytest.raises(ValueError, match='partner_id'): + with pytest.raises(ToolValidationError, match='partner_id'): v.validate({'name': 'get_invoices', 'arguments': {}}) def test_type_coercion_int(): v = ToolCallValidator(SAMPLE_TOOLS) result = v.validate({'name': 'get_invoices', 'arguments': {'partner_id': '42'}}) - assert result.arguments['partner_id'] == 42 + assert result['arguments']['partner_id'] == 42 def test_unknown_tool_raises(): v = ToolCallValidator(SAMPLE_TOOLS) - with pytest.raises(ValueError, match='Unknown tool'): + with pytest.raises(ToolValidationError, match='Unknown tool'): v.validate({'name': 'nonexistent_tool', 'arguments': {}}) -def test_parse_or_fallback_returns_none_on_bad_json(): +def test_parse_or_fallback_returns_none_on_unknown_tool(): v = ToolCallValidator(SAMPLE_TOOLS) - result = v.parse_or_fallback('not json at all', context='test') + result = v.parse_or_fallback({'name': 'bad_tool', 'arguments': {}}) assert result is None -def test_parse_or_fallback_valid_json(): +def test_parse_or_fallback_returns_none_on_missing_required(): v = ToolCallValidator(SAMPLE_TOOLS) - import json - raw = json.dumps({'name': 'send_reminder', 'arguments': {'invoice_id': 5}}) - result = v.parse_or_fallback(raw, context='test') + result = v.parse_or_fallback({'name': 'get_invoices', 'arguments': {}}) + assert result is None + + +def test_parse_or_fallback_valid_call(): + v = ToolCallValidator(SAMPLE_TOOLS) + result = v.parse_or_fallback({'name': 'send_reminder', 'arguments': {'invoice_id': 5}}) assert result is not None - assert result.name == 'send_reminder' + assert result['name'] == 'send_reminder' + + +def test_parse_or_fallback_none_input(): + v = ToolCallValidator(SAMPLE_TOOLS) + result = v.parse_or_fallback(None) + assert result is None + + +def test_optional_param_not_required(): + v = ToolCallValidator(SAMPLE_TOOLS) + result = v.validate({'name': 'get_invoices', 'arguments': {'partner_id': 1}}) + assert result['name'] == 'get_invoices' + + +def test_enum_validation_rejects_invalid_value(): + tools = [{'name': 'set_state', 'parameters': { + 'state': {'type': 'string', 'enum': ['draft', 'posted']}, + }}] + v = ToolCallValidator(tools) + with pytest.raises(ToolValidationError): + v.validate({'name': 'set_state', 'arguments': {'state': 'invalid_state'}}) + + +def test_enum_validation_accepts_valid_value(): + tools = [{'name': 'set_state', 'parameters': { + 'state': {'type': 'string', 'enum': ['draft', 'posted']}, + }}] + v = ToolCallValidator(tools) + result = v.validate({'name': 'set_state', 'arguments': {'state': 'posted'}}) + assert result['arguments']['state'] == 'posted'