anderson-ufrj commited on
Commit
60716dd
·
1 Parent(s): 6eb1f60

test(multiagent): implement agent coordination tests

Browse files

- Test agent message system
- Test agent context management
- Test base agent functionality
- Test reflective agent capabilities
- Test master agent orchestration
- Add end-to-end investigation tests

tests/multiagent/test_agent_coordination.py ADDED
@@ -0,0 +1,605 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Multi-agent coordination tests."""
2
+ import pytest
3
+ import asyncio
4
+ from unittest.mock import MagicMock, patch, AsyncMock
5
+ from datetime import datetime
6
+ import json
7
+
8
+ from src.agents.deodoro import (
9
+ BaseAgent, ReflectiveAgent, AgentContext,
10
+ AgentMessage, AgentResponse, AgentStatus
11
+ )
12
+ from src.agents.abaporu import MasterAgent
13
+ from src.agents.zumbi import InvestigatorAgent
14
+ from src.agents.anita import AnalystAgent
15
+ from src.agents.tiradentes import ReporterAgent
16
+ from src.agents.ayrton_senna import SemanticRouterAgent
17
+ from src.infrastructure.orchestrator import AgentOrchestrator
18
+ from src.infrastructure.agent_pool import AgentPool
19
+
20
+
21
+ class TestAgentMessage:
22
+ """Test agent message system."""
23
+
24
+ def test_create_agent_message(self):
25
+ """Test creating agent message."""
26
+ message = AgentMessage(
27
+ sender="agent1",
28
+ recipient="agent2",
29
+ action="analyze",
30
+ payload={"data": "test"},
31
+ context={"investigation_id": "inv-123"}
32
+ )
33
+
34
+ assert message.sender == "agent1"
35
+ assert message.recipient == "agent2"
36
+ assert message.action == "analyze"
37
+ assert message.payload["data"] == "test"
38
+ assert message.requires_response is True # Default
39
+ assert message.message_id is not None
40
+
41
+ def test_message_serialization(self):
42
+ """Test message serialization."""
43
+ message = AgentMessage(
44
+ sender="test_agent",
45
+ recipient="target_agent",
46
+ action="process",
47
+ payload={"value": 100}
48
+ )
49
+
50
+ # Serialize to dict
51
+ message_dict = message.to_dict()
52
+
53
+ assert isinstance(message_dict, dict)
54
+ assert message_dict["sender"] == "test_agent"
55
+ assert "timestamp" in message_dict
56
+
57
+ # Deserialize back
58
+ restored = AgentMessage.from_dict(message_dict)
59
+ assert restored.sender == message.sender
60
+ assert restored.action == message.action
61
+
62
+
63
+ class TestAgentContext:
64
+ """Test agent context management."""
65
+
66
+ def test_agent_context_creation(self):
67
+ """Test creating agent context."""
68
+ context = AgentContext(
69
+ investigation_id="inv-456",
70
+ user_id="user-789",
71
+ parameters={"threshold": 0.8}
72
+ )
73
+
74
+ assert context.investigation_id == "inv-456"
75
+ assert context.user_id == "user-789"
76
+ assert context.parameters["threshold"] == 0.8
77
+ assert context.shared_data == {} # Empty initially
78
+
79
+ def test_context_data_sharing(self):
80
+ """Test data sharing through context."""
81
+ context = AgentContext(investigation_id="inv-123")
82
+
83
+ # Agent 1 adds data
84
+ context.add_data("anomalies", [{"type": "price", "score": 0.9}])
85
+
86
+ # Agent 2 can access it
87
+ anomalies = context.get_data("anomalies")
88
+ assert len(anomalies) == 1
89
+ assert anomalies[0]["type"] == "price"
90
+
91
+ def test_context_history_tracking(self):
92
+ """Test context tracks agent history."""
93
+ context = AgentContext(investigation_id="inv-123")
94
+
95
+ # Track agent executions
96
+ context.add_history("agent1", "Started analysis")
97
+ context.add_history("agent2", "Found 3 anomalies")
98
+
99
+ history = context.get_history()
100
+ assert len(history) == 2
101
+ assert history[0]["agent"] == "agent1"
102
+ assert history[1]["message"] == "Found 3 anomalies"
103
+
104
+
105
+ class TestBaseAgent:
106
+ """Test base agent functionality."""
107
+
108
+ @pytest.fixture
109
+ def test_agent(self):
110
+ """Create test agent."""
111
+ class TestAgent(BaseAgent):
112
+ def __init__(self):
113
+ super().__init__(
114
+ agent_id="test_agent",
115
+ name="Test Agent",
116
+ description="Agent for testing"
117
+ )
118
+
119
+ async def process(self, message: AgentMessage) -> AgentResponse:
120
+ return AgentResponse(
121
+ agent_id=self.agent_id,
122
+ status=AgentStatus.COMPLETED,
123
+ result={"processed": True}
124
+ )
125
+
126
+ return TestAgent()
127
+
128
+ @pytest.mark.asyncio
129
+ async def test_agent_initialization(self, test_agent):
130
+ """Test agent initialization."""
131
+ assert test_agent.agent_id == "test_agent"
132
+ assert test_agent.name == "Test Agent"
133
+ assert test_agent.status == AgentStatus.IDLE
134
+ assert test_agent.metrics is not None
135
+
136
+ @pytest.mark.asyncio
137
+ async def test_agent_execute(self, test_agent):
138
+ """Test agent execution."""
139
+ message = AgentMessage(
140
+ sender="orchestrator",
141
+ recipient="test_agent",
142
+ action="test",
143
+ payload={"data": "test"}
144
+ )
145
+
146
+ response = await test_agent.execute(message)
147
+
148
+ assert response.status == AgentStatus.COMPLETED
149
+ assert response.result["processed"] is True
150
+ assert test_agent.metrics["executions"] == 1
151
+
152
+ @pytest.mark.asyncio
153
+ async def test_agent_retry_on_failure(self, test_agent):
154
+ """Test agent retry mechanism."""
155
+ # Mock process to fail first time, succeed second
156
+ call_count = 0
157
+
158
+ async def mock_process(message):
159
+ nonlocal call_count
160
+ call_count += 1
161
+ if call_count == 1:
162
+ raise Exception("Temporary failure")
163
+ return AgentResponse(
164
+ agent_id="test_agent",
165
+ status=AgentStatus.COMPLETED,
166
+ result={"success": True}
167
+ )
168
+
169
+ test_agent.process = mock_process
170
+
171
+ message = AgentMessage(
172
+ sender="test",
173
+ recipient="test_agent",
174
+ action="retry_test"
175
+ )
176
+
177
+ response = await test_agent.execute(message)
178
+
179
+ assert response.status == AgentStatus.COMPLETED
180
+ assert call_count == 2 # Failed once, succeeded on retry
181
+
182
+
183
+ class TestReflectiveAgent:
184
+ """Test reflective agent capabilities."""
185
+
186
+ @pytest.fixture
187
+ def reflective_agent(self):
188
+ """Create reflective test agent."""
189
+ class TestReflectiveAgent(ReflectiveAgent):
190
+ def __init__(self):
191
+ super().__init__(
192
+ agent_id="reflective_test",
193
+ name="Reflective Test Agent",
194
+ description="Agent with reflection"
195
+ )
196
+
197
+ async def process(self, message: AgentMessage) -> AgentResponse:
198
+ # Simulate processing with quality score
199
+ return AgentResponse(
200
+ agent_id=self.agent_id,
201
+ status=AgentStatus.COMPLETED,
202
+ result={
203
+ "analysis": "Basic analysis",
204
+ "quality_score": 0.6 # Below threshold
205
+ }
206
+ )
207
+
208
+ async def reflect(self, result: dict, context: dict) -> dict:
209
+ # Improve based on reflection
210
+ return {
211
+ "analysis": "Improved analysis with more detail",
212
+ "quality_score": 0.9
213
+ }
214
+
215
+ return TestReflectiveAgent()
216
+
217
+ @pytest.mark.asyncio
218
+ async def test_reflection_improves_quality(self, reflective_agent):
219
+ """Test reflection improves result quality."""
220
+ message = AgentMessage(
221
+ sender="test",
222
+ recipient="reflective_test",
223
+ action="analyze_with_reflection"
224
+ )
225
+
226
+ response = await reflective_agent.execute(message)
227
+
228
+ # Should have improved through reflection
229
+ assert response.result["quality_score"] > 0.8
230
+ assert "Improved analysis" in response.result["analysis"]
231
+ assert reflective_agent.metrics["reflections"] > 0
232
+
233
+ @pytest.mark.asyncio
234
+ async def test_reflection_loop_limit(self, reflective_agent):
235
+ """Test reflection loop has maximum iterations."""
236
+ # Mock reflect to never improve quality
237
+ async def mock_reflect(result, context):
238
+ return {"quality_score": 0.5} # Always below threshold
239
+
240
+ reflective_agent.reflect = mock_reflect
241
+ reflective_agent.max_reflection_loops = 3
242
+
243
+ message = AgentMessage(
244
+ sender="test",
245
+ recipient="reflective_test",
246
+ action="test"
247
+ )
248
+
249
+ response = await reflective_agent.execute(message)
250
+
251
+ # Should stop after max loops
252
+ assert reflective_agent.metrics["reflections"] == 3
253
+
254
+
255
+ class TestMasterAgentOrchestration:
256
+ """Test master agent orchestration."""
257
+
258
+ @pytest.fixture
259
+ def master_agent(self):
260
+ """Create master agent."""
261
+ return MasterAgent()
262
+
263
+ @pytest.fixture
264
+ def mock_agents(self):
265
+ """Create mock agents."""
266
+ investigator = AsyncMock()
267
+ investigator.agent_id = "zumbi"
268
+ investigator.execute.return_value = AgentResponse(
269
+ agent_id="zumbi",
270
+ status=AgentStatus.COMPLETED,
271
+ result={"anomalies": [{"type": "price", "score": 0.8}]}
272
+ )
273
+
274
+ analyst = AsyncMock()
275
+ analyst.agent_id = "anita"
276
+ analyst.execute.return_value = AgentResponse(
277
+ agent_id="anita",
278
+ status=AgentStatus.COMPLETED,
279
+ result={"patterns": [{"type": "temporal", "confidence": 0.7}]}
280
+ )
281
+
282
+ return {"zumbi": investigator, "anita": analyst}
283
+
284
+ @pytest.mark.asyncio
285
+ async def test_master_agent_coordination(self, master_agent, mock_agents):
286
+ """Test master agent coordinating multiple agents."""
287
+ # Patch agent registry
288
+ with patch.object(master_agent, "_agent_registry", mock_agents):
289
+ message = AgentMessage(
290
+ sender="user",
291
+ recipient="abaporu",
292
+ action="investigate",
293
+ payload={
294
+ "query": "Analyze Ministry of Health contracts",
295
+ "parameters": {"year": 2024}
296
+ }
297
+ )
298
+
299
+ response = await master_agent.execute(message)
300
+
301
+ assert response.status == AgentStatus.COMPLETED
302
+ assert "anomalies" in response.result
303
+ assert "patterns" in response.result
304
+
305
+ # Verify agents were called
306
+ mock_agents["zumbi"].execute.assert_called_once()
307
+ mock_agents["anita"].execute.assert_called_once()
308
+
309
+ @pytest.mark.asyncio
310
+ async def test_master_agent_error_handling(self, master_agent, mock_agents):
311
+ """Test master agent handles agent failures."""
312
+ # Make one agent fail
313
+ mock_agents["zumbi"].execute.side_effect = Exception("Agent error")
314
+
315
+ with patch.object(master_agent, "_agent_registry", mock_agents):
316
+ message = AgentMessage(
317
+ sender="user",
318
+ recipient="abaporu",
319
+ action="investigate",
320
+ payload={"query": "Test investigation"}
321
+ )
322
+
323
+ response = await master_agent.execute(message)
324
+
325
+ # Should still complete with partial results
326
+ assert response.status == AgentStatus.COMPLETED
327
+ assert "patterns" in response.result # From analyst
328
+ assert response.result.get("errors") is not None
329
+
330
+
331
+ class TestSemanticRouter:
332
+ """Test semantic router agent."""
333
+
334
+ @pytest.fixture
335
+ def router(self):
336
+ """Create semantic router."""
337
+ return SemanticRouterAgent()
338
+
339
+ @pytest.mark.asyncio
340
+ async def test_rule_based_routing(self, router):
341
+ """Test rule-based routing."""
342
+ # Query about anomalies should route to Zumbi
343
+ message = AgentMessage(
344
+ sender="user",
345
+ recipient="ayrton_senna",
346
+ action="route",
347
+ payload={"query": "Find anomalies in contracts"}
348
+ )
349
+
350
+ response = await router.execute(message)
351
+
352
+ assert response.status == AgentStatus.COMPLETED
353
+ assert response.result["target_agent"] == "zumbi"
354
+ assert response.result["confidence"] > 0.8
355
+
356
+ @pytest.mark.asyncio
357
+ async def test_semantic_routing_fallback(self, router):
358
+ """Test semantic routing for complex queries."""
359
+ # Complex query requiring semantic understanding
360
+ message = AgentMessage(
361
+ sender="user",
362
+ recipient="ayrton_senna",
363
+ action="route",
364
+ payload={
365
+ "query": "I need a comprehensive analysis of spending patterns"
366
+ }
367
+ )
368
+
369
+ with patch.object(router, "_semantic_route") as mock_semantic:
370
+ mock_semantic.return_value = ("anita", 0.85)
371
+
372
+ response = await router.execute(message)
373
+
374
+ assert response.result["target_agent"] == "anita"
375
+ assert response.result["routing_method"] == "semantic"
376
+
377
+
378
+ class TestAgentOrchestrator:
379
+ """Test agent orchestrator."""
380
+
381
+ @pytest.fixture
382
+ def orchestrator(self):
383
+ """Create orchestrator."""
384
+ return AgentOrchestrator()
385
+
386
+ @pytest.mark.asyncio
387
+ async def test_orchestrate_investigation(self, orchestrator):
388
+ """Test orchestrating full investigation."""
389
+ # Create investigation context
390
+ context = AgentContext(
391
+ investigation_id="orch-test-123",
392
+ parameters={
393
+ "entity": "Ministry of Health",
394
+ "year": 2024
395
+ }
396
+ )
397
+
398
+ # Mock agent pool
399
+ mock_pool = MagicMock()
400
+ orchestrator.agent_pool = mock_pool
401
+
402
+ # Mock agent responses
403
+ async def mock_get_agent(agent_type):
404
+ if agent_type == "investigator":
405
+ agent = AsyncMock()
406
+ agent.execute.return_value = AgentResponse(
407
+ agent_id="zumbi",
408
+ status=AgentStatus.COMPLETED,
409
+ result={"anomalies": [{"type": "price"}]}
410
+ )
411
+ return agent
412
+ elif agent_type == "analyst":
413
+ agent = AsyncMock()
414
+ agent.execute.return_value = AgentResponse(
415
+ agent_id="anita",
416
+ status=AgentStatus.COMPLETED,
417
+ result={"patterns": [{"type": "seasonal"}]}
418
+ )
419
+ return agent
420
+
421
+ mock_pool.get_agent = mock_get_agent
422
+
423
+ # Execute orchestration
424
+ result = await orchestrator.orchestrate_investigation(context)
425
+
426
+ assert result["status"] == "completed"
427
+ assert "anomalies" in result
428
+ assert "patterns" in result
429
+ assert len(result["execution_timeline"]) > 0
430
+
431
+ @pytest.mark.asyncio
432
+ async def test_parallel_agent_execution(self, orchestrator):
433
+ """Test parallel execution of independent agents."""
434
+ agents = []
435
+ execution_order = []
436
+
437
+ # Create mock agents that track execution order
438
+ for i in range(3):
439
+ agent = AsyncMock()
440
+ agent.agent_id = f"agent_{i}"
441
+
442
+ async def execute_with_delay(msg, agent_id=i):
443
+ execution_order.append(f"start_{agent_id}")
444
+ await asyncio.sleep(0.1) # Simulate work
445
+ execution_order.append(f"end_{agent_id}")
446
+ return AgentResponse(
447
+ agent_id=f"agent_{agent_id}",
448
+ status=AgentStatus.COMPLETED,
449
+ result={"data": f"result_{agent_id}"}
450
+ )
451
+
452
+ agent.execute = execute_with_delay
453
+ agents.append(agent)
454
+
455
+ # Execute in parallel
456
+ messages = [
457
+ AgentMessage(sender="orch", recipient=f"agent_{i}", action="test")
458
+ for i in range(3)
459
+ ]
460
+
461
+ results = await orchestrator.execute_parallel(agents, messages)
462
+
463
+ # All should complete
464
+ assert len(results) == 3
465
+ assert all(r.status == AgentStatus.COMPLETED for r in results)
466
+
467
+ # Verify parallel execution (starts should be close together)
468
+ start_indices = [execution_order.index(f"start_{i}") for i in range(3)]
469
+ assert max(start_indices) - min(start_indices) < 3
470
+
471
+
472
+ class TestAgentPool:
473
+ """Test agent pool management."""
474
+
475
+ @pytest.fixture
476
+ def agent_pool(self):
477
+ """Create agent pool."""
478
+ return AgentPool(min_agents=2, max_agents=10)
479
+
480
+ @pytest.mark.asyncio
481
+ async def test_agent_pool_initialization(self, agent_pool):
482
+ """Test agent pool initializes with minimum agents."""
483
+ await agent_pool.initialize()
484
+
485
+ # Should have minimum agents of each type
486
+ assert len(agent_pool._agents["investigator"]) >= 2
487
+ assert len(agent_pool._agents["analyst"]) >= 2
488
+ assert len(agent_pool._agents["reporter"]) >= 2
489
+
490
+ @pytest.mark.asyncio
491
+ async def test_agent_pool_scaling(self, agent_pool):
492
+ """Test agent pool scales with demand."""
493
+ await agent_pool.initialize()
494
+
495
+ # Request many agents to trigger scaling
496
+ agents = []
497
+ for _ in range(5):
498
+ agent = await agent_pool.get_agent("investigator")
499
+ agents.append(agent)
500
+
501
+ # Pool should have scaled up
502
+ assert len(agent_pool._agents["investigator"]) > 2
503
+
504
+ # Return agents
505
+ for agent in agents:
506
+ await agent_pool.return_agent(agent)
507
+
508
+ @pytest.mark.asyncio
509
+ async def test_agent_health_monitoring(self, agent_pool):
510
+ """Test agent health monitoring and recovery."""
511
+ await agent_pool.initialize()
512
+
513
+ # Get an agent and simulate failure
514
+ agent = await agent_pool.get_agent("investigator")
515
+ agent.status = AgentStatus.ERROR
516
+ agent.metrics["errors"] = 5
517
+
518
+ # Return unhealthy agent
519
+ await agent_pool.return_agent(agent)
520
+
521
+ # Pool should detect and replace unhealthy agent
522
+ await agent_pool.health_check()
523
+
524
+ # Verify unhealthy agent was replaced
525
+ all_healthy = all(
526
+ a.status != AgentStatus.ERROR
527
+ for a in agent_pool._agents["investigator"]
528
+ )
529
+ assert all_healthy
530
+
531
+
532
+ class TestMultiAgentScenarios:
533
+ """Test complex multi-agent scenarios."""
534
+
535
+ @pytest.mark.asyncio
536
+ async def test_end_to_end_investigation(self):
537
+ """Test complete investigation flow."""
538
+ # Create orchestrator with real agents
539
+ orchestrator = AgentOrchestrator()
540
+ await orchestrator.initialize()
541
+
542
+ # Create investigation context
543
+ context = AgentContext(
544
+ investigation_id="e2e-test",
545
+ parameters={
546
+ "query": "Analyze health ministry contracts for anomalies",
547
+ "entity": "26000", # Health ministry code
548
+ "year": 2024,
549
+ "threshold": 100000
550
+ }
551
+ )
552
+
553
+ # Mock external API calls
554
+ with patch("src.tools.transparency_api.TransparencyAPIClient.get_contracts") as mock_api:
555
+ mock_api.return_value = {
556
+ "data": [
557
+ {
558
+ "id": "contract-1",
559
+ "valor": 500000,
560
+ "fornecedor": "Company A",
561
+ "data_assinatura": "2024-01-15"
562
+ },
563
+ {
564
+ "id": "contract-2",
565
+ "valor": 2000000, # Anomaly
566
+ "fornecedor": "Company B",
567
+ "data_assinatura": "2024-01-16"
568
+ }
569
+ ]
570
+ }
571
+
572
+ # Run investigation
573
+ result = await orchestrator.run_investigation(context)
574
+
575
+ assert result["status"] == "completed"
576
+ assert len(result["anomalies"]) > 0
577
+ assert result["summary"] is not None
578
+ assert result["confidence_score"] > 0.5
579
+
580
+ @pytest.mark.asyncio
581
+ async def test_agent_failure_recovery(self):
582
+ """Test system recovery from agent failures."""
583
+ orchestrator = AgentOrchestrator()
584
+
585
+ # Inject failing agent
586
+ class FailingAgent(BaseAgent):
587
+ async def process(self, message):
588
+ raise Exception("Simulated failure")
589
+
590
+ failing_agent = FailingAgent(
591
+ agent_id="failing_agent",
592
+ name="Failing Agent"
593
+ )
594
+
595
+ # Replace one agent with failing version
596
+ orchestrator._agent_registry["investigator"] = failing_agent
597
+
598
+ context = AgentContext(investigation_id="failure-test")
599
+
600
+ # Should handle failure gracefully
601
+ result = await orchestrator.orchestrate_investigation(context)
602
+
603
+ assert result["status"] in ["partial", "completed"]
604
+ assert "errors" in result
605
+ assert len(result["errors"]) > 0