Akashmj22122002 commited on
Commit
14edff4
·
verified ·
1 Parent(s): 0d46d1f

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +3 -9
  2. agents/__init__.py +319 -0
  3. agents/__pycache__/__init__.cpython-312.pyc +0 -0
  4. agents/__pycache__/_config.cpython-312.pyc +0 -0
  5. agents/__pycache__/_debug.cpython-312.pyc +0 -0
  6. agents/__pycache__/_run_impl.cpython-312.pyc +0 -0
  7. agents/__pycache__/agent.cpython-312.pyc +0 -0
  8. agents/__pycache__/agent_output.cpython-312.pyc +0 -0
  9. agents/__pycache__/computer.cpython-312.pyc +0 -0
  10. agents/__pycache__/exceptions.cpython-312.pyc +0 -0
  11. agents/__pycache__/function_schema.cpython-312.pyc +0 -0
  12. agents/__pycache__/guardrail.cpython-312.pyc +0 -0
  13. agents/__pycache__/handoffs.cpython-312.pyc +0 -0
  14. agents/__pycache__/items.cpython-312.pyc +0 -0
  15. agents/__pycache__/lifecycle.cpython-312.pyc +0 -0
  16. agents/__pycache__/logger.cpython-312.pyc +0 -0
  17. agents/__pycache__/model_settings.cpython-312.pyc +0 -0
  18. agents/__pycache__/prompts.cpython-312.pyc +0 -0
  19. agents/__pycache__/repl.cpython-312.pyc +0 -0
  20. agents/__pycache__/result.cpython-312.pyc +0 -0
  21. agents/__pycache__/run.cpython-312.pyc +0 -0
  22. agents/__pycache__/run_context.cpython-312.pyc +0 -0
  23. agents/__pycache__/stream_events.cpython-312.pyc +0 -0
  24. agents/__pycache__/strict_schema.cpython-312.pyc +0 -0
  25. agents/__pycache__/tool.cpython-312.pyc +0 -0
  26. agents/__pycache__/tool_context.cpython-312.pyc +0 -0
  27. agents/__pycache__/tool_guardrails.cpython-312.pyc +0 -0
  28. agents/__pycache__/usage.cpython-312.pyc +0 -0
  29. agents/__pycache__/version.cpython-312.pyc +0 -0
  30. agents/_config.py +26 -0
  31. agents/_debug.py +28 -0
  32. agents/_run_impl.py +1442 -0
  33. agents/agent.py +476 -0
  34. agents/agent_output.py +194 -0
  35. agents/computer.py +107 -0
  36. agents/exceptions.py +131 -0
  37. agents/extensions/__init__.py +0 -0
  38. agents/extensions/handoff_filters.py +70 -0
  39. agents/extensions/handoff_prompt.py +19 -0
  40. agents/extensions/memory/__init__.py +65 -0
  41. agents/extensions/memory/advanced_sqlite_session.py +1285 -0
  42. agents/extensions/memory/encrypt_session.py +185 -0
  43. agents/extensions/memory/redis_session.py +267 -0
  44. agents/extensions/memory/sqlalchemy_session.py +312 -0
  45. agents/extensions/models/__init__.py +0 -0
  46. agents/extensions/models/litellm_model.py +601 -0
  47. agents/extensions/models/litellm_provider.py +23 -0
  48. agents/extensions/visualization.py +165 -0
  49. agents/function_schema.py +398 -0
  50. agents/guardrail.py +329 -0
README.md CHANGED
@@ -1,12 +1,6 @@
1
  ---
2
- title: Deep Research-personal
3
- emoji: 🔥
4
- colorFrom: gray
5
- colorTo: blue
6
- sdk: gradio
7
- sdk_version: 6.1.0
8
  app_file: app.py
9
- pinned: false
 
10
  ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: deep_research-personal
 
 
 
 
 
3
  app_file: app.py
4
+ sdk: gradio
5
+ sdk_version: 6.0.2
6
  ---
 
 
agents/__init__.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import sys
3
+ from typing import Literal
4
+
5
+ from openai import AsyncOpenAI
6
+
7
+ from . import _config
8
+ from .agent import (
9
+ Agent,
10
+ AgentBase,
11
+ StopAtTools,
12
+ ToolsToFinalOutputFunction,
13
+ ToolsToFinalOutputResult,
14
+ )
15
+ from .agent_output import AgentOutputSchema, AgentOutputSchemaBase
16
+ from .computer import AsyncComputer, Button, Computer, Environment
17
+ from .exceptions import (
18
+ AgentsException,
19
+ InputGuardrailTripwireTriggered,
20
+ MaxTurnsExceeded,
21
+ ModelBehaviorError,
22
+ OutputGuardrailTripwireTriggered,
23
+ RunErrorDetails,
24
+ ToolInputGuardrailTripwireTriggered,
25
+ ToolOutputGuardrailTripwireTriggered,
26
+ UserError,
27
+ )
28
+ from .guardrail import (
29
+ GuardrailFunctionOutput,
30
+ InputGuardrail,
31
+ InputGuardrailResult,
32
+ OutputGuardrail,
33
+ OutputGuardrailResult,
34
+ input_guardrail,
35
+ output_guardrail,
36
+ )
37
+ from .handoffs import Handoff, HandoffInputData, HandoffInputFilter, handoff
38
+ from .items import (
39
+ HandoffCallItem,
40
+ HandoffOutputItem,
41
+ ItemHelpers,
42
+ MessageOutputItem,
43
+ ModelResponse,
44
+ ReasoningItem,
45
+ RunItem,
46
+ ToolCallItem,
47
+ ToolCallOutputItem,
48
+ TResponseInputItem,
49
+ )
50
+ from .lifecycle import AgentHooks, RunHooks
51
+ from .memory import OpenAIConversationsSession, Session, SessionABC, SQLiteSession
52
+ from .model_settings import ModelSettings
53
+ from .models.interface import Model, ModelProvider, ModelTracing
54
+ from .models.multi_provider import MultiProvider
55
+ from .models.openai_chatcompletions import OpenAIChatCompletionsModel
56
+ from .models.openai_provider import OpenAIProvider
57
+ from .models.openai_responses import OpenAIResponsesModel
58
+ from .prompts import DynamicPromptFunction, GenerateDynamicPromptData, Prompt
59
+ from .repl import run_demo_loop
60
+ from .result import RunResult, RunResultStreaming
61
+ from .run import RunConfig, Runner
62
+ from .run_context import RunContextWrapper, TContext
63
+ from .stream_events import (
64
+ AgentUpdatedStreamEvent,
65
+ RawResponsesStreamEvent,
66
+ RunItemStreamEvent,
67
+ StreamEvent,
68
+ )
69
+ from .tool import (
70
+ CodeInterpreterTool,
71
+ ComputerTool,
72
+ FileSearchTool,
73
+ FunctionTool,
74
+ FunctionToolResult,
75
+ HostedMCPTool,
76
+ ImageGenerationTool,
77
+ LocalShellCommandRequest,
78
+ LocalShellExecutor,
79
+ LocalShellTool,
80
+ MCPToolApprovalFunction,
81
+ MCPToolApprovalFunctionResult,
82
+ MCPToolApprovalRequest,
83
+ Tool,
84
+ WebSearchTool,
85
+ default_tool_error_function,
86
+ function_tool,
87
+ )
88
+ from .tool_guardrails import (
89
+ ToolGuardrailFunctionOutput,
90
+ ToolInputGuardrail,
91
+ ToolInputGuardrailData,
92
+ ToolInputGuardrailResult,
93
+ ToolOutputGuardrail,
94
+ ToolOutputGuardrailData,
95
+ ToolOutputGuardrailResult,
96
+ tool_input_guardrail,
97
+ tool_output_guardrail,
98
+ )
99
+ from .tracing import (
100
+ AgentSpanData,
101
+ CustomSpanData,
102
+ FunctionSpanData,
103
+ GenerationSpanData,
104
+ GuardrailSpanData,
105
+ HandoffSpanData,
106
+ MCPListToolsSpanData,
107
+ Span,
108
+ SpanData,
109
+ SpanError,
110
+ SpeechGroupSpanData,
111
+ SpeechSpanData,
112
+ Trace,
113
+ TracingProcessor,
114
+ TranscriptionSpanData,
115
+ add_trace_processor,
116
+ agent_span,
117
+ custom_span,
118
+ function_span,
119
+ gen_span_id,
120
+ gen_trace_id,
121
+ generation_span,
122
+ get_current_span,
123
+ get_current_trace,
124
+ guardrail_span,
125
+ handoff_span,
126
+ mcp_tools_span,
127
+ set_trace_processors,
128
+ set_trace_provider,
129
+ set_tracing_disabled,
130
+ set_tracing_export_api_key,
131
+ speech_group_span,
132
+ speech_span,
133
+ trace,
134
+ transcription_span,
135
+ )
136
+ from .usage import Usage
137
+ from .version import __version__
138
+
139
+
140
+ def set_default_openai_key(key: str, use_for_tracing: bool = True) -> None:
141
+ """Set the default OpenAI API key to use for LLM requests (and optionally tracing()). This is
142
+ only necessary if the OPENAI_API_KEY environment variable is not already set.
143
+
144
+ If provided, this key will be used instead of the OPENAI_API_KEY environment variable.
145
+
146
+ Args:
147
+ key: The OpenAI key to use.
148
+ use_for_tracing: Whether to also use this key to send traces to OpenAI. Defaults to True
149
+ If False, you'll either need to set the OPENAI_API_KEY environment variable or call
150
+ set_tracing_export_api_key() with the API key you want to use for tracing.
151
+ """
152
+ _config.set_default_openai_key(key, use_for_tracing)
153
+
154
+
155
+ def set_default_openai_client(client: AsyncOpenAI, use_for_tracing: bool = True) -> None:
156
+ """Set the default OpenAI client to use for LLM requests and/or tracing. If provided, this
157
+ client will be used instead of the default OpenAI client.
158
+
159
+ Args:
160
+ client: The OpenAI client to use.
161
+ use_for_tracing: Whether to use the API key from this client for uploading traces. If False,
162
+ you'll either need to set the OPENAI_API_KEY environment variable or call
163
+ set_tracing_export_api_key() with the API key you want to use for tracing.
164
+ """
165
+ _config.set_default_openai_client(client, use_for_tracing)
166
+
167
+
168
+ def set_default_openai_api(api: Literal["chat_completions", "responses"]) -> None:
169
+ """Set the default API to use for OpenAI LLM requests. By default, we will use the responses API
170
+ but you can set this to use the chat completions API instead.
171
+ """
172
+ _config.set_default_openai_api(api)
173
+
174
+
175
+ def enable_verbose_stdout_logging():
176
+ """Enables verbose logging to stdout. This is useful for debugging."""
177
+ logger = logging.getLogger("openai.agents")
178
+ logger.setLevel(logging.DEBUG)
179
+ logger.addHandler(logging.StreamHandler(sys.stdout))
180
+
181
+
182
+ __all__ = [
183
+ "Agent",
184
+ "AgentBase",
185
+ "StopAtTools",
186
+ "ToolsToFinalOutputFunction",
187
+ "ToolsToFinalOutputResult",
188
+ "Runner",
189
+ "run_demo_loop",
190
+ "Model",
191
+ "ModelProvider",
192
+ "ModelTracing",
193
+ "ModelSettings",
194
+ "OpenAIChatCompletionsModel",
195
+ "MultiProvider",
196
+ "OpenAIProvider",
197
+ "OpenAIResponsesModel",
198
+ "AgentOutputSchema",
199
+ "AgentOutputSchemaBase",
200
+ "Computer",
201
+ "AsyncComputer",
202
+ "Environment",
203
+ "Button",
204
+ "AgentsException",
205
+ "InputGuardrailTripwireTriggered",
206
+ "OutputGuardrailTripwireTriggered",
207
+ "ToolInputGuardrailTripwireTriggered",
208
+ "ToolOutputGuardrailTripwireTriggered",
209
+ "DynamicPromptFunction",
210
+ "GenerateDynamicPromptData",
211
+ "Prompt",
212
+ "MaxTurnsExceeded",
213
+ "ModelBehaviorError",
214
+ "UserError",
215
+ "InputGuardrail",
216
+ "InputGuardrailResult",
217
+ "OutputGuardrail",
218
+ "OutputGuardrailResult",
219
+ "GuardrailFunctionOutput",
220
+ "input_guardrail",
221
+ "output_guardrail",
222
+ "ToolInputGuardrail",
223
+ "ToolOutputGuardrail",
224
+ "ToolGuardrailFunctionOutput",
225
+ "ToolInputGuardrailData",
226
+ "ToolInputGuardrailResult",
227
+ "ToolOutputGuardrailData",
228
+ "ToolOutputGuardrailResult",
229
+ "tool_input_guardrail",
230
+ "tool_output_guardrail",
231
+ "handoff",
232
+ "Handoff",
233
+ "HandoffInputData",
234
+ "HandoffInputFilter",
235
+ "TResponseInputItem",
236
+ "MessageOutputItem",
237
+ "ModelResponse",
238
+ "RunItem",
239
+ "HandoffCallItem",
240
+ "HandoffOutputItem",
241
+ "ToolCallItem",
242
+ "ToolCallOutputItem",
243
+ "ReasoningItem",
244
+ "ItemHelpers",
245
+ "RunHooks",
246
+ "AgentHooks",
247
+ "Session",
248
+ "SessionABC",
249
+ "SQLiteSession",
250
+ "OpenAIConversationsSession",
251
+ "RunContextWrapper",
252
+ "TContext",
253
+ "RunErrorDetails",
254
+ "RunResult",
255
+ "RunResultStreaming",
256
+ "RunConfig",
257
+ "RawResponsesStreamEvent",
258
+ "RunItemStreamEvent",
259
+ "AgentUpdatedStreamEvent",
260
+ "StreamEvent",
261
+ "FunctionTool",
262
+ "FunctionToolResult",
263
+ "ComputerTool",
264
+ "FileSearchTool",
265
+ "CodeInterpreterTool",
266
+ "ImageGenerationTool",
267
+ "LocalShellCommandRequest",
268
+ "LocalShellExecutor",
269
+ "LocalShellTool",
270
+ "Tool",
271
+ "WebSearchTool",
272
+ "HostedMCPTool",
273
+ "MCPToolApprovalFunction",
274
+ "MCPToolApprovalRequest",
275
+ "MCPToolApprovalFunctionResult",
276
+ "function_tool",
277
+ "Usage",
278
+ "add_trace_processor",
279
+ "agent_span",
280
+ "custom_span",
281
+ "function_span",
282
+ "generation_span",
283
+ "get_current_span",
284
+ "get_current_trace",
285
+ "guardrail_span",
286
+ "handoff_span",
287
+ "set_trace_processors",
288
+ "set_trace_provider",
289
+ "set_tracing_disabled",
290
+ "speech_group_span",
291
+ "transcription_span",
292
+ "speech_span",
293
+ "mcp_tools_span",
294
+ "trace",
295
+ "Trace",
296
+ "TracingProcessor",
297
+ "SpanError",
298
+ "Span",
299
+ "SpanData",
300
+ "AgentSpanData",
301
+ "CustomSpanData",
302
+ "FunctionSpanData",
303
+ "GenerationSpanData",
304
+ "GuardrailSpanData",
305
+ "HandoffSpanData",
306
+ "SpeechGroupSpanData",
307
+ "SpeechSpanData",
308
+ "MCPListToolsSpanData",
309
+ "TranscriptionSpanData",
310
+ "set_default_openai_key",
311
+ "set_default_openai_client",
312
+ "set_default_openai_api",
313
+ "set_tracing_export_api_key",
314
+ "enable_verbose_stdout_logging",
315
+ "gen_trace_id",
316
+ "gen_span_id",
317
+ "default_tool_error_function",
318
+ "__version__",
319
+ ]
agents/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (7.96 kB). View file
 
agents/__pycache__/_config.cpython-312.pyc ADDED
Binary file (1.36 kB). View file
 
agents/__pycache__/_debug.cpython-312.pyc ADDED
Binary file (1.1 kB). View file
 
agents/__pycache__/_run_impl.cpython-312.pyc ADDED
Binary file (54.7 kB). View file
 
agents/__pycache__/agent.cpython-312.pyc ADDED
Binary file (18.7 kB). View file
 
agents/__pycache__/agent_output.cpython-312.pyc ADDED
Binary file (8.27 kB). View file
 
agents/__pycache__/computer.cpython-312.pyc ADDED
Binary file (5.74 kB). View file
 
agents/__pycache__/exceptions.cpython-312.pyc ADDED
Binary file (6.27 kB). View file
 
agents/__pycache__/function_schema.cpython-312.pyc ADDED
Binary file (13.9 kB). View file
 
agents/__pycache__/guardrail.cpython-312.pyc ADDED
Binary file (10 kB). View file
 
agents/__pycache__/handoffs.cpython-312.pyc ADDED
Binary file (10.6 kB). View file
 
agents/__pycache__/items.cpython-312.pyc ADDED
Binary file (10.6 kB). View file
 
agents/__pycache__/lifecycle.cpython-312.pyc ADDED
Binary file (5.49 kB). View file
 
agents/__pycache__/logger.cpython-312.pyc ADDED
Binary file (271 Bytes). View file
 
agents/__pycache__/model_settings.cpython-312.pyc ADDED
Binary file (6.23 kB). View file
 
agents/__pycache__/prompts.cpython-312.pyc ADDED
Binary file (2.76 kB). View file
 
agents/__pycache__/repl.cpython-312.pyc ADDED
Binary file (3.44 kB). View file
 
agents/__pycache__/result.cpython-312.pyc ADDED
Binary file (14.3 kB). View file
 
agents/__pycache__/run.cpython-312.pyc ADDED
Binary file (57.2 kB). View file
 
agents/__pycache__/run_context.cpython-312.pyc ADDED
Binary file (1.12 kB). View file
 
agents/__pycache__/stream_events.cpython-312.pyc ADDED
Binary file (2.15 kB). View file
 
agents/__pycache__/strict_schema.cpython-312.pyc ADDED
Binary file (5.98 kB). View file
 
agents/__pycache__/tool.cpython-312.pyc ADDED
Binary file (18.3 kB). View file
 
agents/__pycache__/tool_context.cpython-312.pyc ADDED
Binary file (2.57 kB). View file
 
agents/__pycache__/tool_guardrails.cpython-312.pyc ADDED
Binary file (10.4 kB). View file
 
agents/__pycache__/usage.cpython-312.pyc ADDED
Binary file (2.31 kB). View file
 
agents/__pycache__/version.cpython-312.pyc ADDED
Binary file (477 Bytes). View file
 
agents/_config.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from openai import AsyncOpenAI
2
+ from typing_extensions import Literal
3
+
4
+ from .models import _openai_shared
5
+ from .tracing import set_tracing_export_api_key
6
+
7
+
8
+ def set_default_openai_key(key: str, use_for_tracing: bool) -> None:
9
+ _openai_shared.set_default_openai_key(key)
10
+
11
+ if use_for_tracing:
12
+ set_tracing_export_api_key(key)
13
+
14
+
15
+ def set_default_openai_client(client: AsyncOpenAI, use_for_tracing: bool) -> None:
16
+ _openai_shared.set_default_openai_client(client)
17
+
18
+ if use_for_tracing:
19
+ set_tracing_export_api_key(client.api_key)
20
+
21
+
22
+ def set_default_openai_api(api: Literal["chat_completions", "responses"]) -> None:
23
+ if api == "chat_completions":
24
+ _openai_shared.set_use_responses_by_default(False)
25
+ else:
26
+ _openai_shared.set_use_responses_by_default(True)
agents/_debug.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+
4
+ def _debug_flag_enabled(flag: str, default: bool = False) -> bool:
5
+ flag_value = os.getenv(flag)
6
+ if flag_value is None:
7
+ return default
8
+ else:
9
+ return flag_value == "1" or flag_value.lower() == "true"
10
+
11
+
12
+ def _load_dont_log_model_data() -> bool:
13
+ return _debug_flag_enabled("OPENAI_AGENTS_DONT_LOG_MODEL_DATA", default=True)
14
+
15
+
16
+ def _load_dont_log_tool_data() -> bool:
17
+ return _debug_flag_enabled("OPENAI_AGENTS_DONT_LOG_TOOL_DATA", default=True)
18
+
19
+
20
+ DONT_LOG_MODEL_DATA = _load_dont_log_model_data()
21
+ """By default we don't log LLM inputs/outputs, to prevent exposing sensitive information. Set this
22
+ flag to enable logging them.
23
+ """
24
+
25
+ DONT_LOG_TOOL_DATA = _load_dont_log_tool_data()
26
+ """By default we don't log tool call inputs/outputs, to prevent exposing sensitive information. Set
27
+ this flag to enable logging them.
28
+ """
agents/_run_impl.py ADDED
@@ -0,0 +1,1442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import dataclasses
5
+ import inspect
6
+ from collections.abc import Awaitable
7
+ from dataclasses import dataclass, field
8
+ from typing import TYPE_CHECKING, Any, cast
9
+
10
+ from openai.types.responses import (
11
+ ResponseComputerToolCall,
12
+ ResponseFileSearchToolCall,
13
+ ResponseFunctionToolCall,
14
+ ResponseFunctionWebSearch,
15
+ ResponseOutputMessage,
16
+ )
17
+ from openai.types.responses.response_code_interpreter_tool_call import (
18
+ ResponseCodeInterpreterToolCall,
19
+ )
20
+ from openai.types.responses.response_computer_tool_call import (
21
+ ActionClick,
22
+ ActionDoubleClick,
23
+ ActionDrag,
24
+ ActionKeypress,
25
+ ActionMove,
26
+ ActionScreenshot,
27
+ ActionScroll,
28
+ ActionType,
29
+ ActionWait,
30
+ )
31
+ from openai.types.responses.response_input_item_param import (
32
+ ComputerCallOutputAcknowledgedSafetyCheck,
33
+ )
34
+ from openai.types.responses.response_input_param import ComputerCallOutput, McpApprovalResponse
35
+ from openai.types.responses.response_output_item import (
36
+ ImageGenerationCall,
37
+ LocalShellCall,
38
+ McpApprovalRequest,
39
+ McpCall,
40
+ McpListTools,
41
+ )
42
+ from openai.types.responses.response_reasoning_item import ResponseReasoningItem
43
+
44
+ from .agent import Agent, ToolsToFinalOutputResult
45
+ from .agent_output import AgentOutputSchemaBase
46
+ from .computer import AsyncComputer, Computer
47
+ from .exceptions import (
48
+ AgentsException,
49
+ ModelBehaviorError,
50
+ ToolInputGuardrailTripwireTriggered,
51
+ ToolOutputGuardrailTripwireTriggered,
52
+ UserError,
53
+ )
54
+ from .guardrail import InputGuardrail, InputGuardrailResult, OutputGuardrail, OutputGuardrailResult
55
+ from .handoffs import Handoff, HandoffInputData
56
+ from .items import (
57
+ HandoffCallItem,
58
+ HandoffOutputItem,
59
+ ItemHelpers,
60
+ MCPApprovalRequestItem,
61
+ MCPApprovalResponseItem,
62
+ MCPListToolsItem,
63
+ MessageOutputItem,
64
+ ModelResponse,
65
+ ReasoningItem,
66
+ RunItem,
67
+ ToolCallItem,
68
+ ToolCallOutputItem,
69
+ TResponseInputItem,
70
+ )
71
+ from .lifecycle import RunHooks
72
+ from .logger import logger
73
+ from .model_settings import ModelSettings
74
+ from .models.interface import ModelTracing
75
+ from .run_context import RunContextWrapper, TContext
76
+ from .stream_events import RunItemStreamEvent, StreamEvent
77
+ from .tool import (
78
+ ComputerTool,
79
+ ComputerToolSafetyCheckData,
80
+ FunctionTool,
81
+ FunctionToolResult,
82
+ HostedMCPTool,
83
+ LocalShellCommandRequest,
84
+ LocalShellTool,
85
+ MCPToolApprovalRequest,
86
+ Tool,
87
+ )
88
+ from .tool_context import ToolContext
89
+ from .tool_guardrails import (
90
+ ToolInputGuardrailData,
91
+ ToolInputGuardrailResult,
92
+ ToolOutputGuardrailData,
93
+ ToolOutputGuardrailResult,
94
+ )
95
+ from .tracing import (
96
+ SpanError,
97
+ Trace,
98
+ function_span,
99
+ get_current_trace,
100
+ guardrail_span,
101
+ handoff_span,
102
+ trace,
103
+ )
104
+ from .util import _coro, _error_tracing
105
+
106
+ if TYPE_CHECKING:
107
+ from .run import RunConfig
108
+
109
+
110
+ class QueueCompleteSentinel:
111
+ pass
112
+
113
+
114
+ QUEUE_COMPLETE_SENTINEL = QueueCompleteSentinel()
115
+
116
+ _NOT_FINAL_OUTPUT = ToolsToFinalOutputResult(is_final_output=False, final_output=None)
117
+
118
+
119
+ @dataclass
120
+ class AgentToolUseTracker:
121
+ agent_to_tools: list[tuple[Agent, list[str]]] = field(default_factory=list)
122
+ """Tuple of (agent, list of tools used). Can't use a dict because agents aren't hashable."""
123
+
124
+ def add_tool_use(self, agent: Agent[Any], tool_names: list[str]) -> None:
125
+ existing_data = next((item for item in self.agent_to_tools if item[0] == agent), None)
126
+ if existing_data:
127
+ existing_data[1].extend(tool_names)
128
+ else:
129
+ self.agent_to_tools.append((agent, tool_names))
130
+
131
+ def has_used_tools(self, agent: Agent[Any]) -> bool:
132
+ existing_data = next((item for item in self.agent_to_tools if item[0] == agent), None)
133
+ return existing_data is not None and len(existing_data[1]) > 0
134
+
135
+
136
+ @dataclass
137
+ class ToolRunHandoff:
138
+ handoff: Handoff
139
+ tool_call: ResponseFunctionToolCall
140
+
141
+
142
+ @dataclass
143
+ class ToolRunFunction:
144
+ tool_call: ResponseFunctionToolCall
145
+ function_tool: FunctionTool
146
+
147
+
148
+ @dataclass
149
+ class ToolRunComputerAction:
150
+ tool_call: ResponseComputerToolCall
151
+ computer_tool: ComputerTool
152
+
153
+
154
+ @dataclass
155
+ class ToolRunMCPApprovalRequest:
156
+ request_item: McpApprovalRequest
157
+ mcp_tool: HostedMCPTool
158
+
159
+
160
+ @dataclass
161
+ class ToolRunLocalShellCall:
162
+ tool_call: LocalShellCall
163
+ local_shell_tool: LocalShellTool
164
+
165
+
166
+ @dataclass
167
+ class ProcessedResponse:
168
+ new_items: list[RunItem]
169
+ handoffs: list[ToolRunHandoff]
170
+ functions: list[ToolRunFunction]
171
+ computer_actions: list[ToolRunComputerAction]
172
+ local_shell_calls: list[ToolRunLocalShellCall]
173
+ tools_used: list[str] # Names of all tools used, including hosted tools
174
+ mcp_approval_requests: list[ToolRunMCPApprovalRequest] # Only requests with callbacks
175
+
176
+ def has_tools_or_approvals_to_run(self) -> bool:
177
+ # Handoffs, functions and computer actions need local processing
178
+ # Hosted tools have already run, so there's nothing to do.
179
+ return any(
180
+ [
181
+ self.handoffs,
182
+ self.functions,
183
+ self.computer_actions,
184
+ self.local_shell_calls,
185
+ self.mcp_approval_requests,
186
+ ]
187
+ )
188
+
189
+
190
+ @dataclass
191
+ class NextStepHandoff:
192
+ new_agent: Agent[Any]
193
+
194
+
195
+ @dataclass
196
+ class NextStepFinalOutput:
197
+ output: Any
198
+
199
+
200
+ @dataclass
201
+ class NextStepRunAgain:
202
+ pass
203
+
204
+
205
+ @dataclass
206
+ class SingleStepResult:
207
+ original_input: str | list[TResponseInputItem]
208
+ """The input items i.e. the items before run() was called. May be mutated by handoff input
209
+ filters."""
210
+
211
+ model_response: ModelResponse
212
+ """The model response for the current step."""
213
+
214
+ pre_step_items: list[RunItem]
215
+ """Items generated before the current step."""
216
+
217
+ new_step_items: list[RunItem]
218
+ """Items generated during this current step."""
219
+
220
+ next_step: NextStepHandoff | NextStepFinalOutput | NextStepRunAgain
221
+ """The next step to take."""
222
+
223
+ tool_input_guardrail_results: list[ToolInputGuardrailResult]
224
+ """Tool input guardrail results from this step."""
225
+
226
+ tool_output_guardrail_results: list[ToolOutputGuardrailResult]
227
+ """Tool output guardrail results from this step."""
228
+
229
+ @property
230
+ def generated_items(self) -> list[RunItem]:
231
+ """Items generated during the agent run (i.e. everything generated after
232
+ `original_input`)."""
233
+ return self.pre_step_items + self.new_step_items
234
+
235
+
236
+ def get_model_tracing_impl(
237
+ tracing_disabled: bool, trace_include_sensitive_data: bool
238
+ ) -> ModelTracing:
239
+ if tracing_disabled:
240
+ return ModelTracing.DISABLED
241
+ elif trace_include_sensitive_data:
242
+ return ModelTracing.ENABLED
243
+ else:
244
+ return ModelTracing.ENABLED_WITHOUT_DATA
245
+
246
+
247
+ class RunImpl:
248
+ @classmethod
249
+ async def execute_tools_and_side_effects(
250
+ cls,
251
+ *,
252
+ agent: Agent[TContext],
253
+ # The original input to the Runner
254
+ original_input: str | list[TResponseInputItem],
255
+ # Everything generated by Runner since the original input, but before the current step
256
+ pre_step_items: list[RunItem],
257
+ new_response: ModelResponse,
258
+ processed_response: ProcessedResponse,
259
+ output_schema: AgentOutputSchemaBase | None,
260
+ hooks: RunHooks[TContext],
261
+ context_wrapper: RunContextWrapper[TContext],
262
+ run_config: RunConfig,
263
+ ) -> SingleStepResult:
264
+ # Make a copy of the generated items
265
+ pre_step_items = list(pre_step_items)
266
+
267
+ new_step_items: list[RunItem] = []
268
+ new_step_items.extend(processed_response.new_items)
269
+
270
+ # First, lets run the tool calls - function tools and computer actions
271
+ (
272
+ (function_results, tool_input_guardrail_results, tool_output_guardrail_results),
273
+ computer_results,
274
+ ) = await asyncio.gather(
275
+ cls.execute_function_tool_calls(
276
+ agent=agent,
277
+ tool_runs=processed_response.functions,
278
+ hooks=hooks,
279
+ context_wrapper=context_wrapper,
280
+ config=run_config,
281
+ ),
282
+ cls.execute_computer_actions(
283
+ agent=agent,
284
+ actions=processed_response.computer_actions,
285
+ hooks=hooks,
286
+ context_wrapper=context_wrapper,
287
+ config=run_config,
288
+ ),
289
+ )
290
+ new_step_items.extend([result.run_item for result in function_results])
291
+ new_step_items.extend(computer_results)
292
+
293
+ # Next, run the MCP approval requests
294
+ if processed_response.mcp_approval_requests:
295
+ approval_results = await cls.execute_mcp_approval_requests(
296
+ agent=agent,
297
+ approval_requests=processed_response.mcp_approval_requests,
298
+ context_wrapper=context_wrapper,
299
+ )
300
+ new_step_items.extend(approval_results)
301
+
302
+ # Next, check if there are any handoffs
303
+ if run_handoffs := processed_response.handoffs:
304
+ return await cls.execute_handoffs(
305
+ agent=agent,
306
+ original_input=original_input,
307
+ pre_step_items=pre_step_items,
308
+ new_step_items=new_step_items,
309
+ new_response=new_response,
310
+ run_handoffs=run_handoffs,
311
+ hooks=hooks,
312
+ context_wrapper=context_wrapper,
313
+ run_config=run_config,
314
+ )
315
+
316
+ # Next, we'll check if the tool use should result in a final output
317
+ check_tool_use = await cls._check_for_final_output_from_tools(
318
+ agent=agent,
319
+ tool_results=function_results,
320
+ context_wrapper=context_wrapper,
321
+ config=run_config,
322
+ )
323
+
324
+ if check_tool_use.is_final_output:
325
+ # If the output type is str, then let's just stringify it
326
+ if not agent.output_type or agent.output_type is str:
327
+ check_tool_use.final_output = str(check_tool_use.final_output)
328
+
329
+ if check_tool_use.final_output is None:
330
+ logger.error(
331
+ "Model returned a final output of None. Not raising an error because we assume"
332
+ "you know what you're doing."
333
+ )
334
+
335
+ return await cls.execute_final_output(
336
+ agent=agent,
337
+ original_input=original_input,
338
+ new_response=new_response,
339
+ pre_step_items=pre_step_items,
340
+ new_step_items=new_step_items,
341
+ final_output=check_tool_use.final_output,
342
+ hooks=hooks,
343
+ context_wrapper=context_wrapper,
344
+ tool_input_guardrail_results=tool_input_guardrail_results,
345
+ tool_output_guardrail_results=tool_output_guardrail_results,
346
+ )
347
+
348
+ # Now we can check if the model also produced a final output
349
+ message_items = [item for item in new_step_items if isinstance(item, MessageOutputItem)]
350
+
351
+ # We'll use the last content output as the final output
352
+ potential_final_output_text = (
353
+ ItemHelpers.extract_last_text(message_items[-1].raw_item) if message_items else None
354
+ )
355
+
356
+ # Generate final output only when there are no pending tool calls or approval requests.
357
+ if not processed_response.has_tools_or_approvals_to_run():
358
+ if output_schema and not output_schema.is_plain_text() and potential_final_output_text:
359
+ final_output = output_schema.validate_json(potential_final_output_text)
360
+ return await cls.execute_final_output(
361
+ agent=agent,
362
+ original_input=original_input,
363
+ new_response=new_response,
364
+ pre_step_items=pre_step_items,
365
+ new_step_items=new_step_items,
366
+ final_output=final_output,
367
+ hooks=hooks,
368
+ context_wrapper=context_wrapper,
369
+ tool_input_guardrail_results=tool_input_guardrail_results,
370
+ tool_output_guardrail_results=tool_output_guardrail_results,
371
+ )
372
+ elif not output_schema or output_schema.is_plain_text():
373
+ return await cls.execute_final_output(
374
+ agent=agent,
375
+ original_input=original_input,
376
+ new_response=new_response,
377
+ pre_step_items=pre_step_items,
378
+ new_step_items=new_step_items,
379
+ final_output=potential_final_output_text or "",
380
+ hooks=hooks,
381
+ context_wrapper=context_wrapper,
382
+ tool_input_guardrail_results=tool_input_guardrail_results,
383
+ tool_output_guardrail_results=tool_output_guardrail_results,
384
+ )
385
+
386
+ # If there's no final output, we can just run again
387
+ return SingleStepResult(
388
+ original_input=original_input,
389
+ model_response=new_response,
390
+ pre_step_items=pre_step_items,
391
+ new_step_items=new_step_items,
392
+ next_step=NextStepRunAgain(),
393
+ tool_input_guardrail_results=tool_input_guardrail_results,
394
+ tool_output_guardrail_results=tool_output_guardrail_results,
395
+ )
396
+
397
+ @classmethod
398
+ def maybe_reset_tool_choice(
399
+ cls, agent: Agent[Any], tool_use_tracker: AgentToolUseTracker, model_settings: ModelSettings
400
+ ) -> ModelSettings:
401
+ """Resets tool choice to None if the agent has used tools and the agent's reset_tool_choice
402
+ flag is True."""
403
+
404
+ if agent.reset_tool_choice is True and tool_use_tracker.has_used_tools(agent):
405
+ return dataclasses.replace(model_settings, tool_choice=None)
406
+
407
+ return model_settings
408
+
409
+ @classmethod
410
+ def process_model_response(
411
+ cls,
412
+ *,
413
+ agent: Agent[Any],
414
+ all_tools: list[Tool],
415
+ response: ModelResponse,
416
+ output_schema: AgentOutputSchemaBase | None,
417
+ handoffs: list[Handoff],
418
+ ) -> ProcessedResponse:
419
+ items: list[RunItem] = []
420
+
421
+ run_handoffs = []
422
+ functions = []
423
+ computer_actions = []
424
+ local_shell_calls = []
425
+ mcp_approval_requests = []
426
+ tools_used: list[str] = []
427
+ handoff_map = {handoff.tool_name: handoff for handoff in handoffs}
428
+ function_map = {tool.name: tool for tool in all_tools if isinstance(tool, FunctionTool)}
429
+ computer_tool = next((tool for tool in all_tools if isinstance(tool, ComputerTool)), None)
430
+ local_shell_tool = next(
431
+ (tool for tool in all_tools if isinstance(tool, LocalShellTool)), None
432
+ )
433
+ hosted_mcp_server_map = {
434
+ tool.tool_config["server_label"]: tool
435
+ for tool in all_tools
436
+ if isinstance(tool, HostedMCPTool)
437
+ }
438
+
439
+ for output in response.output:
440
+ if isinstance(output, ResponseOutputMessage):
441
+ items.append(MessageOutputItem(raw_item=output, agent=agent))
442
+ elif isinstance(output, ResponseFileSearchToolCall):
443
+ items.append(ToolCallItem(raw_item=output, agent=agent))
444
+ tools_used.append("file_search")
445
+ elif isinstance(output, ResponseFunctionWebSearch):
446
+ items.append(ToolCallItem(raw_item=output, agent=agent))
447
+ tools_used.append("web_search")
448
+ elif isinstance(output, ResponseReasoningItem):
449
+ items.append(ReasoningItem(raw_item=output, agent=agent))
450
+ elif isinstance(output, ResponseComputerToolCall):
451
+ items.append(ToolCallItem(raw_item=output, agent=agent))
452
+ tools_used.append("computer_use")
453
+ if not computer_tool:
454
+ _error_tracing.attach_error_to_current_span(
455
+ SpanError(
456
+ message="Computer tool not found",
457
+ data={},
458
+ )
459
+ )
460
+ raise ModelBehaviorError(
461
+ "Model produced computer action without a computer tool."
462
+ )
463
+ computer_actions.append(
464
+ ToolRunComputerAction(tool_call=output, computer_tool=computer_tool)
465
+ )
466
+ elif isinstance(output, McpApprovalRequest):
467
+ items.append(MCPApprovalRequestItem(raw_item=output, agent=agent))
468
+ if output.server_label not in hosted_mcp_server_map:
469
+ _error_tracing.attach_error_to_current_span(
470
+ SpanError(
471
+ message="MCP server label not found",
472
+ data={"server_label": output.server_label},
473
+ )
474
+ )
475
+ raise ModelBehaviorError(f"MCP server label {output.server_label} not found")
476
+ else:
477
+ server = hosted_mcp_server_map[output.server_label]
478
+ if server.on_approval_request:
479
+ mcp_approval_requests.append(
480
+ ToolRunMCPApprovalRequest(
481
+ request_item=output,
482
+ mcp_tool=server,
483
+ )
484
+ )
485
+ else:
486
+ logger.warning(
487
+ f"MCP server {output.server_label} has no on_approval_request hook"
488
+ )
489
+ elif isinstance(output, McpListTools):
490
+ items.append(MCPListToolsItem(raw_item=output, agent=agent))
491
+ elif isinstance(output, McpCall):
492
+ items.append(ToolCallItem(raw_item=output, agent=agent))
493
+ tools_used.append("mcp")
494
+ elif isinstance(output, ImageGenerationCall):
495
+ items.append(ToolCallItem(raw_item=output, agent=agent))
496
+ tools_used.append("image_generation")
497
+ elif isinstance(output, ResponseCodeInterpreterToolCall):
498
+ items.append(ToolCallItem(raw_item=output, agent=agent))
499
+ tools_used.append("code_interpreter")
500
+ elif isinstance(output, LocalShellCall):
501
+ items.append(ToolCallItem(raw_item=output, agent=agent))
502
+ tools_used.append("local_shell")
503
+ if not local_shell_tool:
504
+ _error_tracing.attach_error_to_current_span(
505
+ SpanError(
506
+ message="Local shell tool not found",
507
+ data={},
508
+ )
509
+ )
510
+ raise ModelBehaviorError(
511
+ "Model produced local shell call without a local shell tool."
512
+ )
513
+ local_shell_calls.append(
514
+ ToolRunLocalShellCall(tool_call=output, local_shell_tool=local_shell_tool)
515
+ )
516
+
517
+ elif not isinstance(output, ResponseFunctionToolCall):
518
+ logger.warning(f"Unexpected output type, ignoring: {type(output)}")
519
+ continue
520
+
521
+ # At this point we know it's a function tool call
522
+ if not isinstance(output, ResponseFunctionToolCall):
523
+ continue
524
+
525
+ tools_used.append(output.name)
526
+
527
+ # Handoffs
528
+ if output.name in handoff_map:
529
+ items.append(HandoffCallItem(raw_item=output, agent=agent))
530
+ handoff = ToolRunHandoff(
531
+ tool_call=output,
532
+ handoff=handoff_map[output.name],
533
+ )
534
+ run_handoffs.append(handoff)
535
+ # Regular function tool call
536
+ else:
537
+ if output.name not in function_map:
538
+ if output_schema is not None and output.name == "json_tool_call":
539
+ # LiteLLM could generate non-existent tool calls for structured outputs
540
+ items.append(ToolCallItem(raw_item=output, agent=agent))
541
+ functions.append(
542
+ ToolRunFunction(
543
+ tool_call=output,
544
+ # this tool does not exist in function_map, so generate ad-hoc one,
545
+ # which just parses the input if it's a string, and returns the
546
+ # value otherwise
547
+ function_tool=_build_litellm_json_tool_call(output),
548
+ )
549
+ )
550
+ continue
551
+ else:
552
+ _error_tracing.attach_error_to_current_span(
553
+ SpanError(
554
+ message="Tool not found",
555
+ data={"tool_name": output.name},
556
+ )
557
+ )
558
+ error = f"Tool {output.name} not found in agent {agent.name}"
559
+ raise ModelBehaviorError(error)
560
+
561
+ items.append(ToolCallItem(raw_item=output, agent=agent))
562
+ functions.append(
563
+ ToolRunFunction(
564
+ tool_call=output,
565
+ function_tool=function_map[output.name],
566
+ )
567
+ )
568
+
569
+ return ProcessedResponse(
570
+ new_items=items,
571
+ handoffs=run_handoffs,
572
+ functions=functions,
573
+ computer_actions=computer_actions,
574
+ local_shell_calls=local_shell_calls,
575
+ tools_used=tools_used,
576
+ mcp_approval_requests=mcp_approval_requests,
577
+ )
578
+
579
+ @classmethod
580
+ async def _execute_input_guardrails(
581
+ cls,
582
+ *,
583
+ func_tool: FunctionTool,
584
+ tool_context: ToolContext[TContext],
585
+ agent: Agent[TContext],
586
+ tool_input_guardrail_results: list[ToolInputGuardrailResult],
587
+ ) -> str | None:
588
+ """Execute input guardrails for a tool.
589
+
590
+ Args:
591
+ func_tool: The function tool being executed.
592
+ tool_context: The tool execution context.
593
+ agent: The agent executing the tool.
594
+ tool_input_guardrail_results: List to append guardrail results to.
595
+
596
+ Returns:
597
+ None if tool execution should proceed, or a message string if execution should be
598
+ skipped.
599
+
600
+ Raises:
601
+ ToolInputGuardrailTripwireTriggered: If a guardrail triggers an exception.
602
+ """
603
+ if not func_tool.tool_input_guardrails:
604
+ return None
605
+
606
+ for guardrail in func_tool.tool_input_guardrails:
607
+ gr_out = await guardrail.run(
608
+ ToolInputGuardrailData(
609
+ context=tool_context,
610
+ agent=agent,
611
+ )
612
+ )
613
+
614
+ # Store the guardrail result
615
+ tool_input_guardrail_results.append(
616
+ ToolInputGuardrailResult(
617
+ guardrail=guardrail,
618
+ output=gr_out,
619
+ )
620
+ )
621
+
622
+ # Handle different behavior types
623
+ if gr_out.behavior["type"] == "raise_exception":
624
+ raise ToolInputGuardrailTripwireTriggered(guardrail=guardrail, output=gr_out)
625
+ elif gr_out.behavior["type"] == "reject_content":
626
+ # Set final_result to the message and skip tool execution
627
+ return gr_out.behavior["message"]
628
+ elif gr_out.behavior["type"] == "allow":
629
+ # Continue to next guardrail or tool execution
630
+ continue
631
+
632
+ return None
633
+
634
+ @classmethod
635
+ async def _execute_output_guardrails(
636
+ cls,
637
+ *,
638
+ func_tool: FunctionTool,
639
+ tool_context: ToolContext[TContext],
640
+ agent: Agent[TContext],
641
+ real_result: Any,
642
+ tool_output_guardrail_results: list[ToolOutputGuardrailResult],
643
+ ) -> Any:
644
+ """Execute output guardrails for a tool.
645
+
646
+ Args:
647
+ func_tool: The function tool being executed.
648
+ tool_context: The tool execution context.
649
+ agent: The agent executing the tool.
650
+ real_result: The actual result from the tool execution.
651
+ tool_output_guardrail_results: List to append guardrail results to.
652
+
653
+ Returns:
654
+ The final result after guardrail processing (may be modified).
655
+
656
+ Raises:
657
+ ToolOutputGuardrailTripwireTriggered: If a guardrail triggers an exception.
658
+ """
659
+ if not func_tool.tool_output_guardrails:
660
+ return real_result
661
+
662
+ final_result = real_result
663
+ for output_guardrail in func_tool.tool_output_guardrails:
664
+ gr_out = await output_guardrail.run(
665
+ ToolOutputGuardrailData(
666
+ context=tool_context,
667
+ agent=agent,
668
+ output=real_result,
669
+ )
670
+ )
671
+
672
+ # Store the guardrail result
673
+ tool_output_guardrail_results.append(
674
+ ToolOutputGuardrailResult(
675
+ guardrail=output_guardrail,
676
+ output=gr_out,
677
+ )
678
+ )
679
+
680
+ # Handle different behavior types
681
+ if gr_out.behavior["type"] == "raise_exception":
682
+ raise ToolOutputGuardrailTripwireTriggered(
683
+ guardrail=output_guardrail, output=gr_out
684
+ )
685
+ elif gr_out.behavior["type"] == "reject_content":
686
+ # Override the result with the guardrail message
687
+ final_result = gr_out.behavior["message"]
688
+ break
689
+ elif gr_out.behavior["type"] == "allow":
690
+ # Continue to next guardrail
691
+ continue
692
+
693
+ return final_result
694
+
695
+ @classmethod
696
+ async def _execute_tool_with_hooks(
697
+ cls,
698
+ *,
699
+ func_tool: FunctionTool,
700
+ tool_context: ToolContext[TContext],
701
+ agent: Agent[TContext],
702
+ hooks: RunHooks[TContext],
703
+ tool_call: ResponseFunctionToolCall,
704
+ ) -> Any:
705
+ """Execute the core tool function with before/after hooks.
706
+
707
+ Args:
708
+ func_tool: The function tool being executed.
709
+ tool_context: The tool execution context.
710
+ agent: The agent executing the tool.
711
+ hooks: The run hooks to execute.
712
+ tool_call: The tool call details.
713
+
714
+ Returns:
715
+ The result from the tool execution.
716
+ """
717
+ await asyncio.gather(
718
+ hooks.on_tool_start(tool_context, agent, func_tool),
719
+ (
720
+ agent.hooks.on_tool_start(tool_context, agent, func_tool)
721
+ if agent.hooks
722
+ else _coro.noop_coroutine()
723
+ ),
724
+ )
725
+
726
+ return await func_tool.on_invoke_tool(tool_context, tool_call.arguments)
727
+
728
+ @classmethod
729
+ async def execute_function_tool_calls(
730
+ cls,
731
+ *,
732
+ agent: Agent[TContext],
733
+ tool_runs: list[ToolRunFunction],
734
+ hooks: RunHooks[TContext],
735
+ context_wrapper: RunContextWrapper[TContext],
736
+ config: RunConfig,
737
+ ) -> tuple[
738
+ list[FunctionToolResult], list[ToolInputGuardrailResult], list[ToolOutputGuardrailResult]
739
+ ]:
740
+ # Collect guardrail results
741
+ tool_input_guardrail_results: list[ToolInputGuardrailResult] = []
742
+ tool_output_guardrail_results: list[ToolOutputGuardrailResult] = []
743
+
744
+ async def run_single_tool(
745
+ func_tool: FunctionTool, tool_call: ResponseFunctionToolCall
746
+ ) -> Any:
747
+ with function_span(func_tool.name) as span_fn:
748
+ tool_context = ToolContext.from_agent_context(
749
+ context_wrapper,
750
+ tool_call.call_id,
751
+ tool_call=tool_call,
752
+ )
753
+ if config.trace_include_sensitive_data:
754
+ span_fn.span_data.input = tool_call.arguments
755
+ try:
756
+ # 1) Run input tool guardrails, if any
757
+ rejected_message = await cls._execute_input_guardrails(
758
+ func_tool=func_tool,
759
+ tool_context=tool_context,
760
+ agent=agent,
761
+ tool_input_guardrail_results=tool_input_guardrail_results,
762
+ )
763
+
764
+ if rejected_message is not None:
765
+ # Input guardrail rejected the tool call
766
+ final_result = rejected_message
767
+ else:
768
+ # 2) Actually run the tool
769
+ real_result = await cls._execute_tool_with_hooks(
770
+ func_tool=func_tool,
771
+ tool_context=tool_context,
772
+ agent=agent,
773
+ hooks=hooks,
774
+ tool_call=tool_call,
775
+ )
776
+
777
+ # 3) Run output tool guardrails, if any
778
+ final_result = await cls._execute_output_guardrails(
779
+ func_tool=func_tool,
780
+ tool_context=tool_context,
781
+ agent=agent,
782
+ real_result=real_result,
783
+ tool_output_guardrail_results=tool_output_guardrail_results,
784
+ )
785
+
786
+ # 4) Tool end hooks (with final result, which may have been overridden)
787
+ await asyncio.gather(
788
+ hooks.on_tool_end(tool_context, agent, func_tool, final_result),
789
+ (
790
+ agent.hooks.on_tool_end(
791
+ tool_context, agent, func_tool, final_result
792
+ )
793
+ if agent.hooks
794
+ else _coro.noop_coroutine()
795
+ ),
796
+ )
797
+ result = final_result
798
+ except Exception as e:
799
+ _error_tracing.attach_error_to_current_span(
800
+ SpanError(
801
+ message="Error running tool",
802
+ data={"tool_name": func_tool.name, "error": str(e)},
803
+ )
804
+ )
805
+ if isinstance(e, AgentsException):
806
+ raise e
807
+ raise UserError(f"Error running tool {func_tool.name}: {e}") from e
808
+
809
+ if config.trace_include_sensitive_data:
810
+ span_fn.span_data.output = result
811
+ return result
812
+
813
+ tasks = []
814
+ for tool_run in tool_runs:
815
+ function_tool = tool_run.function_tool
816
+ tasks.append(run_single_tool(function_tool, tool_run.tool_call))
817
+
818
+ results = await asyncio.gather(*tasks)
819
+
820
+ function_tool_results = [
821
+ FunctionToolResult(
822
+ tool=tool_run.function_tool,
823
+ output=result,
824
+ run_item=ToolCallOutputItem(
825
+ output=result,
826
+ raw_item=ItemHelpers.tool_call_output_item(tool_run.tool_call, str(result)),
827
+ agent=agent,
828
+ ),
829
+ )
830
+ for tool_run, result in zip(tool_runs, results)
831
+ ]
832
+
833
+ return function_tool_results, tool_input_guardrail_results, tool_output_guardrail_results
834
+
835
+ @classmethod
836
+ async def execute_local_shell_calls(
837
+ cls,
838
+ *,
839
+ agent: Agent[TContext],
840
+ calls: list[ToolRunLocalShellCall],
841
+ context_wrapper: RunContextWrapper[TContext],
842
+ hooks: RunHooks[TContext],
843
+ config: RunConfig,
844
+ ) -> list[RunItem]:
845
+ results: list[RunItem] = []
846
+ # Need to run these serially, because each call can affect the local shell state
847
+ for call in calls:
848
+ results.append(
849
+ await LocalShellAction.execute(
850
+ agent=agent,
851
+ call=call,
852
+ hooks=hooks,
853
+ context_wrapper=context_wrapper,
854
+ config=config,
855
+ )
856
+ )
857
+ return results
858
+
859
+ @classmethod
860
+ async def execute_computer_actions(
861
+ cls,
862
+ *,
863
+ agent: Agent[TContext],
864
+ actions: list[ToolRunComputerAction],
865
+ hooks: RunHooks[TContext],
866
+ context_wrapper: RunContextWrapper[TContext],
867
+ config: RunConfig,
868
+ ) -> list[RunItem]:
869
+ results: list[RunItem] = []
870
+ # Need to run these serially, because each action can affect the computer state
871
+ for action in actions:
872
+ acknowledged: list[ComputerCallOutputAcknowledgedSafetyCheck] | None = None
873
+ if action.tool_call.pending_safety_checks and action.computer_tool.on_safety_check:
874
+ acknowledged = []
875
+ for check in action.tool_call.pending_safety_checks:
876
+ data = ComputerToolSafetyCheckData(
877
+ ctx_wrapper=context_wrapper,
878
+ agent=agent,
879
+ tool_call=action.tool_call,
880
+ safety_check=check,
881
+ )
882
+ maybe = action.computer_tool.on_safety_check(data)
883
+ ack = await maybe if inspect.isawaitable(maybe) else maybe
884
+ if ack:
885
+ acknowledged.append(
886
+ ComputerCallOutputAcknowledgedSafetyCheck(
887
+ id=check.id,
888
+ code=check.code,
889
+ message=check.message,
890
+ )
891
+ )
892
+ else:
893
+ raise UserError("Computer tool safety check was not acknowledged")
894
+
895
+ results.append(
896
+ await ComputerAction.execute(
897
+ agent=agent,
898
+ action=action,
899
+ hooks=hooks,
900
+ context_wrapper=context_wrapper,
901
+ config=config,
902
+ acknowledged_safety_checks=acknowledged,
903
+ )
904
+ )
905
+
906
+ return results
907
+
908
+ @classmethod
909
+ async def execute_handoffs(
910
+ cls,
911
+ *,
912
+ agent: Agent[TContext],
913
+ original_input: str | list[TResponseInputItem],
914
+ pre_step_items: list[RunItem],
915
+ new_step_items: list[RunItem],
916
+ new_response: ModelResponse,
917
+ run_handoffs: list[ToolRunHandoff],
918
+ hooks: RunHooks[TContext],
919
+ context_wrapper: RunContextWrapper[TContext],
920
+ run_config: RunConfig,
921
+ ) -> SingleStepResult:
922
+ # If there is more than one handoff, add tool responses that reject those handoffs
923
+ multiple_handoffs = len(run_handoffs) > 1
924
+ if multiple_handoffs:
925
+ output_message = "Multiple handoffs detected, ignoring this one."
926
+ new_step_items.extend(
927
+ [
928
+ ToolCallOutputItem(
929
+ output=output_message,
930
+ raw_item=ItemHelpers.tool_call_output_item(
931
+ handoff.tool_call, output_message
932
+ ),
933
+ agent=agent,
934
+ )
935
+ for handoff in run_handoffs[1:]
936
+ ]
937
+ )
938
+
939
+ actual_handoff = run_handoffs[0]
940
+ with handoff_span(from_agent=agent.name) as span_handoff:
941
+ handoff = actual_handoff.handoff
942
+ new_agent: Agent[Any] = await handoff.on_invoke_handoff(
943
+ context_wrapper, actual_handoff.tool_call.arguments
944
+ )
945
+ span_handoff.span_data.to_agent = new_agent.name
946
+ if multiple_handoffs:
947
+ requested_agents = [handoff.handoff.agent_name for handoff in run_handoffs]
948
+ span_handoff.set_error(
949
+ SpanError(
950
+ message="Multiple handoffs requested",
951
+ data={
952
+ "requested_agents": requested_agents,
953
+ },
954
+ )
955
+ )
956
+
957
+ # Append a tool output item for the handoff
958
+ new_step_items.append(
959
+ HandoffOutputItem(
960
+ agent=agent,
961
+ raw_item=ItemHelpers.tool_call_output_item(
962
+ actual_handoff.tool_call,
963
+ handoff.get_transfer_message(new_agent),
964
+ ),
965
+ source_agent=agent,
966
+ target_agent=new_agent,
967
+ )
968
+ )
969
+
970
+ # Execute handoff hooks
971
+ await asyncio.gather(
972
+ hooks.on_handoff(
973
+ context=context_wrapper,
974
+ from_agent=agent,
975
+ to_agent=new_agent,
976
+ ),
977
+ (
978
+ agent.hooks.on_handoff(
979
+ context_wrapper,
980
+ agent=new_agent,
981
+ source=agent,
982
+ )
983
+ if agent.hooks
984
+ else _coro.noop_coroutine()
985
+ ),
986
+ )
987
+
988
+ # If there's an input filter, filter the input for the next agent
989
+ input_filter = handoff.input_filter or (
990
+ run_config.handoff_input_filter if run_config else None
991
+ )
992
+ if input_filter:
993
+ logger.debug("Filtering inputs for handoff")
994
+ handoff_input_data = HandoffInputData(
995
+ input_history=tuple(original_input)
996
+ if isinstance(original_input, list)
997
+ else original_input,
998
+ pre_handoff_items=tuple(pre_step_items),
999
+ new_items=tuple(new_step_items),
1000
+ run_context=context_wrapper,
1001
+ )
1002
+ if not callable(input_filter):
1003
+ _error_tracing.attach_error_to_span(
1004
+ span_handoff,
1005
+ SpanError(
1006
+ message="Invalid input filter",
1007
+ data={"details": "not callable()"},
1008
+ ),
1009
+ )
1010
+ raise UserError(f"Invalid input filter: {input_filter}")
1011
+ filtered = input_filter(handoff_input_data)
1012
+ if inspect.isawaitable(filtered):
1013
+ filtered = await filtered
1014
+ if not isinstance(filtered, HandoffInputData):
1015
+ _error_tracing.attach_error_to_span(
1016
+ span_handoff,
1017
+ SpanError(
1018
+ message="Invalid input filter result",
1019
+ data={"details": "not a HandoffInputData"},
1020
+ ),
1021
+ )
1022
+ raise UserError(f"Invalid input filter result: {filtered}")
1023
+
1024
+ original_input = (
1025
+ filtered.input_history
1026
+ if isinstance(filtered.input_history, str)
1027
+ else list(filtered.input_history)
1028
+ )
1029
+ pre_step_items = list(filtered.pre_handoff_items)
1030
+ new_step_items = list(filtered.new_items)
1031
+
1032
+ return SingleStepResult(
1033
+ original_input=original_input,
1034
+ model_response=new_response,
1035
+ pre_step_items=pre_step_items,
1036
+ new_step_items=new_step_items,
1037
+ next_step=NextStepHandoff(new_agent),
1038
+ tool_input_guardrail_results=[],
1039
+ tool_output_guardrail_results=[],
1040
+ )
1041
+
1042
+ @classmethod
1043
+ async def execute_mcp_approval_requests(
1044
+ cls,
1045
+ *,
1046
+ agent: Agent[TContext],
1047
+ approval_requests: list[ToolRunMCPApprovalRequest],
1048
+ context_wrapper: RunContextWrapper[TContext],
1049
+ ) -> list[RunItem]:
1050
+ async def run_single_approval(approval_request: ToolRunMCPApprovalRequest) -> RunItem:
1051
+ callback = approval_request.mcp_tool.on_approval_request
1052
+ assert callback is not None, "Callback is required for MCP approval requests"
1053
+ maybe_awaitable_result = callback(
1054
+ MCPToolApprovalRequest(context_wrapper, approval_request.request_item)
1055
+ )
1056
+ if inspect.isawaitable(maybe_awaitable_result):
1057
+ result = await maybe_awaitable_result
1058
+ else:
1059
+ result = maybe_awaitable_result
1060
+ reason = result.get("reason", None)
1061
+ raw_item: McpApprovalResponse = {
1062
+ "approval_request_id": approval_request.request_item.id,
1063
+ "approve": result["approve"],
1064
+ "type": "mcp_approval_response",
1065
+ }
1066
+ if not result["approve"] and reason:
1067
+ raw_item["reason"] = reason
1068
+ return MCPApprovalResponseItem(
1069
+ raw_item=raw_item,
1070
+ agent=agent,
1071
+ )
1072
+
1073
+ tasks = [run_single_approval(approval_request) for approval_request in approval_requests]
1074
+ return await asyncio.gather(*tasks)
1075
+
1076
+ @classmethod
1077
+ async def execute_final_output(
1078
+ cls,
1079
+ *,
1080
+ agent: Agent[TContext],
1081
+ original_input: str | list[TResponseInputItem],
1082
+ new_response: ModelResponse,
1083
+ pre_step_items: list[RunItem],
1084
+ new_step_items: list[RunItem],
1085
+ final_output: Any,
1086
+ hooks: RunHooks[TContext],
1087
+ context_wrapper: RunContextWrapper[TContext],
1088
+ tool_input_guardrail_results: list[ToolInputGuardrailResult],
1089
+ tool_output_guardrail_results: list[ToolOutputGuardrailResult],
1090
+ ) -> SingleStepResult:
1091
+ # Run the on_end hooks
1092
+ await cls.run_final_output_hooks(agent, hooks, context_wrapper, final_output)
1093
+
1094
+ return SingleStepResult(
1095
+ original_input=original_input,
1096
+ model_response=new_response,
1097
+ pre_step_items=pre_step_items,
1098
+ new_step_items=new_step_items,
1099
+ next_step=NextStepFinalOutput(final_output),
1100
+ tool_input_guardrail_results=tool_input_guardrail_results,
1101
+ tool_output_guardrail_results=tool_output_guardrail_results,
1102
+ )
1103
+
1104
+ @classmethod
1105
+ async def run_final_output_hooks(
1106
+ cls,
1107
+ agent: Agent[TContext],
1108
+ hooks: RunHooks[TContext],
1109
+ context_wrapper: RunContextWrapper[TContext],
1110
+ final_output: Any,
1111
+ ):
1112
+ await asyncio.gather(
1113
+ hooks.on_agent_end(context_wrapper, agent, final_output),
1114
+ agent.hooks.on_end(context_wrapper, agent, final_output)
1115
+ if agent.hooks
1116
+ else _coro.noop_coroutine(),
1117
+ )
1118
+
1119
+ @classmethod
1120
+ async def run_single_input_guardrail(
1121
+ cls,
1122
+ agent: Agent[Any],
1123
+ guardrail: InputGuardrail[TContext],
1124
+ input: str | list[TResponseInputItem],
1125
+ context: RunContextWrapper[TContext],
1126
+ ) -> InputGuardrailResult:
1127
+ with guardrail_span(guardrail.get_name()) as span_guardrail:
1128
+ result = await guardrail.run(agent, input, context)
1129
+ span_guardrail.span_data.triggered = result.output.tripwire_triggered
1130
+ return result
1131
+
1132
+ @classmethod
1133
+ async def run_single_output_guardrail(
1134
+ cls,
1135
+ guardrail: OutputGuardrail[TContext],
1136
+ agent: Agent[Any],
1137
+ agent_output: Any,
1138
+ context: RunContextWrapper[TContext],
1139
+ ) -> OutputGuardrailResult:
1140
+ with guardrail_span(guardrail.get_name()) as span_guardrail:
1141
+ result = await guardrail.run(agent=agent, agent_output=agent_output, context=context)
1142
+ span_guardrail.span_data.triggered = result.output.tripwire_triggered
1143
+ return result
1144
+
1145
+ @classmethod
1146
+ def stream_step_items_to_queue(
1147
+ cls,
1148
+ new_step_items: list[RunItem],
1149
+ queue: asyncio.Queue[StreamEvent | QueueCompleteSentinel],
1150
+ ):
1151
+ for item in new_step_items:
1152
+ if isinstance(item, MessageOutputItem):
1153
+ event = RunItemStreamEvent(item=item, name="message_output_created")
1154
+ elif isinstance(item, HandoffCallItem):
1155
+ event = RunItemStreamEvent(item=item, name="handoff_requested")
1156
+ elif isinstance(item, HandoffOutputItem):
1157
+ event = RunItemStreamEvent(item=item, name="handoff_occured")
1158
+ elif isinstance(item, ToolCallItem):
1159
+ event = RunItemStreamEvent(item=item, name="tool_called")
1160
+ elif isinstance(item, ToolCallOutputItem):
1161
+ event = RunItemStreamEvent(item=item, name="tool_output")
1162
+ elif isinstance(item, ReasoningItem):
1163
+ event = RunItemStreamEvent(item=item, name="reasoning_item_created")
1164
+ elif isinstance(item, MCPApprovalRequestItem):
1165
+ event = RunItemStreamEvent(item=item, name="mcp_approval_requested")
1166
+ elif isinstance(item, MCPListToolsItem):
1167
+ event = RunItemStreamEvent(item=item, name="mcp_list_tools")
1168
+
1169
+ else:
1170
+ logger.warning(f"Unexpected item type: {type(item)}")
1171
+ event = None
1172
+
1173
+ if event:
1174
+ queue.put_nowait(event)
1175
+
1176
+ @classmethod
1177
+ def stream_step_result_to_queue(
1178
+ cls,
1179
+ step_result: SingleStepResult,
1180
+ queue: asyncio.Queue[StreamEvent | QueueCompleteSentinel],
1181
+ ):
1182
+ cls.stream_step_items_to_queue(step_result.new_step_items, queue)
1183
+
1184
+ @classmethod
1185
+ async def _check_for_final_output_from_tools(
1186
+ cls,
1187
+ *,
1188
+ agent: Agent[TContext],
1189
+ tool_results: list[FunctionToolResult],
1190
+ context_wrapper: RunContextWrapper[TContext],
1191
+ config: RunConfig,
1192
+ ) -> ToolsToFinalOutputResult:
1193
+ """Determine if tool results should produce a final output.
1194
+ Returns:
1195
+ ToolsToFinalOutputResult: Indicates whether final output is ready, and the output value.
1196
+ """
1197
+ if not tool_results:
1198
+ return _NOT_FINAL_OUTPUT
1199
+
1200
+ if agent.tool_use_behavior == "run_llm_again":
1201
+ return _NOT_FINAL_OUTPUT
1202
+ elif agent.tool_use_behavior == "stop_on_first_tool":
1203
+ return ToolsToFinalOutputResult(
1204
+ is_final_output=True, final_output=tool_results[0].output
1205
+ )
1206
+ elif isinstance(agent.tool_use_behavior, dict):
1207
+ names = agent.tool_use_behavior.get("stop_at_tool_names", [])
1208
+ for tool_result in tool_results:
1209
+ if tool_result.tool.name in names:
1210
+ return ToolsToFinalOutputResult(
1211
+ is_final_output=True, final_output=tool_result.output
1212
+ )
1213
+ return ToolsToFinalOutputResult(is_final_output=False, final_output=None)
1214
+ elif callable(agent.tool_use_behavior):
1215
+ if inspect.iscoroutinefunction(agent.tool_use_behavior):
1216
+ return await cast(
1217
+ Awaitable[ToolsToFinalOutputResult],
1218
+ agent.tool_use_behavior(context_wrapper, tool_results),
1219
+ )
1220
+ else:
1221
+ return cast(
1222
+ ToolsToFinalOutputResult, agent.tool_use_behavior(context_wrapper, tool_results)
1223
+ )
1224
+
1225
+ logger.error(f"Invalid tool_use_behavior: {agent.tool_use_behavior}")
1226
+ raise UserError(f"Invalid tool_use_behavior: {agent.tool_use_behavior}")
1227
+
1228
+
1229
+ class TraceCtxManager:
1230
+ """Creates a trace only if there is no current trace, and manages the trace lifecycle."""
1231
+
1232
+ def __init__(
1233
+ self,
1234
+ workflow_name: str,
1235
+ trace_id: str | None,
1236
+ group_id: str | None,
1237
+ metadata: dict[str, Any] | None,
1238
+ disabled: bool,
1239
+ ):
1240
+ self.trace: Trace | None = None
1241
+ self.workflow_name = workflow_name
1242
+ self.trace_id = trace_id
1243
+ self.group_id = group_id
1244
+ self.metadata = metadata
1245
+ self.disabled = disabled
1246
+
1247
+ def __enter__(self) -> TraceCtxManager:
1248
+ current_trace = get_current_trace()
1249
+ if not current_trace:
1250
+ self.trace = trace(
1251
+ workflow_name=self.workflow_name,
1252
+ trace_id=self.trace_id,
1253
+ group_id=self.group_id,
1254
+ metadata=self.metadata,
1255
+ disabled=self.disabled,
1256
+ )
1257
+ self.trace.start(mark_as_current=True)
1258
+
1259
+ return self
1260
+
1261
+ def __exit__(self, exc_type, exc_val, exc_tb):
1262
+ if self.trace:
1263
+ self.trace.finish(reset_current=True)
1264
+
1265
+
1266
+ class ComputerAction:
1267
+ @classmethod
1268
+ async def execute(
1269
+ cls,
1270
+ *,
1271
+ agent: Agent[TContext],
1272
+ action: ToolRunComputerAction,
1273
+ hooks: RunHooks[TContext],
1274
+ context_wrapper: RunContextWrapper[TContext],
1275
+ config: RunConfig,
1276
+ acknowledged_safety_checks: list[ComputerCallOutputAcknowledgedSafetyCheck] | None = None,
1277
+ ) -> RunItem:
1278
+ output_func = (
1279
+ cls._get_screenshot_async(action.computer_tool.computer, action.tool_call)
1280
+ if isinstance(action.computer_tool.computer, AsyncComputer)
1281
+ else cls._get_screenshot_sync(action.computer_tool.computer, action.tool_call)
1282
+ )
1283
+
1284
+ _, _, output = await asyncio.gather(
1285
+ hooks.on_tool_start(context_wrapper, agent, action.computer_tool),
1286
+ (
1287
+ agent.hooks.on_tool_start(context_wrapper, agent, action.computer_tool)
1288
+ if agent.hooks
1289
+ else _coro.noop_coroutine()
1290
+ ),
1291
+ output_func,
1292
+ )
1293
+
1294
+ await asyncio.gather(
1295
+ hooks.on_tool_end(context_wrapper, agent, action.computer_tool, output),
1296
+ (
1297
+ agent.hooks.on_tool_end(context_wrapper, agent, action.computer_tool, output)
1298
+ if agent.hooks
1299
+ else _coro.noop_coroutine()
1300
+ ),
1301
+ )
1302
+
1303
+ # TODO: don't send a screenshot every single time, use references
1304
+ image_url = f"data:image/png;base64,{output}"
1305
+ return ToolCallOutputItem(
1306
+ agent=agent,
1307
+ output=image_url,
1308
+ raw_item=ComputerCallOutput(
1309
+ call_id=action.tool_call.call_id,
1310
+ output={
1311
+ "type": "computer_screenshot",
1312
+ "image_url": image_url,
1313
+ },
1314
+ type="computer_call_output",
1315
+ acknowledged_safety_checks=acknowledged_safety_checks,
1316
+ ),
1317
+ )
1318
+
1319
+ @classmethod
1320
+ async def _get_screenshot_sync(
1321
+ cls,
1322
+ computer: Computer,
1323
+ tool_call: ResponseComputerToolCall,
1324
+ ) -> str:
1325
+ action = tool_call.action
1326
+ if isinstance(action, ActionClick):
1327
+ computer.click(action.x, action.y, action.button)
1328
+ elif isinstance(action, ActionDoubleClick):
1329
+ computer.double_click(action.x, action.y)
1330
+ elif isinstance(action, ActionDrag):
1331
+ computer.drag([(p.x, p.y) for p in action.path])
1332
+ elif isinstance(action, ActionKeypress):
1333
+ computer.keypress(action.keys)
1334
+ elif isinstance(action, ActionMove):
1335
+ computer.move(action.x, action.y)
1336
+ elif isinstance(action, ActionScreenshot):
1337
+ computer.screenshot()
1338
+ elif isinstance(action, ActionScroll):
1339
+ computer.scroll(action.x, action.y, action.scroll_x, action.scroll_y)
1340
+ elif isinstance(action, ActionType):
1341
+ computer.type(action.text)
1342
+ elif isinstance(action, ActionWait):
1343
+ computer.wait()
1344
+
1345
+ return computer.screenshot()
1346
+
1347
+ @classmethod
1348
+ async def _get_screenshot_async(
1349
+ cls,
1350
+ computer: AsyncComputer,
1351
+ tool_call: ResponseComputerToolCall,
1352
+ ) -> str:
1353
+ action = tool_call.action
1354
+ if isinstance(action, ActionClick):
1355
+ await computer.click(action.x, action.y, action.button)
1356
+ elif isinstance(action, ActionDoubleClick):
1357
+ await computer.double_click(action.x, action.y)
1358
+ elif isinstance(action, ActionDrag):
1359
+ await computer.drag([(p.x, p.y) for p in action.path])
1360
+ elif isinstance(action, ActionKeypress):
1361
+ await computer.keypress(action.keys)
1362
+ elif isinstance(action, ActionMove):
1363
+ await computer.move(action.x, action.y)
1364
+ elif isinstance(action, ActionScreenshot):
1365
+ await computer.screenshot()
1366
+ elif isinstance(action, ActionScroll):
1367
+ await computer.scroll(action.x, action.y, action.scroll_x, action.scroll_y)
1368
+ elif isinstance(action, ActionType):
1369
+ await computer.type(action.text)
1370
+ elif isinstance(action, ActionWait):
1371
+ await computer.wait()
1372
+
1373
+ return await computer.screenshot()
1374
+
1375
+
1376
+ class LocalShellAction:
1377
+ @classmethod
1378
+ async def execute(
1379
+ cls,
1380
+ *,
1381
+ agent: Agent[TContext],
1382
+ call: ToolRunLocalShellCall,
1383
+ hooks: RunHooks[TContext],
1384
+ context_wrapper: RunContextWrapper[TContext],
1385
+ config: RunConfig,
1386
+ ) -> RunItem:
1387
+ await asyncio.gather(
1388
+ hooks.on_tool_start(context_wrapper, agent, call.local_shell_tool),
1389
+ (
1390
+ agent.hooks.on_tool_start(context_wrapper, agent, call.local_shell_tool)
1391
+ if agent.hooks
1392
+ else _coro.noop_coroutine()
1393
+ ),
1394
+ )
1395
+
1396
+ request = LocalShellCommandRequest(
1397
+ ctx_wrapper=context_wrapper,
1398
+ data=call.tool_call,
1399
+ )
1400
+ output = call.local_shell_tool.executor(request)
1401
+ if inspect.isawaitable(output):
1402
+ result = await output
1403
+ else:
1404
+ result = output
1405
+
1406
+ await asyncio.gather(
1407
+ hooks.on_tool_end(context_wrapper, agent, call.local_shell_tool, result),
1408
+ (
1409
+ agent.hooks.on_tool_end(context_wrapper, agent, call.local_shell_tool, result)
1410
+ if agent.hooks
1411
+ else _coro.noop_coroutine()
1412
+ ),
1413
+ )
1414
+
1415
+ return ToolCallOutputItem(
1416
+ agent=agent,
1417
+ output=output,
1418
+ raw_item={
1419
+ "type": "local_shell_call_output",
1420
+ "id": call.tool_call.call_id,
1421
+ "output": result,
1422
+ # "id": "out" + call.tool_call.id, # TODO remove this, it should be optional
1423
+ },
1424
+ )
1425
+
1426
+
1427
+ def _build_litellm_json_tool_call(output: ResponseFunctionToolCall) -> FunctionTool:
1428
+ async def on_invoke_tool(_ctx: ToolContext[Any], value: Any) -> Any:
1429
+ if isinstance(value, str):
1430
+ import json
1431
+
1432
+ return json.loads(value)
1433
+ return value
1434
+
1435
+ return FunctionTool(
1436
+ name=output.name,
1437
+ description=output.name,
1438
+ params_json_schema={},
1439
+ on_invoke_tool=on_invoke_tool,
1440
+ strict_json_schema=True,
1441
+ is_enabled=True,
1442
+ )
agents/agent.py ADDED
@@ -0,0 +1,476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import dataclasses
5
+ import inspect
6
+ from collections.abc import Awaitable
7
+ from dataclasses import dataclass, field
8
+ from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, cast
9
+
10
+ from openai.types.responses.response_prompt_param import ResponsePromptParam
11
+ from typing_extensions import NotRequired, TypeAlias, TypedDict
12
+
13
+ from .agent_output import AgentOutputSchemaBase
14
+ from .guardrail import InputGuardrail, OutputGuardrail
15
+ from .handoffs import Handoff
16
+ from .items import ItemHelpers
17
+ from .logger import logger
18
+ from .mcp import MCPUtil
19
+ from .model_settings import ModelSettings
20
+ from .models.default_models import (
21
+ get_default_model_settings,
22
+ gpt_5_reasoning_settings_required,
23
+ is_gpt_5_default,
24
+ )
25
+ from .models.interface import Model
26
+ from .prompts import DynamicPromptFunction, Prompt, PromptUtil
27
+ from .run_context import RunContextWrapper, TContext
28
+ from .tool import FunctionTool, FunctionToolResult, Tool, function_tool
29
+ from .util import _transforms
30
+ from .util._types import MaybeAwaitable
31
+
32
+ if TYPE_CHECKING:
33
+ from .lifecycle import AgentHooks, RunHooks
34
+ from .mcp import MCPServer
35
+ from .memory.session import Session
36
+ from .result import RunResult
37
+ from .run import RunConfig
38
+
39
+
40
+ @dataclass
41
+ class ToolsToFinalOutputResult:
42
+ is_final_output: bool
43
+ """Whether this is the final output. If False, the LLM will run again and receive the tool call
44
+ output.
45
+ """
46
+
47
+ final_output: Any | None = None
48
+ """The final output. Can be None if `is_final_output` is False, otherwise must match the
49
+ `output_type` of the agent.
50
+ """
51
+
52
+
53
+ ToolsToFinalOutputFunction: TypeAlias = Callable[
54
+ [RunContextWrapper[TContext], list[FunctionToolResult]],
55
+ MaybeAwaitable[ToolsToFinalOutputResult],
56
+ ]
57
+ """A function that takes a run context and a list of tool results, and returns a
58
+ `ToolsToFinalOutputResult`.
59
+ """
60
+
61
+
62
+ class StopAtTools(TypedDict):
63
+ stop_at_tool_names: list[str]
64
+ """A list of tool names, any of which will stop the agent from running further."""
65
+
66
+
67
+ class MCPConfig(TypedDict):
68
+ """Configuration for MCP servers."""
69
+
70
+ convert_schemas_to_strict: NotRequired[bool]
71
+ """If True, we will attempt to convert the MCP schemas to strict-mode schemas. This is a
72
+ best-effort conversion, so some schemas may not be convertible. Defaults to False.
73
+ """
74
+
75
+
76
+ @dataclass
77
+ class AgentBase(Generic[TContext]):
78
+ """Base class for `Agent` and `RealtimeAgent`."""
79
+
80
+ name: str
81
+ """The name of the agent."""
82
+
83
+ handoff_description: str | None = None
84
+ """A description of the agent. This is used when the agent is used as a handoff, so that an
85
+ LLM knows what it does and when to invoke it.
86
+ """
87
+
88
+ tools: list[Tool] = field(default_factory=list)
89
+ """A list of tools that the agent can use."""
90
+
91
+ mcp_servers: list[MCPServer] = field(default_factory=list)
92
+ """A list of [Model Context Protocol](https://modelcontextprotocol.io/) servers that
93
+ the agent can use. Every time the agent runs, it will include tools from these servers in the
94
+ list of available tools.
95
+
96
+ NOTE: You are expected to manage the lifecycle of these servers. Specifically, you must call
97
+ `server.connect()` before passing it to the agent, and `server.cleanup()` when the server is no
98
+ longer needed.
99
+ """
100
+
101
+ mcp_config: MCPConfig = field(default_factory=lambda: MCPConfig())
102
+ """Configuration for MCP servers."""
103
+
104
+ async def get_mcp_tools(self, run_context: RunContextWrapper[TContext]) -> list[Tool]:
105
+ """Fetches the available tools from the MCP servers."""
106
+ convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False)
107
+ return await MCPUtil.get_all_function_tools(
108
+ self.mcp_servers, convert_schemas_to_strict, run_context, self
109
+ )
110
+
111
+ async def get_all_tools(self, run_context: RunContextWrapper[TContext]) -> list[Tool]:
112
+ """All agent tools, including MCP tools and function tools."""
113
+ mcp_tools = await self.get_mcp_tools(run_context)
114
+
115
+ async def _check_tool_enabled(tool: Tool) -> bool:
116
+ if not isinstance(tool, FunctionTool):
117
+ return True
118
+
119
+ attr = tool.is_enabled
120
+ if isinstance(attr, bool):
121
+ return attr
122
+ res = attr(run_context, self)
123
+ if inspect.isawaitable(res):
124
+ return bool(await res)
125
+ return bool(res)
126
+
127
+ results = await asyncio.gather(*(_check_tool_enabled(t) for t in self.tools))
128
+ enabled: list[Tool] = [t for t, ok in zip(self.tools, results) if ok]
129
+ return [*mcp_tools, *enabled]
130
+
131
+
132
+ @dataclass
133
+ class Agent(AgentBase, Generic[TContext]):
134
+ """An agent is an AI model configured with instructions, tools, guardrails, handoffs and more.
135
+
136
+ We strongly recommend passing `instructions`, which is the "system prompt" for the agent. In
137
+ addition, you can pass `handoff_description`, which is a human-readable description of the
138
+ agent, used when the agent is used inside tools/handoffs.
139
+
140
+ Agents are generic on the context type. The context is a (mutable) object you create. It is
141
+ passed to tool functions, handoffs, guardrails, etc.
142
+
143
+ See `AgentBase` for base parameters that are shared with `RealtimeAgent`s.
144
+ """
145
+
146
+ instructions: (
147
+ str
148
+ | Callable[
149
+ [RunContextWrapper[TContext], Agent[TContext]],
150
+ MaybeAwaitable[str],
151
+ ]
152
+ | None
153
+ ) = None
154
+ """The instructions for the agent. Will be used as the "system prompt" when this agent is
155
+ invoked. Describes what the agent should do, and how it responds.
156
+
157
+ Can either be a string, or a function that dynamically generates instructions for the agent. If
158
+ you provide a function, it will be called with the context and the agent instance. It must
159
+ return a string.
160
+ """
161
+
162
+ prompt: Prompt | DynamicPromptFunction | None = None
163
+ """A prompt object (or a function that returns a Prompt). Prompts allow you to dynamically
164
+ configure the instructions, tools and other config for an agent outside of your code. Only
165
+ usable with OpenAI models, using the Responses API.
166
+ """
167
+
168
+ handoffs: list[Agent[Any] | Handoff[TContext, Any]] = field(default_factory=list)
169
+ """Handoffs are sub-agents that the agent can delegate to. You can provide a list of handoffs,
170
+ and the agent can choose to delegate to them if relevant. Allows for separation of concerns and
171
+ modularity.
172
+ """
173
+
174
+ model: str | Model | None = None
175
+ """The model implementation to use when invoking the LLM.
176
+
177
+ By default, if not set, the agent will use the default model configured in
178
+ `agents.models.get_default_model()` (currently "gpt-4.1").
179
+ """
180
+
181
+ model_settings: ModelSettings = field(default_factory=get_default_model_settings)
182
+ """Configures model-specific tuning parameters (e.g. temperature, top_p).
183
+ """
184
+
185
+ input_guardrails: list[InputGuardrail[TContext]] = field(default_factory=list)
186
+ """A list of checks that run in parallel to the agent's execution, before generating a
187
+ response. Runs only if the agent is the first agent in the chain.
188
+ """
189
+
190
+ output_guardrails: list[OutputGuardrail[TContext]] = field(default_factory=list)
191
+ """A list of checks that run on the final output of the agent, after generating a response.
192
+ Runs only if the agent produces a final output.
193
+ """
194
+
195
+ output_type: type[Any] | AgentOutputSchemaBase | None = None
196
+ """The type of the output object. If not provided, the output will be `str`. In most cases,
197
+ you should pass a regular Python type (e.g. a dataclass, Pydantic model, TypedDict, etc).
198
+ You can customize this in two ways:
199
+ 1. If you want non-strict schemas, pass `AgentOutputSchema(MyClass, strict_json_schema=False)`.
200
+ 2. If you want to use a custom JSON schema (i.e. without using the SDK's automatic schema)
201
+ creation, subclass and pass an `AgentOutputSchemaBase` subclass.
202
+ """
203
+
204
+ hooks: AgentHooks[TContext] | None = None
205
+ """A class that receives callbacks on various lifecycle events for this agent.
206
+ """
207
+
208
+ tool_use_behavior: (
209
+ Literal["run_llm_again", "stop_on_first_tool"] | StopAtTools | ToolsToFinalOutputFunction
210
+ ) = "run_llm_again"
211
+ """
212
+ This lets you configure how tool use is handled.
213
+ - "run_llm_again": The default behavior. Tools are run, and then the LLM receives the results
214
+ and gets to respond.
215
+ - "stop_on_first_tool": The output from the first tool call is treated as the final result.
216
+ In other words, it isn’t sent back to the LLM for further processing but is used directly
217
+ as the final output.
218
+ - A StopAtTools object: The agent will stop running if any of the tools listed in
219
+ `stop_at_tool_names` is called.
220
+ The final output will be the output of the first matching tool call.
221
+ The LLM does not process the result of the tool call.
222
+ - A function: If you pass a function, it will be called with the run context and the list of
223
+ tool results. It must return a `ToolsToFinalOutputResult`, which determines whether the tool
224
+ calls result in a final output.
225
+
226
+ NOTE: This configuration is specific to FunctionTools. Hosted tools, such as file search,
227
+ web search, etc. are always processed by the LLM.
228
+ """
229
+
230
+ reset_tool_choice: bool = True
231
+ """Whether to reset the tool choice to the default value after a tool has been called. Defaults
232
+ to True. This ensures that the agent doesn't enter an infinite loop of tool usage."""
233
+
234
+ def __post_init__(self):
235
+ from typing import get_origin
236
+
237
+ if not isinstance(self.name, str):
238
+ raise TypeError(f"Agent name must be a string, got {type(self.name).__name__}")
239
+
240
+ if self.handoff_description is not None and not isinstance(self.handoff_description, str):
241
+ raise TypeError(
242
+ f"Agent handoff_description must be a string or None, "
243
+ f"got {type(self.handoff_description).__name__}"
244
+ )
245
+
246
+ if not isinstance(self.tools, list):
247
+ raise TypeError(f"Agent tools must be a list, got {type(self.tools).__name__}")
248
+
249
+ if not isinstance(self.mcp_servers, list):
250
+ raise TypeError(
251
+ f"Agent mcp_servers must be a list, got {type(self.mcp_servers).__name__}"
252
+ )
253
+
254
+ if not isinstance(self.mcp_config, dict):
255
+ raise TypeError(
256
+ f"Agent mcp_config must be a dict, got {type(self.mcp_config).__name__}"
257
+ )
258
+
259
+ if (
260
+ self.instructions is not None
261
+ and not isinstance(self.instructions, str)
262
+ and not callable(self.instructions)
263
+ ):
264
+ raise TypeError(
265
+ f"Agent instructions must be a string, callable, or None, "
266
+ f"got {type(self.instructions).__name__}"
267
+ )
268
+
269
+ if (
270
+ self.prompt is not None
271
+ and not callable(self.prompt)
272
+ and not hasattr(self.prompt, "get")
273
+ ):
274
+ raise TypeError(
275
+ f"Agent prompt must be a Prompt, DynamicPromptFunction, or None, "
276
+ f"got {type(self.prompt).__name__}"
277
+ )
278
+
279
+ if not isinstance(self.handoffs, list):
280
+ raise TypeError(f"Agent handoffs must be a list, got {type(self.handoffs).__name__}")
281
+
282
+ if self.model is not None and not isinstance(self.model, str):
283
+ from .models.interface import Model
284
+
285
+ if not isinstance(self.model, Model):
286
+ raise TypeError(
287
+ f"Agent model must be a string, Model, or None, got {type(self.model).__name__}"
288
+ )
289
+
290
+ if not isinstance(self.model_settings, ModelSettings):
291
+ raise TypeError(
292
+ f"Agent model_settings must be a ModelSettings instance, "
293
+ f"got {type(self.model_settings).__name__}"
294
+ )
295
+
296
+ if (
297
+ # The user sets a non-default model
298
+ self.model is not None
299
+ and (
300
+ # The default model is gpt-5
301
+ is_gpt_5_default() is True
302
+ # However, the specified model is not a gpt-5 model
303
+ and (
304
+ isinstance(self.model, str) is False
305
+ or gpt_5_reasoning_settings_required(self.model) is False # type: ignore
306
+ )
307
+ # The model settings are not customized for the specified model
308
+ and self.model_settings == get_default_model_settings()
309
+ )
310
+ ):
311
+ # In this scenario, we should use a generic model settings
312
+ # because non-gpt-5 models are not compatible with the default gpt-5 model settings.
313
+ # This is a best-effort attempt to make the agent work with non-gpt-5 models.
314
+ self.model_settings = ModelSettings()
315
+
316
+ if not isinstance(self.input_guardrails, list):
317
+ raise TypeError(
318
+ f"Agent input_guardrails must be a list, got {type(self.input_guardrails).__name__}"
319
+ )
320
+
321
+ if not isinstance(self.output_guardrails, list):
322
+ raise TypeError(
323
+ f"Agent output_guardrails must be a list, "
324
+ f"got {type(self.output_guardrails).__name__}"
325
+ )
326
+
327
+ if self.output_type is not None:
328
+ from .agent_output import AgentOutputSchemaBase
329
+
330
+ if not (
331
+ isinstance(self.output_type, (type, AgentOutputSchemaBase))
332
+ or get_origin(self.output_type) is not None
333
+ ):
334
+ raise TypeError(
335
+ f"Agent output_type must be a type, AgentOutputSchemaBase, or None, "
336
+ f"got {type(self.output_type).__name__}"
337
+ )
338
+
339
+ if self.hooks is not None:
340
+ from .lifecycle import AgentHooksBase
341
+
342
+ if not isinstance(self.hooks, AgentHooksBase):
343
+ raise TypeError(
344
+ f"Agent hooks must be an AgentHooks instance or None, "
345
+ f"got {type(self.hooks).__name__}"
346
+ )
347
+
348
+ if (
349
+ not (
350
+ isinstance(self.tool_use_behavior, str)
351
+ and self.tool_use_behavior in ["run_llm_again", "stop_on_first_tool"]
352
+ )
353
+ and not isinstance(self.tool_use_behavior, dict)
354
+ and not callable(self.tool_use_behavior)
355
+ ):
356
+ raise TypeError(
357
+ f"Agent tool_use_behavior must be 'run_llm_again', 'stop_on_first_tool', "
358
+ f"StopAtTools dict, or callable, got {type(self.tool_use_behavior).__name__}"
359
+ )
360
+
361
+ if not isinstance(self.reset_tool_choice, bool):
362
+ raise TypeError(
363
+ f"Agent reset_tool_choice must be a boolean, "
364
+ f"got {type(self.reset_tool_choice).__name__}"
365
+ )
366
+
367
+ def clone(self, **kwargs: Any) -> Agent[TContext]:
368
+ """Make a copy of the agent, with the given arguments changed.
369
+ Notes:
370
+ - Uses `dataclasses.replace`, which performs a **shallow copy**.
371
+ - Mutable attributes like `tools` and `handoffs` are shallow-copied:
372
+ new list objects are created only if overridden, but their contents
373
+ (tool functions and handoff objects) are shared with the original.
374
+ - To modify these independently, pass new lists when calling `clone()`.
375
+ Example:
376
+ ```python
377
+ new_agent = agent.clone(instructions="New instructions")
378
+ ```
379
+ """
380
+ return dataclasses.replace(self, **kwargs)
381
+
382
+ def as_tool(
383
+ self,
384
+ tool_name: str | None,
385
+ tool_description: str | None,
386
+ custom_output_extractor: Callable[[RunResult], Awaitable[str]] | None = None,
387
+ is_enabled: bool
388
+ | Callable[[RunContextWrapper[Any], AgentBase[Any]], MaybeAwaitable[bool]] = True,
389
+ run_config: RunConfig | None = None,
390
+ max_turns: int | None = None,
391
+ hooks: RunHooks[TContext] | None = None,
392
+ previous_response_id: str | None = None,
393
+ conversation_id: str | None = None,
394
+ session: Session | None = None,
395
+ ) -> Tool:
396
+ """Transform this agent into a tool, callable by other agents.
397
+
398
+ This is different from handoffs in two ways:
399
+ 1. In handoffs, the new agent receives the conversation history. In this tool, the new agent
400
+ receives generated input.
401
+ 2. In handoffs, the new agent takes over the conversation. In this tool, the new agent is
402
+ called as a tool, and the conversation is continued by the original agent.
403
+
404
+ Args:
405
+ tool_name: The name of the tool. If not provided, the agent's name will be used.
406
+ tool_description: The description of the tool, which should indicate what it does and
407
+ when to use it.
408
+ custom_output_extractor: A function that extracts the output from the agent. If not
409
+ provided, the last message from the agent will be used.
410
+ is_enabled: Whether the tool is enabled. Can be a bool or a callable that takes the run
411
+ context and agent and returns whether the tool is enabled. Disabled tools are hidden
412
+ from the LLM at runtime.
413
+ """
414
+
415
+ @function_tool(
416
+ name_override=tool_name or _transforms.transform_string_function_style(self.name),
417
+ description_override=tool_description or "",
418
+ is_enabled=is_enabled,
419
+ )
420
+ async def run_agent(context: RunContextWrapper, input: str) -> str:
421
+ from .run import DEFAULT_MAX_TURNS, Runner
422
+
423
+ resolved_max_turns = max_turns if max_turns is not None else DEFAULT_MAX_TURNS
424
+
425
+ output = await Runner.run(
426
+ starting_agent=self,
427
+ input=input,
428
+ context=context.context,
429
+ run_config=run_config,
430
+ max_turns=resolved_max_turns,
431
+ hooks=hooks,
432
+ previous_response_id=previous_response_id,
433
+ conversation_id=conversation_id,
434
+ session=session,
435
+ )
436
+ if custom_output_extractor:
437
+ return await custom_output_extractor(output)
438
+
439
+ return ItemHelpers.text_message_outputs(output.new_items)
440
+
441
+ return run_agent
442
+
443
+ async def get_system_prompt(self, run_context: RunContextWrapper[TContext]) -> str | None:
444
+ if isinstance(self.instructions, str):
445
+ return self.instructions
446
+ elif callable(self.instructions):
447
+ # Inspect the signature of the instructions function
448
+ sig = inspect.signature(self.instructions)
449
+ params = list(sig.parameters.values())
450
+
451
+ # Enforce exactly 2 parameters
452
+ if len(params) != 2:
453
+ raise TypeError(
454
+ f"'instructions' callable must accept exactly 2 arguments (context, agent), "
455
+ f"but got {len(params)}: {[p.name for p in params]}"
456
+ )
457
+
458
+ # Call the instructions function properly
459
+ if inspect.iscoroutinefunction(self.instructions):
460
+ return await cast(Awaitable[str], self.instructions(run_context, self))
461
+ else:
462
+ return cast(str, self.instructions(run_context, self))
463
+
464
+ elif self.instructions is not None:
465
+ logger.error(
466
+ f"Instructions must be a string or a callable function, "
467
+ f"got {type(self.instructions).__name__}"
468
+ )
469
+
470
+ return None
471
+
472
+ async def get_prompt(
473
+ self, run_context: RunContextWrapper[TContext]
474
+ ) -> ResponsePromptParam | None:
475
+ """Get the prompt for the agent."""
476
+ return await PromptUtil.to_model_input(self.prompt, run_context, self)
agents/agent_output.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ from dataclasses import dataclass
3
+ from typing import Any
4
+
5
+ from pydantic import BaseModel, TypeAdapter
6
+ from typing_extensions import TypedDict, get_args, get_origin
7
+
8
+ from .exceptions import ModelBehaviorError, UserError
9
+ from .strict_schema import ensure_strict_json_schema
10
+ from .tracing import SpanError
11
+ from .util import _error_tracing, _json
12
+
13
+ _WRAPPER_DICT_KEY = "response"
14
+
15
+
16
+ class AgentOutputSchemaBase(abc.ABC):
17
+ """An object that captures the JSON schema of the output, as well as validating/parsing JSON
18
+ produced by the LLM into the output type.
19
+ """
20
+
21
+ @abc.abstractmethod
22
+ def is_plain_text(self) -> bool:
23
+ """Whether the output type is plain text (versus a JSON object)."""
24
+ pass
25
+
26
+ @abc.abstractmethod
27
+ def name(self) -> str:
28
+ """The name of the output type."""
29
+ pass
30
+
31
+ @abc.abstractmethod
32
+ def json_schema(self) -> dict[str, Any]:
33
+ """Returns the JSON schema of the output. Will only be called if the output type is not
34
+ plain text.
35
+ """
36
+ pass
37
+
38
+ @abc.abstractmethod
39
+ def is_strict_json_schema(self) -> bool:
40
+ """Whether the JSON schema is in strict mode. Strict mode constrains the JSON schema
41
+ features, but guarantees valid JSON. See here for details:
42
+ https://platform.openai.com/docs/guides/structured-outputs#supported-schemas
43
+ """
44
+ pass
45
+
46
+ @abc.abstractmethod
47
+ def validate_json(self, json_str: str) -> Any:
48
+ """Validate a JSON string against the output type. You must return the validated object,
49
+ or raise a `ModelBehaviorError` if the JSON is invalid.
50
+ """
51
+ pass
52
+
53
+
54
+ @dataclass(init=False)
55
+ class AgentOutputSchema(AgentOutputSchemaBase):
56
+ """An object that captures the JSON schema of the output, as well as validating/parsing JSON
57
+ produced by the LLM into the output type.
58
+ """
59
+
60
+ output_type: type[Any]
61
+ """The type of the output."""
62
+
63
+ _type_adapter: TypeAdapter[Any]
64
+ """A type adapter that wraps the output type, so that we can validate JSON."""
65
+
66
+ _is_wrapped: bool
67
+ """Whether the output type is wrapped in a dictionary. This is generally done if the base
68
+ output type cannot be represented as a JSON Schema object.
69
+ """
70
+
71
+ _output_schema: dict[str, Any]
72
+ """The JSON schema of the output."""
73
+
74
+ _strict_json_schema: bool
75
+ """Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True,
76
+ as it increases the likelihood of correct JSON input.
77
+ """
78
+
79
+ def __init__(self, output_type: type[Any], strict_json_schema: bool = True):
80
+ """
81
+ Args:
82
+ output_type: The type of the output.
83
+ strict_json_schema: Whether the JSON schema is in strict mode. We **strongly** recommend
84
+ setting this to True, as it increases the likelihood of correct JSON input.
85
+ """
86
+ self.output_type = output_type
87
+ self._strict_json_schema = strict_json_schema
88
+
89
+ if output_type is None or output_type is str:
90
+ self._is_wrapped = False
91
+ self._type_adapter = TypeAdapter(output_type)
92
+ self._output_schema = self._type_adapter.json_schema()
93
+ return
94
+
95
+ # We should wrap for things that are not plain text, and for things that would definitely
96
+ # not be a JSON Schema object.
97
+ self._is_wrapped = not _is_subclass_of_base_model_or_dict(output_type)
98
+
99
+ if self._is_wrapped:
100
+ OutputType = TypedDict(
101
+ "OutputType",
102
+ {
103
+ _WRAPPER_DICT_KEY: output_type, # type: ignore
104
+ },
105
+ )
106
+ self._type_adapter = TypeAdapter(OutputType)
107
+ self._output_schema = self._type_adapter.json_schema()
108
+ else:
109
+ self._type_adapter = TypeAdapter(output_type)
110
+ self._output_schema = self._type_adapter.json_schema()
111
+
112
+ if self._strict_json_schema:
113
+ try:
114
+ self._output_schema = ensure_strict_json_schema(self._output_schema)
115
+ except UserError as e:
116
+ raise UserError(
117
+ "Strict JSON schema is enabled, but the output type is not valid. "
118
+ "Either make the output type strict, "
119
+ "or wrap your type with AgentOutputSchema(YourType, strict_json_schema=False)"
120
+ ) from e
121
+
122
+ def is_plain_text(self) -> bool:
123
+ """Whether the output type is plain text (versus a JSON object)."""
124
+ return self.output_type is None or self.output_type is str
125
+
126
+ def is_strict_json_schema(self) -> bool:
127
+ """Whether the JSON schema is in strict mode."""
128
+ return self._strict_json_schema
129
+
130
+ def json_schema(self) -> dict[str, Any]:
131
+ """The JSON schema of the output type."""
132
+ if self.is_plain_text():
133
+ raise UserError("Output type is plain text, so no JSON schema is available")
134
+ return self._output_schema
135
+
136
+ def validate_json(self, json_str: str) -> Any:
137
+ """Validate a JSON string against the output type. Returns the validated object, or raises
138
+ a `ModelBehaviorError` if the JSON is invalid.
139
+ """
140
+ validated = _json.validate_json(json_str, self._type_adapter, partial=False)
141
+ if self._is_wrapped:
142
+ if not isinstance(validated, dict):
143
+ _error_tracing.attach_error_to_current_span(
144
+ SpanError(
145
+ message="Invalid JSON",
146
+ data={"details": f"Expected a dict, got {type(validated)}"},
147
+ )
148
+ )
149
+ raise ModelBehaviorError(
150
+ f"Expected a dict, got {type(validated)} for JSON: {json_str}"
151
+ )
152
+
153
+ if _WRAPPER_DICT_KEY not in validated:
154
+ _error_tracing.attach_error_to_current_span(
155
+ SpanError(
156
+ message="Invalid JSON",
157
+ data={"details": f"Could not find key {_WRAPPER_DICT_KEY} in JSON"},
158
+ )
159
+ )
160
+ raise ModelBehaviorError(
161
+ f"Could not find key {_WRAPPER_DICT_KEY} in JSON: {json_str}"
162
+ )
163
+ return validated[_WRAPPER_DICT_KEY]
164
+ return validated
165
+
166
+ def name(self) -> str:
167
+ """The name of the output type."""
168
+ return _type_to_str(self.output_type)
169
+
170
+
171
+ def _is_subclass_of_base_model_or_dict(t: Any) -> bool:
172
+ if not isinstance(t, type):
173
+ return False
174
+
175
+ # If it's a generic alias, 'origin' will be the actual type, e.g. 'list'
176
+ origin = get_origin(t)
177
+
178
+ allowed_types = (BaseModel, dict)
179
+ # If it's a generic alias e.g. list[str], then we should check the origin type i.e. list
180
+ return issubclass(origin or t, allowed_types)
181
+
182
+
183
+ def _type_to_str(t: type[Any]) -> str:
184
+ origin = get_origin(t)
185
+ args = get_args(t)
186
+
187
+ if origin is None:
188
+ # It's a simple type like `str`, `int`, etc.
189
+ return t.__name__
190
+ elif args:
191
+ args_str = ", ".join(_type_to_str(arg) for arg in args)
192
+ return f"{origin.__name__}[{args_str}]"
193
+ else:
194
+ return str(t)
agents/computer.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ from typing import Literal
3
+
4
+ Environment = Literal["mac", "windows", "ubuntu", "browser"]
5
+ Button = Literal["left", "right", "wheel", "back", "forward"]
6
+
7
+
8
+ class Computer(abc.ABC):
9
+ """A computer implemented with sync operations. The Computer interface abstracts the
10
+ operations needed to control a computer or browser."""
11
+
12
+ @property
13
+ @abc.abstractmethod
14
+ def environment(self) -> Environment:
15
+ pass
16
+
17
+ @property
18
+ @abc.abstractmethod
19
+ def dimensions(self) -> tuple[int, int]:
20
+ pass
21
+
22
+ @abc.abstractmethod
23
+ def screenshot(self) -> str:
24
+ pass
25
+
26
+ @abc.abstractmethod
27
+ def click(self, x: int, y: int, button: Button) -> None:
28
+ pass
29
+
30
+ @abc.abstractmethod
31
+ def double_click(self, x: int, y: int) -> None:
32
+ pass
33
+
34
+ @abc.abstractmethod
35
+ def scroll(self, x: int, y: int, scroll_x: int, scroll_y: int) -> None:
36
+ pass
37
+
38
+ @abc.abstractmethod
39
+ def type(self, text: str) -> None:
40
+ pass
41
+
42
+ @abc.abstractmethod
43
+ def wait(self) -> None:
44
+ pass
45
+
46
+ @abc.abstractmethod
47
+ def move(self, x: int, y: int) -> None:
48
+ pass
49
+
50
+ @abc.abstractmethod
51
+ def keypress(self, keys: list[str]) -> None:
52
+ pass
53
+
54
+ @abc.abstractmethod
55
+ def drag(self, path: list[tuple[int, int]]) -> None:
56
+ pass
57
+
58
+
59
+ class AsyncComputer(abc.ABC):
60
+ """A computer implemented with async operations. The Computer interface abstracts the
61
+ operations needed to control a computer or browser."""
62
+
63
+ @property
64
+ @abc.abstractmethod
65
+ def environment(self) -> Environment:
66
+ pass
67
+
68
+ @property
69
+ @abc.abstractmethod
70
+ def dimensions(self) -> tuple[int, int]:
71
+ pass
72
+
73
+ @abc.abstractmethod
74
+ async def screenshot(self) -> str:
75
+ pass
76
+
77
+ @abc.abstractmethod
78
+ async def click(self, x: int, y: int, button: Button) -> None:
79
+ pass
80
+
81
+ @abc.abstractmethod
82
+ async def double_click(self, x: int, y: int) -> None:
83
+ pass
84
+
85
+ @abc.abstractmethod
86
+ async def scroll(self, x: int, y: int, scroll_x: int, scroll_y: int) -> None:
87
+ pass
88
+
89
+ @abc.abstractmethod
90
+ async def type(self, text: str) -> None:
91
+ pass
92
+
93
+ @abc.abstractmethod
94
+ async def wait(self) -> None:
95
+ pass
96
+
97
+ @abc.abstractmethod
98
+ async def move(self, x: int, y: int) -> None:
99
+ pass
100
+
101
+ @abc.abstractmethod
102
+ async def keypress(self, keys: list[str]) -> None:
103
+ pass
104
+
105
+ @abc.abstractmethod
106
+ async def drag(self, path: list[tuple[int, int]]) -> None:
107
+ pass
agents/exceptions.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import TYPE_CHECKING, Any
5
+
6
+ if TYPE_CHECKING:
7
+ from .agent import Agent
8
+ from .guardrail import InputGuardrailResult, OutputGuardrailResult
9
+ from .items import ModelResponse, RunItem, TResponseInputItem
10
+ from .run_context import RunContextWrapper
11
+ from .tool_guardrails import (
12
+ ToolGuardrailFunctionOutput,
13
+ ToolInputGuardrail,
14
+ ToolOutputGuardrail,
15
+ )
16
+
17
+ from .util._pretty_print import pretty_print_run_error_details
18
+
19
+
20
+ @dataclass
21
+ class RunErrorDetails:
22
+ """Data collected from an agent run when an exception occurs."""
23
+
24
+ input: str | list[TResponseInputItem]
25
+ new_items: list[RunItem]
26
+ raw_responses: list[ModelResponse]
27
+ last_agent: Agent[Any]
28
+ context_wrapper: RunContextWrapper[Any]
29
+ input_guardrail_results: list[InputGuardrailResult]
30
+ output_guardrail_results: list[OutputGuardrailResult]
31
+
32
+ def __str__(self) -> str:
33
+ return pretty_print_run_error_details(self)
34
+
35
+
36
+ class AgentsException(Exception):
37
+ """Base class for all exceptions in the Agents SDK."""
38
+
39
+ run_data: RunErrorDetails | None
40
+
41
+ def __init__(self, *args: object) -> None:
42
+ super().__init__(*args)
43
+ self.run_data = None
44
+
45
+
46
+ class MaxTurnsExceeded(AgentsException):
47
+ """Exception raised when the maximum number of turns is exceeded."""
48
+
49
+ message: str
50
+
51
+ def __init__(self, message: str):
52
+ self.message = message
53
+ super().__init__(message)
54
+
55
+
56
+ class ModelBehaviorError(AgentsException):
57
+ """Exception raised when the model does something unexpected, e.g. calling a tool that doesn't
58
+ exist, or providing malformed JSON.
59
+ """
60
+
61
+ message: str
62
+
63
+ def __init__(self, message: str):
64
+ self.message = message
65
+ super().__init__(message)
66
+
67
+
68
+ class UserError(AgentsException):
69
+ """Exception raised when the user makes an error using the SDK."""
70
+
71
+ message: str
72
+
73
+ def __init__(self, message: str):
74
+ self.message = message
75
+ super().__init__(message)
76
+
77
+
78
+ class InputGuardrailTripwireTriggered(AgentsException):
79
+ """Exception raised when a guardrail tripwire is triggered."""
80
+
81
+ guardrail_result: InputGuardrailResult
82
+ """The result data of the guardrail that was triggered."""
83
+
84
+ def __init__(self, guardrail_result: InputGuardrailResult):
85
+ self.guardrail_result = guardrail_result
86
+ super().__init__(
87
+ f"Guardrail {guardrail_result.guardrail.__class__.__name__} triggered tripwire"
88
+ )
89
+
90
+
91
+ class OutputGuardrailTripwireTriggered(AgentsException):
92
+ """Exception raised when a guardrail tripwire is triggered."""
93
+
94
+ guardrail_result: OutputGuardrailResult
95
+ """The result data of the guardrail that was triggered."""
96
+
97
+ def __init__(self, guardrail_result: OutputGuardrailResult):
98
+ self.guardrail_result = guardrail_result
99
+ super().__init__(
100
+ f"Guardrail {guardrail_result.guardrail.__class__.__name__} triggered tripwire"
101
+ )
102
+
103
+
104
+ class ToolInputGuardrailTripwireTriggered(AgentsException):
105
+ """Exception raised when a tool input guardrail tripwire is triggered."""
106
+
107
+ guardrail: ToolInputGuardrail[Any]
108
+ """The guardrail that was triggered."""
109
+
110
+ output: ToolGuardrailFunctionOutput
111
+ """The output from the guardrail function."""
112
+
113
+ def __init__(self, guardrail: ToolInputGuardrail[Any], output: ToolGuardrailFunctionOutput):
114
+ self.guardrail = guardrail
115
+ self.output = output
116
+ super().__init__(f"Tool input guardrail {guardrail.__class__.__name__} triggered tripwire")
117
+
118
+
119
+ class ToolOutputGuardrailTripwireTriggered(AgentsException):
120
+ """Exception raised when a tool output guardrail tripwire is triggered."""
121
+
122
+ guardrail: ToolOutputGuardrail[Any]
123
+ """The guardrail that was triggered."""
124
+
125
+ output: ToolGuardrailFunctionOutput
126
+ """The output from the guardrail function."""
127
+
128
+ def __init__(self, guardrail: ToolOutputGuardrail[Any], output: ToolGuardrailFunctionOutput):
129
+ self.guardrail = guardrail
130
+ self.output = output
131
+ super().__init__(f"Tool output guardrail {guardrail.__class__.__name__} triggered tripwire")
agents/extensions/__init__.py ADDED
File without changes
agents/extensions/handoff_filters.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from ..handoffs import HandoffInputData
4
+ from ..items import (
5
+ HandoffCallItem,
6
+ HandoffOutputItem,
7
+ ReasoningItem,
8
+ RunItem,
9
+ ToolCallItem,
10
+ ToolCallOutputItem,
11
+ TResponseInputItem,
12
+ )
13
+
14
+ """Contains common handoff input filters, for convenience. """
15
+
16
+
17
+ def remove_all_tools(handoff_input_data: HandoffInputData) -> HandoffInputData:
18
+ """Filters out all tool items: file search, web search and function calls+output."""
19
+
20
+ history = handoff_input_data.input_history
21
+ new_items = handoff_input_data.new_items
22
+
23
+ filtered_history = (
24
+ _remove_tool_types_from_input(history) if isinstance(history, tuple) else history
25
+ )
26
+ filtered_pre_handoff_items = _remove_tools_from_items(handoff_input_data.pre_handoff_items)
27
+ filtered_new_items = _remove_tools_from_items(new_items)
28
+
29
+ return HandoffInputData(
30
+ input_history=filtered_history,
31
+ pre_handoff_items=filtered_pre_handoff_items,
32
+ new_items=filtered_new_items,
33
+ run_context=handoff_input_data.run_context,
34
+ )
35
+
36
+
37
+ def _remove_tools_from_items(items: tuple[RunItem, ...]) -> tuple[RunItem, ...]:
38
+ filtered_items = []
39
+ for item in items:
40
+ if (
41
+ isinstance(item, HandoffCallItem)
42
+ or isinstance(item, HandoffOutputItem)
43
+ or isinstance(item, ToolCallItem)
44
+ or isinstance(item, ToolCallOutputItem)
45
+ or isinstance(item, ReasoningItem)
46
+ ):
47
+ continue
48
+ filtered_items.append(item)
49
+ return tuple(filtered_items)
50
+
51
+
52
+ def _remove_tool_types_from_input(
53
+ items: tuple[TResponseInputItem, ...],
54
+ ) -> tuple[TResponseInputItem, ...]:
55
+ tool_types = [
56
+ "function_call",
57
+ "function_call_output",
58
+ "computer_call",
59
+ "computer_call_output",
60
+ "file_search_call",
61
+ "web_search_call",
62
+ ]
63
+
64
+ filtered_items: list[TResponseInputItem] = []
65
+ for item in items:
66
+ itype = item.get("type")
67
+ if itype in tool_types:
68
+ continue
69
+ filtered_items.append(item)
70
+ return tuple(filtered_items)
agents/extensions/handoff_prompt.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # A recommended prompt prefix for agents that use handoffs. We recommend including this or
2
+ # similar instructions in any agents that use handoffs.
3
+ RECOMMENDED_PROMPT_PREFIX = (
4
+ "# System context\n"
5
+ "You are part of a multi-agent system called the Agents SDK, designed to make agent "
6
+ "coordination and execution easy. Agents uses two primary abstraction: **Agents** and "
7
+ "**Handoffs**. An agent encompasses instructions and tools and can hand off a "
8
+ "conversation to another agent when appropriate. "
9
+ "Handoffs are achieved by calling a handoff function, generally named "
10
+ "`transfer_to_<agent_name>`. Transfers between agents are handled seamlessly in the background;"
11
+ " do not mention or draw attention to these transfers in your conversation with the user.\n"
12
+ )
13
+
14
+
15
+ def prompt_with_handoff_instructions(prompt: str) -> str:
16
+ """
17
+ Add recommended instructions to the prompt for agents that use handoffs.
18
+ """
19
+ return f"{RECOMMENDED_PROMPT_PREFIX}\n\n{prompt}"
agents/extensions/memory/__init__.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Session memory backends living in the extensions namespace.
2
+
3
+ This package contains optional, production-grade session implementations that
4
+ introduce extra third-party dependencies (database drivers, ORMs, etc.). They
5
+ conform to the :class:`agents.memory.session.Session` protocol so they can be
6
+ used as a drop-in replacement for :class:`agents.memory.session.SQLiteSession`.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from typing import Any
12
+
13
+ __all__: list[str] = [
14
+ "EncryptedSession",
15
+ "RedisSession",
16
+ "SQLAlchemySession",
17
+ "AdvancedSQLiteSession",
18
+ ]
19
+
20
+
21
+ def __getattr__(name: str) -> Any:
22
+ if name == "EncryptedSession":
23
+ try:
24
+ from .encrypt_session import EncryptedSession # noqa: F401
25
+
26
+ return EncryptedSession
27
+ except ModuleNotFoundError as e:
28
+ raise ImportError(
29
+ "EncryptedSession requires the 'cryptography' extra. "
30
+ "Install it with: pip install openai-agents[encrypt]"
31
+ ) from e
32
+
33
+ if name == "RedisSession":
34
+ try:
35
+ from .redis_session import RedisSession # noqa: F401
36
+
37
+ return RedisSession
38
+ except ModuleNotFoundError as e:
39
+ raise ImportError(
40
+ "RedisSession requires the 'redis' extra. "
41
+ "Install it with: pip install openai-agents[redis]"
42
+ ) from e
43
+
44
+ if name == "SQLAlchemySession":
45
+ try:
46
+ from .sqlalchemy_session import SQLAlchemySession # noqa: F401
47
+
48
+ return SQLAlchemySession
49
+ except ModuleNotFoundError as e:
50
+ raise ImportError(
51
+ "SQLAlchemySession requires the 'sqlalchemy' extra. "
52
+ "Install it with: pip install openai-agents[sqlalchemy]"
53
+ ) from e
54
+
55
+ if name == "AdvancedSQLiteSession":
56
+ try:
57
+ from .advanced_sqlite_session import AdvancedSQLiteSession # noqa: F401
58
+
59
+ return AdvancedSQLiteSession
60
+ except ModuleNotFoundError as e:
61
+ raise ImportError(
62
+ f"Failed to import AdvancedSQLiteSession: {e}"
63
+ ) from e
64
+
65
+ raise AttributeError(f"module {__name__} has no attribute {name}")
agents/extensions/memory/advanced_sqlite_session.py ADDED
@@ -0,0 +1,1285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import json
5
+ import logging
6
+ import threading
7
+ from contextlib import closing
8
+ from pathlib import Path
9
+ from typing import Any, Union, cast
10
+
11
+ from agents.result import RunResult
12
+ from agents.usage import Usage
13
+
14
+ from ...items import TResponseInputItem
15
+ from ...memory import SQLiteSession
16
+
17
+
18
+ class AdvancedSQLiteSession(SQLiteSession):
19
+ """Enhanced SQLite session with conversation branching and usage analytics."""
20
+
21
+ def __init__(
22
+ self,
23
+ *,
24
+ session_id: str,
25
+ db_path: str | Path = ":memory:",
26
+ create_tables: bool = False,
27
+ logger: logging.Logger | None = None,
28
+ **kwargs,
29
+ ):
30
+ """Initialize the AdvancedSQLiteSession.
31
+
32
+ Args:
33
+ session_id: The ID of the session
34
+ db_path: The path to the SQLite database file. Defaults to `:memory:` for in-memory storage
35
+ create_tables: Whether to create the structure tables
36
+ logger: The logger to use. Defaults to the module logger
37
+ **kwargs: Additional keyword arguments to pass to the superclass
38
+ """ # noqa: E501
39
+ super().__init__(session_id, db_path, **kwargs)
40
+ if create_tables:
41
+ self._init_structure_tables()
42
+ self._current_branch_id = "main"
43
+ self._logger = logger or logging.getLogger(__name__)
44
+
45
+ def _init_structure_tables(self):
46
+ """Add structure and usage tracking tables.
47
+
48
+ Creates the message_structure and turn_usage tables with appropriate
49
+ indexes for conversation branching and usage analytics.
50
+ """
51
+ conn = self._get_connection()
52
+
53
+ # Message structure with branch support
54
+ conn.execute("""
55
+ CREATE TABLE IF NOT EXISTS message_structure (
56
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
57
+ session_id TEXT NOT NULL,
58
+ message_id INTEGER NOT NULL,
59
+ branch_id TEXT NOT NULL DEFAULT 'main',
60
+ message_type TEXT NOT NULL,
61
+ sequence_number INTEGER NOT NULL,
62
+ user_turn_number INTEGER,
63
+ branch_turn_number INTEGER,
64
+ tool_name TEXT,
65
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
66
+ FOREIGN KEY (session_id) REFERENCES agent_sessions(session_id) ON DELETE CASCADE,
67
+ FOREIGN KEY (message_id) REFERENCES agent_messages(id) ON DELETE CASCADE
68
+ )
69
+ """)
70
+
71
+ # Turn-level usage tracking with branch support and full JSON details
72
+ conn.execute("""
73
+ CREATE TABLE IF NOT EXISTS turn_usage (
74
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
75
+ session_id TEXT NOT NULL,
76
+ branch_id TEXT NOT NULL DEFAULT 'main',
77
+ user_turn_number INTEGER NOT NULL,
78
+ requests INTEGER DEFAULT 0,
79
+ input_tokens INTEGER DEFAULT 0,
80
+ output_tokens INTEGER DEFAULT 0,
81
+ total_tokens INTEGER DEFAULT 0,
82
+ input_tokens_details JSON,
83
+ output_tokens_details JSON,
84
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
85
+ FOREIGN KEY (session_id) REFERENCES agent_sessions(session_id) ON DELETE CASCADE,
86
+ UNIQUE(session_id, branch_id, user_turn_number)
87
+ )
88
+ """)
89
+
90
+ # Indexes
91
+ conn.execute("""
92
+ CREATE INDEX IF NOT EXISTS idx_structure_session_seq
93
+ ON message_structure(session_id, sequence_number)
94
+ """)
95
+ conn.execute("""
96
+ CREATE INDEX IF NOT EXISTS idx_structure_branch
97
+ ON message_structure(session_id, branch_id)
98
+ """)
99
+ conn.execute("""
100
+ CREATE INDEX IF NOT EXISTS idx_structure_turn
101
+ ON message_structure(session_id, branch_id, user_turn_number)
102
+ """)
103
+ conn.execute("""
104
+ CREATE INDEX IF NOT EXISTS idx_structure_branch_seq
105
+ ON message_structure(session_id, branch_id, sequence_number)
106
+ """)
107
+ conn.execute("""
108
+ CREATE INDEX IF NOT EXISTS idx_turn_usage_session_turn
109
+ ON turn_usage(session_id, branch_id, user_turn_number)
110
+ """)
111
+
112
+ conn.commit()
113
+
114
+ async def add_items(self, items: list[TResponseInputItem]) -> None:
115
+ """Add items to the session.
116
+
117
+ Args:
118
+ items: The items to add to the session
119
+ """
120
+ # Add to base table first
121
+ await super().add_items(items)
122
+
123
+ # Extract structure metadata with precise sequencing
124
+ if items:
125
+ await self._add_structure_metadata(items)
126
+
127
+ async def get_items(
128
+ self,
129
+ limit: int | None = None,
130
+ branch_id: str | None = None,
131
+ ) -> list[TResponseInputItem]:
132
+ """Get items from current or specified branch.
133
+
134
+ Args:
135
+ limit: Maximum number of items to return. If None, returns all items.
136
+ branch_id: Branch to get items from. If None, uses current branch.
137
+
138
+ Returns:
139
+ List of conversation items from the specified branch.
140
+ """
141
+ if branch_id is None:
142
+ branch_id = self._current_branch_id
143
+
144
+ # Get all items for this branch
145
+ def _get_all_items_sync():
146
+ """Synchronous helper to get all items for a branch."""
147
+ conn = self._get_connection()
148
+ # TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
149
+ with self._lock if self._is_memory_db else threading.Lock():
150
+ with closing(conn.cursor()) as cursor:
151
+ if limit is None:
152
+ cursor.execute(
153
+ """
154
+ SELECT m.message_data
155
+ FROM agent_messages m
156
+ JOIN message_structure s ON m.id = s.message_id
157
+ WHERE m.session_id = ? AND s.branch_id = ?
158
+ ORDER BY s.sequence_number ASC
159
+ """,
160
+ (self.session_id, branch_id),
161
+ )
162
+ else:
163
+ cursor.execute(
164
+ """
165
+ SELECT m.message_data
166
+ FROM agent_messages m
167
+ JOIN message_structure s ON m.id = s.message_id
168
+ WHERE m.session_id = ? AND s.branch_id = ?
169
+ ORDER BY s.sequence_number DESC
170
+ LIMIT ?
171
+ """,
172
+ (self.session_id, branch_id, limit),
173
+ )
174
+
175
+ rows = cursor.fetchall()
176
+ if limit is not None:
177
+ rows = list(reversed(rows))
178
+
179
+ items = []
180
+ for (message_data,) in rows:
181
+ try:
182
+ item = json.loads(message_data)
183
+ items.append(item)
184
+ except json.JSONDecodeError:
185
+ continue
186
+ return items
187
+
188
+ return await asyncio.to_thread(_get_all_items_sync)
189
+
190
+ def _get_items_sync():
191
+ """Synchronous helper to get items for a specific branch."""
192
+ conn = self._get_connection()
193
+ # TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
194
+ with self._lock if self._is_memory_db else threading.Lock():
195
+ with closing(conn.cursor()) as cursor:
196
+ # Get message IDs in correct order for this branch
197
+ if limit is None:
198
+ cursor.execute(
199
+ """
200
+ SELECT m.message_data
201
+ FROM agent_messages m
202
+ JOIN message_structure s ON m.id = s.message_id
203
+ WHERE m.session_id = ? AND s.branch_id = ?
204
+ ORDER BY s.sequence_number ASC
205
+ """,
206
+ (self.session_id, branch_id),
207
+ )
208
+ else:
209
+ cursor.execute(
210
+ """
211
+ SELECT m.message_data
212
+ FROM agent_messages m
213
+ JOIN message_structure s ON m.id = s.message_id
214
+ WHERE m.session_id = ? AND s.branch_id = ?
215
+ ORDER BY s.sequence_number DESC
216
+ LIMIT ?
217
+ """,
218
+ (self.session_id, branch_id, limit),
219
+ )
220
+
221
+ rows = cursor.fetchall()
222
+ if limit is not None:
223
+ rows = list(reversed(rows))
224
+
225
+ items = []
226
+ for (message_data,) in rows:
227
+ try:
228
+ item = json.loads(message_data)
229
+ items.append(item)
230
+ except json.JSONDecodeError:
231
+ continue
232
+ return items
233
+
234
+ return await asyncio.to_thread(_get_items_sync)
235
+
236
+ async def store_run_usage(self, result: RunResult) -> None:
237
+ """Store usage data for the current conversation turn.
238
+
239
+ This is designed to be called after `Runner.run()` completes.
240
+ Session-level usage can be aggregated from turn data when needed.
241
+
242
+ Args:
243
+ result: The result from the run
244
+ """
245
+ try:
246
+ if result.context_wrapper.usage is not None:
247
+ # Get the current turn number for this branch
248
+ current_turn = self._get_current_turn_number()
249
+ # Only update turn-level usage - session usage is aggregated on demand
250
+ await self._update_turn_usage_internal(current_turn, result.context_wrapper.usage)
251
+ except Exception as e:
252
+ self._logger.error(f"Failed to store usage for session {self.session_id}: {e}")
253
+
254
+ def _get_next_turn_number(self, branch_id: str) -> int:
255
+ """Get the next turn number for a specific branch.
256
+
257
+ Args:
258
+ branch_id: The branch ID to get the next turn number for.
259
+
260
+ Returns:
261
+ The next available turn number for the specified branch.
262
+ """
263
+ conn = self._get_connection()
264
+ with closing(conn.cursor()) as cursor:
265
+ cursor.execute(
266
+ """
267
+ SELECT COALESCE(MAX(user_turn_number), 0)
268
+ FROM message_structure
269
+ WHERE session_id = ? AND branch_id = ?
270
+ """,
271
+ (self.session_id, branch_id),
272
+ )
273
+ result = cursor.fetchone()
274
+ max_turn = result[0] if result else 0
275
+ return max_turn + 1
276
+
277
+ def _get_next_branch_turn_number(self, branch_id: str) -> int:
278
+ """Get the next branch turn number for a specific branch.
279
+
280
+ Args:
281
+ branch_id: The branch ID to get the next branch turn number for.
282
+
283
+ Returns:
284
+ The next available branch turn number for the specified branch.
285
+ """
286
+ conn = self._get_connection()
287
+ with closing(conn.cursor()) as cursor:
288
+ cursor.execute(
289
+ """
290
+ SELECT COALESCE(MAX(branch_turn_number), 0)
291
+ FROM message_structure
292
+ WHERE session_id = ? AND branch_id = ?
293
+ """,
294
+ (self.session_id, branch_id),
295
+ )
296
+ result = cursor.fetchone()
297
+ max_turn = result[0] if result else 0
298
+ return max_turn + 1
299
+
300
+ def _get_current_turn_number(self) -> int:
301
+ """Get the current turn number for the current branch.
302
+
303
+ Returns:
304
+ The current turn number for the active branch.
305
+ """
306
+ conn = self._get_connection()
307
+ with closing(conn.cursor()) as cursor:
308
+ cursor.execute(
309
+ """
310
+ SELECT COALESCE(MAX(user_turn_number), 0)
311
+ FROM message_structure
312
+ WHERE session_id = ? AND branch_id = ?
313
+ """,
314
+ (self.session_id, self._current_branch_id),
315
+ )
316
+ result = cursor.fetchone()
317
+ return result[0] if result else 0
318
+
319
+ async def _add_structure_metadata(self, items: list[TResponseInputItem]) -> None:
320
+ """Extract structure metadata with branch-aware turn tracking.
321
+
322
+ This method:
323
+ - Assigns turn numbers per branch (not globally)
324
+ - Assigns explicit sequence numbers for precise ordering
325
+ - Links messages to their database IDs for structure tracking
326
+ - Handles multiple user messages in a single batch correctly
327
+
328
+ Args:
329
+ items: The items to add to the session
330
+ """
331
+
332
+ def _add_structure_sync():
333
+ """Synchronous helper to add structure metadata to database."""
334
+ conn = self._get_connection()
335
+ # TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
336
+ with self._lock if self._is_memory_db else threading.Lock():
337
+ # Get the IDs of messages we just inserted, in order
338
+ with closing(conn.cursor()) as cursor:
339
+ cursor.execute(
340
+ f"SELECT id FROM {self.messages_table} "
341
+ f"WHERE session_id = ? ORDER BY id DESC LIMIT ?",
342
+ (self.session_id, len(items)),
343
+ )
344
+ message_ids = [row[0] for row in cursor.fetchall()]
345
+ message_ids.reverse() # Match order of items
346
+
347
+ # Get current max sequence number (global)
348
+ with closing(conn.cursor()) as cursor:
349
+ cursor.execute(
350
+ """
351
+ SELECT COALESCE(MAX(sequence_number), 0)
352
+ FROM message_structure
353
+ WHERE session_id = ?
354
+ """,
355
+ (self.session_id,),
356
+ )
357
+ seq_start = cursor.fetchone()[0]
358
+
359
+ # Get current turn numbers atomically with a single query
360
+ with closing(conn.cursor()) as cursor:
361
+ cursor.execute(
362
+ """
363
+ SELECT
364
+ COALESCE(MAX(user_turn_number), 0) as max_global_turn,
365
+ COALESCE(MAX(branch_turn_number), 0) as max_branch_turn
366
+ FROM message_structure
367
+ WHERE session_id = ? AND branch_id = ?
368
+ """,
369
+ (self.session_id, self._current_branch_id),
370
+ )
371
+ result = cursor.fetchone()
372
+ current_turn = result[0] if result else 0
373
+ current_branch_turn = result[1] if result else 0
374
+
375
+ # Process items and assign turn numbers correctly
376
+ structure_data = []
377
+ user_message_count = 0
378
+
379
+ for i, (item, msg_id) in enumerate(zip(items, message_ids)):
380
+ msg_type = self._classify_message_type(item)
381
+ tool_name = self._extract_tool_name(item)
382
+
383
+ # If this is a user message, increment turn counters
384
+ if self._is_user_message(item):
385
+ user_message_count += 1
386
+ item_turn = current_turn + user_message_count
387
+ item_branch_turn = current_branch_turn + user_message_count
388
+ else:
389
+ # Non-user messages inherit the turn number of the most recent user message
390
+ item_turn = current_turn + user_message_count
391
+ item_branch_turn = current_branch_turn + user_message_count
392
+
393
+ structure_data.append(
394
+ (
395
+ self.session_id,
396
+ msg_id,
397
+ self._current_branch_id,
398
+ msg_type,
399
+ seq_start + i + 1, # Global sequence
400
+ item_turn, # Global turn number
401
+ item_branch_turn, # Branch-specific turn number
402
+ tool_name,
403
+ )
404
+ )
405
+
406
+ with closing(conn.cursor()) as cursor:
407
+ cursor.executemany(
408
+ """
409
+ INSERT INTO message_structure
410
+ (session_id, message_id, branch_id, message_type, sequence_number,
411
+ user_turn_number, branch_turn_number, tool_name)
412
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?)
413
+ """,
414
+ structure_data,
415
+ )
416
+ conn.commit()
417
+
418
+ try:
419
+ await asyncio.to_thread(_add_structure_sync)
420
+ except Exception as e:
421
+ self._logger.error(
422
+ f"Failed to add structure metadata for session {self.session_id}: {e}"
423
+ )
424
+ # Try to clean up any orphaned messages to maintain consistency
425
+ try:
426
+ await self._cleanup_orphaned_messages()
427
+ except Exception as cleanup_error:
428
+ self._logger.error(f"Failed to cleanup orphaned messages: {cleanup_error}")
429
+ # Don't re-raise - structure metadata is supplementary
430
+
431
+ async def _cleanup_orphaned_messages(self) -> None:
432
+ """Remove messages that exist in agent_messages but not in message_structure.
433
+
434
+ This can happen if _add_structure_metadata fails after super().add_items() succeeds.
435
+ Used for maintaining data consistency.
436
+ """
437
+
438
+ def _cleanup_sync():
439
+ """Synchronous helper to cleanup orphaned messages."""
440
+ conn = self._get_connection()
441
+ # TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
442
+ with self._lock if self._is_memory_db else threading.Lock():
443
+ with closing(conn.cursor()) as cursor:
444
+ # Find messages without structure metadata
445
+ cursor.execute(
446
+ """
447
+ SELECT am.id
448
+ FROM agent_messages am
449
+ LEFT JOIN message_structure ms ON am.id = ms.message_id
450
+ WHERE am.session_id = ? AND ms.message_id IS NULL
451
+ """,
452
+ (self.session_id,),
453
+ )
454
+
455
+ orphaned_ids = [row[0] for row in cursor.fetchall()]
456
+
457
+ if orphaned_ids:
458
+ # Delete orphaned messages
459
+ placeholders = ",".join("?" * len(orphaned_ids))
460
+ cursor.execute(
461
+ f"DELETE FROM agent_messages WHERE id IN ({placeholders})", orphaned_ids
462
+ )
463
+
464
+ deleted_count = cursor.rowcount
465
+ conn.commit()
466
+
467
+ self._logger.info(f"Cleaned up {deleted_count} orphaned messages")
468
+ return deleted_count
469
+
470
+ return 0
471
+
472
+ return await asyncio.to_thread(_cleanup_sync)
473
+
474
+ def _classify_message_type(self, item: TResponseInputItem) -> str:
475
+ """Classify the type of a message item.
476
+
477
+ Args:
478
+ item: The message item to classify.
479
+
480
+ Returns:
481
+ String representing the message type (user, assistant, etc.).
482
+ """
483
+ if isinstance(item, dict):
484
+ if item.get("role") == "user":
485
+ return "user"
486
+ elif item.get("role") == "assistant":
487
+ return "assistant"
488
+ elif item.get("type"):
489
+ return str(item.get("type"))
490
+ return "other"
491
+
492
+ def _extract_tool_name(self, item: TResponseInputItem) -> str | None:
493
+ """Extract tool name if this is a tool call/output.
494
+
495
+ Args:
496
+ item: The message item to extract tool name from.
497
+
498
+ Returns:
499
+ Tool name if item is a tool call, None otherwise.
500
+ """
501
+ if isinstance(item, dict):
502
+ item_type = item.get("type")
503
+
504
+ # For MCP tools, try to extract from server_label if available
505
+ if item_type in {"mcp_call", "mcp_approval_request"} and "server_label" in item:
506
+ server_label = item.get("server_label")
507
+ tool_name = item.get("name")
508
+ if tool_name and server_label:
509
+ return f"{server_label}.{tool_name}"
510
+ elif server_label:
511
+ return str(server_label)
512
+ elif tool_name:
513
+ return str(tool_name)
514
+
515
+ # For tool types without a 'name' field, derive from the type
516
+ elif item_type in {
517
+ "computer_call",
518
+ "file_search_call",
519
+ "web_search_call",
520
+ "code_interpreter_call",
521
+ }:
522
+ return item_type
523
+
524
+ # Most other tool calls have a 'name' field
525
+ elif "name" in item:
526
+ name = item.get("name")
527
+ return str(name) if name is not None else None
528
+
529
+ return None
530
+
531
+ def _is_user_message(self, item: TResponseInputItem) -> bool:
532
+ """Check if this is a user message.
533
+
534
+ Args:
535
+ item: The message item to check.
536
+
537
+ Returns:
538
+ True if the item is a user message, False otherwise.
539
+ """
540
+ return isinstance(item, dict) and item.get("role") == "user"
541
+
542
+ async def create_branch_from_turn(
543
+ self, turn_number: int, branch_name: str | None = None
544
+ ) -> str:
545
+ """Create a new branch starting from a specific user message turn.
546
+
547
+ Args:
548
+ turn_number: The branch turn number of the user message to branch from
549
+ branch_name: Optional name for the branch (auto-generated if None)
550
+
551
+ Returns:
552
+ The branch_id of the newly created branch
553
+
554
+ Raises:
555
+ ValueError: If turn doesn't exist or doesn't contain a user message
556
+ """
557
+ import time
558
+
559
+ # Validate the turn exists and contains a user message
560
+ def _validate_turn():
561
+ """Synchronous helper to validate turn exists and contains user message."""
562
+ conn = self._get_connection()
563
+ with closing(conn.cursor()) as cursor:
564
+ cursor.execute(
565
+ """
566
+ SELECT am.message_data
567
+ FROM message_structure ms
568
+ JOIN agent_messages am ON ms.message_id = am.id
569
+ WHERE ms.session_id = ? AND ms.branch_id = ?
570
+ AND ms.branch_turn_number = ? AND ms.message_type = 'user'
571
+ """,
572
+ (self.session_id, self._current_branch_id, turn_number),
573
+ )
574
+
575
+ result = cursor.fetchone()
576
+ if not result:
577
+ raise ValueError(
578
+ f"Turn {turn_number} does not contain a user message "
579
+ f"in branch '{self._current_branch_id}'"
580
+ )
581
+
582
+ message_data = result[0]
583
+ try:
584
+ content = json.loads(message_data).get("content", "")
585
+ return content[:50] + "..." if len(content) > 50 else content
586
+ except Exception:
587
+ return "Unable to parse content"
588
+
589
+ turn_content = await asyncio.to_thread(_validate_turn)
590
+
591
+ # Generate branch name if not provided
592
+ if branch_name is None:
593
+ timestamp = int(time.time())
594
+ branch_name = f"branch_from_turn_{turn_number}_{timestamp}"
595
+
596
+ # Copy messages before the branch point to the new branch
597
+ await self._copy_messages_to_new_branch(branch_name, turn_number)
598
+
599
+ # Switch to new branch
600
+ old_branch = self._current_branch_id
601
+ self._current_branch_id = branch_name
602
+
603
+ self._logger.debug(
604
+ f"Created branch '{branch_name}' from turn {turn_number} ('{turn_content}') in '{old_branch}'" # noqa: E501
605
+ )
606
+ return branch_name
607
+
608
+ async def create_branch_from_content(
609
+ self, search_term: str, branch_name: str | None = None
610
+ ) -> str:
611
+ """Create branch from the first user turn matching the search term.
612
+
613
+ Args:
614
+ search_term: Text to search for in user messages.
615
+ branch_name: Optional name for the branch (auto-generated if None).
616
+
617
+ Returns:
618
+ The branch_id of the newly created branch.
619
+
620
+ Raises:
621
+ ValueError: If no matching turns are found.
622
+ """
623
+ matching_turns = await self.find_turns_by_content(search_term)
624
+ if not matching_turns:
625
+ raise ValueError(f"No user turns found containing '{search_term}'")
626
+
627
+ # Use the first (earliest) match
628
+ turn_number = matching_turns[0]["turn"]
629
+ return await self.create_branch_from_turn(turn_number, branch_name)
630
+
631
+ async def switch_to_branch(self, branch_id: str) -> None:
632
+ """Switch to a different branch.
633
+
634
+ Args:
635
+ branch_id: The branch to switch to.
636
+
637
+ Raises:
638
+ ValueError: If the branch doesn't exist.
639
+ """
640
+
641
+ # Validate branch exists
642
+ def _validate_branch():
643
+ """Synchronous helper to validate branch exists."""
644
+ conn = self._get_connection()
645
+ with closing(conn.cursor()) as cursor:
646
+ cursor.execute(
647
+ """
648
+ SELECT COUNT(*) FROM message_structure
649
+ WHERE session_id = ? AND branch_id = ?
650
+ """,
651
+ (self.session_id, branch_id),
652
+ )
653
+
654
+ count = cursor.fetchone()[0]
655
+ if count == 0:
656
+ raise ValueError(f"Branch '{branch_id}' does not exist")
657
+
658
+ await asyncio.to_thread(_validate_branch)
659
+
660
+ old_branch = self._current_branch_id
661
+ self._current_branch_id = branch_id
662
+ self._logger.info(f"Switched from branch '{old_branch}' to '{branch_id}'")
663
+
664
+ async def delete_branch(self, branch_id: str, force: bool = False) -> None:
665
+ """Delete a branch and all its associated data.
666
+
667
+ Args:
668
+ branch_id: The branch to delete.
669
+ force: If True, allows deleting the current branch (will switch to 'main').
670
+
671
+ Raises:
672
+ ValueError: If branch doesn't exist, is 'main', or is current branch without force.
673
+ """
674
+ if not branch_id or not branch_id.strip():
675
+ raise ValueError("Branch ID cannot be empty")
676
+
677
+ branch_id = branch_id.strip()
678
+
679
+ # Protect main branch
680
+ if branch_id == "main":
681
+ raise ValueError("Cannot delete the 'main' branch")
682
+
683
+ # Check if trying to delete current branch
684
+ if branch_id == self._current_branch_id:
685
+ if not force:
686
+ raise ValueError(
687
+ f"Cannot delete current branch '{branch_id}'. Use force=True or switch branches first" # noqa: E501
688
+ )
689
+ else:
690
+ # Switch to main before deleting
691
+ await self.switch_to_branch("main")
692
+
693
+ def _delete_sync():
694
+ """Synchronous helper to delete branch and associated data."""
695
+ conn = self._get_connection()
696
+ # TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
697
+ with self._lock if self._is_memory_db else threading.Lock():
698
+ with closing(conn.cursor()) as cursor:
699
+ # First verify the branch exists
700
+ cursor.execute(
701
+ """
702
+ SELECT COUNT(*) FROM message_structure
703
+ WHERE session_id = ? AND branch_id = ?
704
+ """,
705
+ (self.session_id, branch_id),
706
+ )
707
+
708
+ count = cursor.fetchone()[0]
709
+ if count == 0:
710
+ raise ValueError(f"Branch '{branch_id}' does not exist")
711
+
712
+ # Delete from turn_usage first (foreign key constraint)
713
+ cursor.execute(
714
+ """
715
+ DELETE FROM turn_usage
716
+ WHERE session_id = ? AND branch_id = ?
717
+ """,
718
+ (self.session_id, branch_id),
719
+ )
720
+
721
+ usage_deleted = cursor.rowcount
722
+
723
+ # Delete from message_structure
724
+ cursor.execute(
725
+ """
726
+ DELETE FROM message_structure
727
+ WHERE session_id = ? AND branch_id = ?
728
+ """,
729
+ (self.session_id, branch_id),
730
+ )
731
+
732
+ structure_deleted = cursor.rowcount
733
+
734
+ conn.commit()
735
+
736
+ return usage_deleted, structure_deleted
737
+
738
+ usage_deleted, structure_deleted = await asyncio.to_thread(_delete_sync)
739
+
740
+ self._logger.info(
741
+ f"Deleted branch '{branch_id}': {structure_deleted} message entries, {usage_deleted} usage entries" # noqa: E501
742
+ )
743
+
744
+ async def list_branches(self) -> list[dict[str, Any]]:
745
+ """List all branches in this session.
746
+
747
+ Returns:
748
+ List of dicts with branch info containing:
749
+ - 'branch_id': Branch identifier
750
+ - 'message_count': Number of messages in branch
751
+ - 'user_turns': Number of user turns in branch
752
+ - 'is_current': Whether this is the current branch
753
+ - 'created_at': When the branch was first created
754
+ """
755
+
756
+ def _list_branches_sync():
757
+ """Synchronous helper to list all branches."""
758
+ conn = self._get_connection()
759
+ with closing(conn.cursor()) as cursor:
760
+ cursor.execute(
761
+ """
762
+ SELECT
763
+ ms.branch_id,
764
+ COUNT(*) as message_count,
765
+ COUNT(CASE WHEN ms.message_type = 'user' THEN 1 END) as user_turns,
766
+ MIN(ms.created_at) as created_at
767
+ FROM message_structure ms
768
+ WHERE ms.session_id = ?
769
+ GROUP BY ms.branch_id
770
+ ORDER BY created_at
771
+ """,
772
+ (self.session_id,),
773
+ )
774
+
775
+ branches = []
776
+ for row in cursor.fetchall():
777
+ branch_id, msg_count, user_turns, created_at = row
778
+ branches.append(
779
+ {
780
+ "branch_id": branch_id,
781
+ "message_count": msg_count,
782
+ "user_turns": user_turns,
783
+ "is_current": branch_id == self._current_branch_id,
784
+ "created_at": created_at,
785
+ }
786
+ )
787
+
788
+ return branches
789
+
790
+ return await asyncio.to_thread(_list_branches_sync)
791
+
792
+ async def _copy_messages_to_new_branch(self, new_branch_id: str, from_turn_number: int) -> None:
793
+ """Copy messages before the branch point to the new branch.
794
+
795
+ Args:
796
+ new_branch_id: The ID of the new branch to copy messages to.
797
+ from_turn_number: The turn number to copy messages up to (exclusive).
798
+ """
799
+
800
+ def _copy_sync():
801
+ """Synchronous helper to copy messages to new branch."""
802
+ conn = self._get_connection()
803
+ # TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
804
+ with self._lock if self._is_memory_db else threading.Lock():
805
+ with closing(conn.cursor()) as cursor:
806
+ # Get all messages before the branch point
807
+ cursor.execute(
808
+ """
809
+ SELECT
810
+ ms.message_id,
811
+ ms.message_type,
812
+ ms.sequence_number,
813
+ ms.user_turn_number,
814
+ ms.branch_turn_number,
815
+ ms.tool_name
816
+ FROM message_structure ms
817
+ WHERE ms.session_id = ? AND ms.branch_id = ?
818
+ AND ms.branch_turn_number < ?
819
+ ORDER BY ms.sequence_number
820
+ """,
821
+ (self.session_id, self._current_branch_id, from_turn_number),
822
+ )
823
+
824
+ messages_to_copy = cursor.fetchall()
825
+
826
+ if messages_to_copy:
827
+ # Get the max sequence number for the new inserts
828
+ cursor.execute(
829
+ """
830
+ SELECT COALESCE(MAX(sequence_number), 0)
831
+ FROM message_structure
832
+ WHERE session_id = ?
833
+ """,
834
+ (self.session_id,),
835
+ )
836
+
837
+ seq_start = cursor.fetchone()[0]
838
+
839
+ # Insert copied messages with new branch_id
840
+ new_structure_data = []
841
+ for i, (
842
+ msg_id,
843
+ msg_type,
844
+ _,
845
+ user_turn,
846
+ branch_turn,
847
+ tool_name,
848
+ ) in enumerate(messages_to_copy):
849
+ new_structure_data.append(
850
+ (
851
+ self.session_id,
852
+ msg_id, # Same message_id (sharing the actual message data)
853
+ new_branch_id,
854
+ msg_type,
855
+ seq_start + i + 1, # New sequence number
856
+ user_turn, # Keep same global turn number
857
+ branch_turn, # Keep same branch turn number
858
+ tool_name,
859
+ )
860
+ )
861
+
862
+ cursor.executemany(
863
+ """
864
+ INSERT INTO message_structure
865
+ (session_id, message_id, branch_id, message_type, sequence_number,
866
+ user_turn_number, branch_turn_number, tool_name)
867
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?)
868
+ """,
869
+ new_structure_data,
870
+ )
871
+
872
+ conn.commit()
873
+
874
+ await asyncio.to_thread(_copy_sync)
875
+
876
+ async def get_conversation_turns(self, branch_id: str | None = None) -> list[dict[str, Any]]:
877
+ """Get user turns with content for easy browsing and branching decisions.
878
+
879
+ Args:
880
+ branch_id: Branch to get turns from (current branch if None).
881
+
882
+ Returns:
883
+ List of dicts with turn info containing:
884
+ - 'turn': Branch turn number
885
+ - 'content': User message content (truncated)
886
+ - 'full_content': Full user message content
887
+ - 'timestamp': When the turn was created
888
+ - 'can_branch': Always True (all user messages can branch)
889
+ """
890
+ if branch_id is None:
891
+ branch_id = self._current_branch_id
892
+
893
+ def _get_turns_sync():
894
+ """Synchronous helper to get conversation turns."""
895
+ conn = self._get_connection()
896
+ with closing(conn.cursor()) as cursor:
897
+ cursor.execute(
898
+ """
899
+ SELECT
900
+ ms.branch_turn_number,
901
+ am.message_data,
902
+ ms.created_at
903
+ FROM message_structure ms
904
+ JOIN agent_messages am ON ms.message_id = am.id
905
+ WHERE ms.session_id = ? AND ms.branch_id = ?
906
+ AND ms.message_type = 'user'
907
+ ORDER BY ms.branch_turn_number
908
+ """,
909
+ (self.session_id, branch_id),
910
+ )
911
+
912
+ turns = []
913
+ for row in cursor.fetchall():
914
+ turn_num, message_data, created_at = row
915
+ try:
916
+ content = json.loads(message_data).get("content", "")
917
+ turns.append(
918
+ {
919
+ "turn": turn_num,
920
+ "content": content[:100] + "..." if len(content) > 100 else content,
921
+ "full_content": content,
922
+ "timestamp": created_at,
923
+ "can_branch": True,
924
+ }
925
+ )
926
+ except (json.JSONDecodeError, AttributeError):
927
+ continue
928
+
929
+ return turns
930
+
931
+ return await asyncio.to_thread(_get_turns_sync)
932
+
933
+ async def find_turns_by_content(
934
+ self, search_term: str, branch_id: str | None = None
935
+ ) -> list[dict[str, Any]]:
936
+ """Find user turns containing specific content.
937
+
938
+ Args:
939
+ search_term: Text to search for in user messages.
940
+ branch_id: Branch to search in (current branch if None).
941
+
942
+ Returns:
943
+ List of matching turns with same format as get_conversation_turns().
944
+ """
945
+ if branch_id is None:
946
+ branch_id = self._current_branch_id
947
+
948
+ def _search_sync():
949
+ """Synchronous helper to search turns by content."""
950
+ conn = self._get_connection()
951
+ with closing(conn.cursor()) as cursor:
952
+ cursor.execute(
953
+ """
954
+ SELECT
955
+ ms.branch_turn_number,
956
+ am.message_data,
957
+ ms.created_at
958
+ FROM message_structure ms
959
+ JOIN agent_messages am ON ms.message_id = am.id
960
+ WHERE ms.session_id = ? AND ms.branch_id = ?
961
+ AND ms.message_type = 'user'
962
+ AND am.message_data LIKE ?
963
+ ORDER BY ms.branch_turn_number
964
+ """,
965
+ (self.session_id, branch_id, f"%{search_term}%"),
966
+ )
967
+
968
+ matches = []
969
+ for row in cursor.fetchall():
970
+ turn_num, message_data, created_at = row
971
+ try:
972
+ content = json.loads(message_data).get("content", "")
973
+ matches.append(
974
+ {
975
+ "turn": turn_num,
976
+ "content": content,
977
+ "full_content": content,
978
+ "timestamp": created_at,
979
+ "can_branch": True,
980
+ }
981
+ )
982
+ except (json.JSONDecodeError, AttributeError):
983
+ continue
984
+
985
+ return matches
986
+
987
+ return await asyncio.to_thread(_search_sync)
988
+
989
+ async def get_conversation_by_turns(
990
+ self, branch_id: str | None = None
991
+ ) -> dict[int, list[dict[str, str | None]]]:
992
+ """Get conversation grouped by user turns for specified branch.
993
+
994
+ Args:
995
+ branch_id: Branch to get conversation from (current branch if None).
996
+
997
+ Returns:
998
+ Dictionary mapping turn numbers to lists of message metadata.
999
+ """
1000
+ if branch_id is None:
1001
+ branch_id = self._current_branch_id
1002
+
1003
+ def _get_conversation_sync():
1004
+ """Synchronous helper to get conversation by turns."""
1005
+ conn = self._get_connection()
1006
+ with closing(conn.cursor()) as cursor:
1007
+ cursor.execute(
1008
+ """
1009
+ SELECT user_turn_number, message_type, tool_name
1010
+ FROM message_structure
1011
+ WHERE session_id = ? AND branch_id = ?
1012
+ ORDER BY sequence_number
1013
+ """,
1014
+ (self.session_id, branch_id),
1015
+ )
1016
+
1017
+ turns: dict[int, list[dict[str, str | None]]] = {}
1018
+ for row in cursor.fetchall():
1019
+ turn_num, msg_type, tool_name = row
1020
+ if turn_num not in turns:
1021
+ turns[turn_num] = []
1022
+ turns[turn_num].append({"type": msg_type, "tool_name": tool_name})
1023
+ return turns
1024
+
1025
+ return await asyncio.to_thread(_get_conversation_sync)
1026
+
1027
+ async def get_tool_usage(self, branch_id: str | None = None) -> list[tuple[str, int, int]]:
1028
+ """Get all tool usage by turn for specified branch.
1029
+
1030
+ Args:
1031
+ branch_id: Branch to get tool usage from (current branch if None).
1032
+
1033
+ Returns:
1034
+ List of tuples containing (tool_name, usage_count, turn_number).
1035
+ """
1036
+ if branch_id is None:
1037
+ branch_id = self._current_branch_id
1038
+
1039
+ def _get_tool_usage_sync():
1040
+ """Synchronous helper to get tool usage statistics."""
1041
+ conn = self._get_connection()
1042
+ with closing(conn.cursor()) as cursor:
1043
+ cursor.execute(
1044
+ """
1045
+ SELECT tool_name, COUNT(*), user_turn_number
1046
+ FROM message_structure
1047
+ WHERE session_id = ? AND branch_id = ? AND message_type IN (
1048
+ 'tool_call', 'function_call', 'computer_call', 'file_search_call',
1049
+ 'web_search_call', 'code_interpreter_call', 'custom_tool_call',
1050
+ 'mcp_call', 'mcp_approval_request'
1051
+ )
1052
+ GROUP BY tool_name, user_turn_number
1053
+ ORDER BY user_turn_number
1054
+ """,
1055
+ (self.session_id, branch_id),
1056
+ )
1057
+ return cursor.fetchall()
1058
+
1059
+ return await asyncio.to_thread(_get_tool_usage_sync)
1060
+
1061
+ async def get_session_usage(self, branch_id: str | None = None) -> dict[str, int] | None:
1062
+ """Get cumulative usage for session or specific branch.
1063
+
1064
+ Args:
1065
+ branch_id: If provided, only get usage for that branch. If None, get all branches.
1066
+
1067
+ Returns:
1068
+ Dictionary with usage statistics or None if no usage data found.
1069
+ """
1070
+
1071
+ def _get_usage_sync():
1072
+ """Synchronous helper to get session usage data."""
1073
+ conn = self._get_connection()
1074
+ # TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
1075
+ with self._lock if self._is_memory_db else threading.Lock():
1076
+ if branch_id:
1077
+ # Branch-specific usage
1078
+ query = """
1079
+ SELECT
1080
+ SUM(requests) as total_requests,
1081
+ SUM(input_tokens) as total_input_tokens,
1082
+ SUM(output_tokens) as total_output_tokens,
1083
+ SUM(total_tokens) as total_total_tokens,
1084
+ COUNT(*) as total_turns
1085
+ FROM turn_usage
1086
+ WHERE session_id = ? AND branch_id = ?
1087
+ """
1088
+ params: tuple[str, ...] = (self.session_id, branch_id)
1089
+ else:
1090
+ # All branches
1091
+ query = """
1092
+ SELECT
1093
+ SUM(requests) as total_requests,
1094
+ SUM(input_tokens) as total_input_tokens,
1095
+ SUM(output_tokens) as total_output_tokens,
1096
+ SUM(total_tokens) as total_total_tokens,
1097
+ COUNT(*) as total_turns
1098
+ FROM turn_usage
1099
+ WHERE session_id = ?
1100
+ """
1101
+ params = (self.session_id,)
1102
+
1103
+ with closing(conn.cursor()) as cursor:
1104
+ cursor.execute(query, params)
1105
+ row = cursor.fetchone()
1106
+
1107
+ if row and row[0] is not None:
1108
+ return {
1109
+ "requests": row[0] or 0,
1110
+ "input_tokens": row[1] or 0,
1111
+ "output_tokens": row[2] or 0,
1112
+ "total_tokens": row[3] or 0,
1113
+ "total_turns": row[4] or 0,
1114
+ }
1115
+ return None
1116
+
1117
+ result = await asyncio.to_thread(_get_usage_sync)
1118
+
1119
+ return cast(Union[dict[str, int], None], result)
1120
+
1121
+ async def get_turn_usage(
1122
+ self,
1123
+ user_turn_number: int | None = None,
1124
+ branch_id: str | None = None,
1125
+ ) -> list[dict[str, Any]] | dict[str, Any]:
1126
+ """Get usage statistics by turn with full JSON token details.
1127
+
1128
+ Args:
1129
+ user_turn_number: Specific turn to get usage for. If None, returns all turns.
1130
+ branch_id: Branch to get usage from (current branch if None).
1131
+
1132
+ Returns:
1133
+ Dictionary with usage data for specific turn, or list of dictionaries for all turns.
1134
+ """
1135
+
1136
+ if branch_id is None:
1137
+ branch_id = self._current_branch_id
1138
+
1139
+ def _get_turn_usage_sync():
1140
+ """Synchronous helper to get turn usage statistics."""
1141
+ conn = self._get_connection()
1142
+
1143
+ if user_turn_number is not None:
1144
+ query = """
1145
+ SELECT requests, input_tokens, output_tokens, total_tokens,
1146
+ input_tokens_details, output_tokens_details
1147
+ FROM turn_usage
1148
+ WHERE session_id = ? AND branch_id = ? AND user_turn_number = ?
1149
+ """
1150
+
1151
+ with closing(conn.cursor()) as cursor:
1152
+ cursor.execute(query, (self.session_id, branch_id, user_turn_number))
1153
+ row = cursor.fetchone()
1154
+
1155
+ if row:
1156
+ # Parse JSON details if present
1157
+ input_details = None
1158
+ output_details = None
1159
+
1160
+ if row[4]: # input_tokens_details
1161
+ try:
1162
+ input_details = json.loads(row[4])
1163
+ except json.JSONDecodeError:
1164
+ pass
1165
+
1166
+ if row[5]: # output_tokens_details
1167
+ try:
1168
+ output_details = json.loads(row[5])
1169
+ except json.JSONDecodeError:
1170
+ pass
1171
+
1172
+ return {
1173
+ "requests": row[0],
1174
+ "input_tokens": row[1],
1175
+ "output_tokens": row[2],
1176
+ "total_tokens": row[3],
1177
+ "input_tokens_details": input_details,
1178
+ "output_tokens_details": output_details,
1179
+ }
1180
+ return {}
1181
+ else:
1182
+ query = """
1183
+ SELECT user_turn_number, requests, input_tokens, output_tokens,
1184
+ total_tokens, input_tokens_details, output_tokens_details
1185
+ FROM turn_usage
1186
+ WHERE session_id = ? AND branch_id = ?
1187
+ ORDER BY user_turn_number
1188
+ """
1189
+
1190
+ with closing(conn.cursor()) as cursor:
1191
+ cursor.execute(query, (self.session_id, branch_id))
1192
+ results = []
1193
+ for row in cursor.fetchall():
1194
+ # Parse JSON details if present
1195
+ input_details = None
1196
+ output_details = None
1197
+
1198
+ if row[5]: # input_tokens_details
1199
+ try:
1200
+ input_details = json.loads(row[5])
1201
+ except json.JSONDecodeError:
1202
+ pass
1203
+
1204
+ if row[6]: # output_tokens_details
1205
+ try:
1206
+ output_details = json.loads(row[6])
1207
+ except json.JSONDecodeError:
1208
+ pass
1209
+
1210
+ results.append(
1211
+ {
1212
+ "user_turn_number": row[0],
1213
+ "requests": row[1],
1214
+ "input_tokens": row[2],
1215
+ "output_tokens": row[3],
1216
+ "total_tokens": row[4],
1217
+ "input_tokens_details": input_details,
1218
+ "output_tokens_details": output_details,
1219
+ }
1220
+ )
1221
+ return results
1222
+
1223
+ result = await asyncio.to_thread(_get_turn_usage_sync)
1224
+
1225
+ return cast(Union[list[dict[str, Any]], dict[str, Any]], result)
1226
+
1227
+ async def _update_turn_usage_internal(self, user_turn_number: int, usage_data: Usage) -> None:
1228
+ """Internal method to update usage for a specific turn with full JSON details.
1229
+
1230
+ Args:
1231
+ user_turn_number: The turn number to update usage for.
1232
+ usage_data: The usage data to store.
1233
+ """
1234
+
1235
+ def _update_sync():
1236
+ """Synchronous helper to update turn usage data."""
1237
+ conn = self._get_connection()
1238
+ # TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
1239
+ with self._lock if self._is_memory_db else threading.Lock():
1240
+ # Serialize token details as JSON
1241
+ input_details_json = None
1242
+ output_details_json = None
1243
+
1244
+ if hasattr(usage_data, "input_tokens_details") and usage_data.input_tokens_details:
1245
+ try:
1246
+ input_details_json = json.dumps(usage_data.input_tokens_details.__dict__)
1247
+ except (TypeError, ValueError) as e:
1248
+ self._logger.warning(f"Failed to serialize input tokens details: {e}")
1249
+ input_details_json = None
1250
+
1251
+ if (
1252
+ hasattr(usage_data, "output_tokens_details")
1253
+ and usage_data.output_tokens_details
1254
+ ):
1255
+ try:
1256
+ output_details_json = json.dumps(
1257
+ usage_data.output_tokens_details.__dict__
1258
+ )
1259
+ except (TypeError, ValueError) as e:
1260
+ self._logger.warning(f"Failed to serialize output tokens details: {e}")
1261
+ output_details_json = None
1262
+
1263
+ with closing(conn.cursor()) as cursor:
1264
+ cursor.execute(
1265
+ """
1266
+ INSERT OR REPLACE INTO turn_usage
1267
+ (session_id, branch_id, user_turn_number, requests, input_tokens, output_tokens,
1268
+ total_tokens, input_tokens_details, output_tokens_details)
1269
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
1270
+ """, # noqa: E501
1271
+ (
1272
+ self.session_id,
1273
+ self._current_branch_id,
1274
+ user_turn_number,
1275
+ usage_data.requests or 0,
1276
+ usage_data.input_tokens or 0,
1277
+ usage_data.output_tokens or 0,
1278
+ usage_data.total_tokens or 0,
1279
+ input_details_json,
1280
+ output_details_json,
1281
+ ),
1282
+ )
1283
+ conn.commit()
1284
+
1285
+ await asyncio.to_thread(_update_sync)
agents/extensions/memory/encrypt_session.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Encrypted Session wrapper for secure conversation storage.
2
+
3
+ This module provides transparent encryption for session storage with automatic
4
+ expiration of old data. When TTL expires, expired items are silently skipped.
5
+
6
+ Usage::
7
+
8
+ from agents.extensions.memory import EncryptedSession, SQLAlchemySession
9
+
10
+ # Create underlying session (e.g. SQLAlchemySession)
11
+ underlying_session = SQLAlchemySession.from_url(
12
+ session_id="user-123",
13
+ url="postgresql+asyncpg://app:[email protected]/agents",
14
+ create_tables=True,
15
+ )
16
+
17
+ # Wrap with encryption and TTL-based expiration
18
+ session = EncryptedSession(
19
+ session_id="user-123",
20
+ underlying_session=underlying_session,
21
+ encryption_key="your-encryption-key",
22
+ ttl=600, # 10 minutes
23
+ )
24
+
25
+ await Runner.run(agent, "Hello", session=session)
26
+ """
27
+
28
+ from __future__ import annotations
29
+
30
+ import base64
31
+ import json
32
+ from typing import Any, cast
33
+
34
+ from cryptography.fernet import Fernet, InvalidToken
35
+ from cryptography.hazmat.primitives import hashes
36
+ from cryptography.hazmat.primitives.kdf.hkdf import HKDF
37
+ from typing_extensions import Literal, TypedDict, TypeGuard
38
+
39
+ from ...items import TResponseInputItem
40
+ from ...memory.session import SessionABC
41
+
42
+
43
+ class EncryptedEnvelope(TypedDict):
44
+ """TypedDict for encrypted message envelopes stored in the underlying session."""
45
+
46
+ __enc__: Literal[1]
47
+ v: int
48
+ kid: str
49
+ payload: str
50
+
51
+
52
+ def _ensure_fernet_key_bytes(master_key: str) -> bytes:
53
+ """
54
+ Accept either a Fernet key (urlsafe-b64, 32 bytes after decode) or a raw string.
55
+ Returns raw bytes suitable for HKDF input.
56
+ """
57
+ if not master_key:
58
+ raise ValueError("encryption_key not set; required for EncryptedSession.")
59
+ try:
60
+ key_bytes = base64.urlsafe_b64decode(master_key)
61
+ if len(key_bytes) == 32:
62
+ return key_bytes
63
+ except Exception:
64
+ pass
65
+ return master_key.encode("utf-8")
66
+
67
+
68
+ def _derive_session_fernet_key(master_key_bytes: bytes, session_id: str) -> Fernet:
69
+ hkdf = HKDF(
70
+ algorithm=hashes.SHA256(),
71
+ length=32,
72
+ salt=session_id.encode("utf-8"),
73
+ info=b"agents.session-store.hkdf.v1",
74
+ )
75
+ derived = hkdf.derive(master_key_bytes)
76
+ return Fernet(base64.urlsafe_b64encode(derived))
77
+
78
+
79
+ def _to_json_bytes(obj: Any) -> bytes:
80
+ return json.dumps(obj, ensure_ascii=False, separators=(",", ":"), default=str).encode("utf-8")
81
+
82
+
83
+ def _from_json_bytes(data: bytes) -> Any:
84
+ return json.loads(data.decode("utf-8"))
85
+
86
+
87
+ def _is_encrypted_envelope(item: object) -> TypeGuard[EncryptedEnvelope]:
88
+ """Type guard to check if an item is an encrypted envelope."""
89
+ return (
90
+ isinstance(item, dict)
91
+ and item.get("__enc__") == 1
92
+ and "payload" in item
93
+ and "kid" in item
94
+ and "v" in item
95
+ )
96
+
97
+
98
+ class EncryptedSession(SessionABC):
99
+ """Encrypted wrapper for Session implementations with TTL-based expiration.
100
+
101
+ This class wraps any SessionABC implementation to provide transparent
102
+ encryption/decryption of stored items using Fernet encryption with
103
+ per-session key derivation and automatic expiration of old data.
104
+
105
+ When items expire (exceed TTL), they are silently skipped during retrieval.
106
+
107
+ Note: Expired tokens are rejected based on the system clock of the application server.
108
+ To avoid valid tokens being rejected due to clock drift, ensure all servers in
109
+ your environment are synchronized using NTP.
110
+ """
111
+
112
+ def __init__(
113
+ self,
114
+ session_id: str,
115
+ underlying_session: SessionABC,
116
+ encryption_key: str,
117
+ ttl: int = 600,
118
+ ):
119
+ """
120
+ Args:
121
+ session_id: ID for this session
122
+ underlying_session: The real session store (e.g. SQLiteSession, SQLAlchemySession)
123
+ encryption_key: Master key (Fernet key or raw secret)
124
+ ttl: Token time-to-live in seconds (default 10 min)
125
+ """
126
+ self.session_id = session_id
127
+ self.underlying_session = underlying_session
128
+ self.ttl = ttl
129
+
130
+ master = _ensure_fernet_key_bytes(encryption_key)
131
+ self.cipher = _derive_session_fernet_key(master, session_id)
132
+ self._kid = "hkdf-v1"
133
+ self._ver = 1
134
+
135
+ def __getattr__(self, name):
136
+ return getattr(self.underlying_session, name)
137
+
138
+ def _wrap(self, item: TResponseInputItem) -> EncryptedEnvelope:
139
+ if isinstance(item, dict):
140
+ payload = item
141
+ elif hasattr(item, "model_dump"):
142
+ payload = item.model_dump()
143
+ elif hasattr(item, "__dict__"):
144
+ payload = item.__dict__
145
+ else:
146
+ payload = dict(item)
147
+
148
+ token = self.cipher.encrypt(_to_json_bytes(payload)).decode("utf-8")
149
+ return {"__enc__": 1, "v": self._ver, "kid": self._kid, "payload": token}
150
+
151
+ def _unwrap(self, item: TResponseInputItem | EncryptedEnvelope) -> TResponseInputItem | None:
152
+ if not _is_encrypted_envelope(item):
153
+ return cast(TResponseInputItem, item)
154
+
155
+ try:
156
+ token = item["payload"].encode("utf-8")
157
+ plaintext = self.cipher.decrypt(token, ttl=self.ttl)
158
+ return cast(TResponseInputItem, _from_json_bytes(plaintext))
159
+ except (InvalidToken, KeyError):
160
+ return None
161
+
162
+ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
163
+ encrypted_items = await self.underlying_session.get_items(limit)
164
+ valid_items: list[TResponseInputItem] = []
165
+ for enc in encrypted_items:
166
+ item = self._unwrap(enc)
167
+ if item is not None:
168
+ valid_items.append(item)
169
+ return valid_items
170
+
171
+ async def add_items(self, items: list[TResponseInputItem]) -> None:
172
+ wrapped: list[EncryptedEnvelope] = [self._wrap(it) for it in items]
173
+ await self.underlying_session.add_items(cast(list[TResponseInputItem], wrapped))
174
+
175
+ async def pop_item(self) -> TResponseInputItem | None:
176
+ while True:
177
+ enc = await self.underlying_session.pop_item()
178
+ if not enc:
179
+ return None
180
+ item = self._unwrap(enc)
181
+ if item is not None:
182
+ return item
183
+
184
+ async def clear_session(self) -> None:
185
+ await self.underlying_session.clear_session()
agents/extensions/memory/redis_session.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Redis-powered Session backend.
2
+
3
+ Usage::
4
+
5
+ from agents.extensions.memory import RedisSession
6
+
7
+ # Create from Redis URL
8
+ session = RedisSession.from_url(
9
+ session_id="user-123",
10
+ url="redis://localhost:6379/0",
11
+ )
12
+
13
+ # Or pass an existing Redis client that your application already manages
14
+ session = RedisSession(
15
+ session_id="user-123",
16
+ redis_client=my_redis_client,
17
+ )
18
+
19
+ await Runner.run(agent, "Hello", session=session)
20
+ """
21
+
22
+ from __future__ import annotations
23
+
24
+ import asyncio
25
+ import json
26
+ import time
27
+ from typing import Any
28
+ from urllib.parse import urlparse
29
+
30
+ try:
31
+ import redis.asyncio as redis
32
+ from redis.asyncio import Redis
33
+ except ImportError as e:
34
+ raise ImportError(
35
+ "RedisSession requires the 'redis' package. Install it with: pip install redis"
36
+ ) from e
37
+
38
+ from ...items import TResponseInputItem
39
+ from ...memory.session import SessionABC
40
+
41
+
42
+ class RedisSession(SessionABC):
43
+ """Redis implementation of :pyclass:`agents.memory.session.Session`."""
44
+
45
+ def __init__(
46
+ self,
47
+ session_id: str,
48
+ *,
49
+ redis_client: Redis,
50
+ key_prefix: str = "agents:session",
51
+ ttl: int | None = None,
52
+ ):
53
+ """Initializes a new RedisSession.
54
+
55
+ Args:
56
+ session_id (str): Unique identifier for the conversation.
57
+ redis_client (Redis[bytes]): A pre-configured Redis async client.
58
+ key_prefix (str, optional): Prefix for Redis keys to avoid collisions.
59
+ Defaults to "agents:session".
60
+ ttl (int | None, optional): Time-to-live in seconds for session data.
61
+ If None, data persists indefinitely. Defaults to None.
62
+ """
63
+ self.session_id = session_id
64
+ self._redis = redis_client
65
+ self._key_prefix = key_prefix
66
+ self._ttl = ttl
67
+ self._lock = asyncio.Lock()
68
+ self._owns_client = False # Track if we own the Redis client
69
+
70
+ # Redis key patterns
71
+ self._session_key = f"{self._key_prefix}:{self.session_id}"
72
+ self._messages_key = f"{self._session_key}:messages"
73
+ self._counter_key = f"{self._session_key}:counter"
74
+
75
+ @classmethod
76
+ def from_url(
77
+ cls,
78
+ session_id: str,
79
+ *,
80
+ url: str,
81
+ redis_kwargs: dict[str, Any] | None = None,
82
+ **kwargs: Any,
83
+ ) -> RedisSession:
84
+ """Create a session from a Redis URL string.
85
+
86
+ Args:
87
+ session_id (str): Conversation ID.
88
+ url (str): Redis URL, e.g. "redis://localhost:6379/0" or "rediss://host:6380".
89
+ redis_kwargs (dict[str, Any] | None): Additional keyword arguments forwarded to
90
+ redis.asyncio.from_url.
91
+ **kwargs: Additional keyword arguments forwarded to the main constructor
92
+ (e.g., key_prefix, ttl, etc.).
93
+
94
+ Returns:
95
+ RedisSession: An instance of RedisSession connected to the specified Redis server.
96
+ """
97
+ redis_kwargs = redis_kwargs or {}
98
+
99
+ # Parse URL to determine if we need SSL
100
+ parsed = urlparse(url)
101
+ if parsed.scheme == "rediss":
102
+ redis_kwargs.setdefault("ssl", True)
103
+
104
+ redis_client = redis.from_url(url, **redis_kwargs)
105
+ session = cls(session_id, redis_client=redis_client, **kwargs)
106
+ session._owns_client = True # We created the client, so we own it
107
+ return session
108
+
109
+ async def _serialize_item(self, item: TResponseInputItem) -> str:
110
+ """Serialize an item to JSON string. Can be overridden by subclasses."""
111
+ return json.dumps(item, separators=(",", ":"))
112
+
113
+ async def _deserialize_item(self, item: str) -> TResponseInputItem:
114
+ """Deserialize a JSON string to an item. Can be overridden by subclasses."""
115
+ return json.loads(item) # type: ignore[no-any-return] # json.loads returns Any but we know the structure
116
+
117
+ async def _get_next_id(self) -> int:
118
+ """Get the next message ID using Redis INCR for atomic increment."""
119
+ result = await self._redis.incr(self._counter_key)
120
+ return int(result)
121
+
122
+ async def _set_ttl_if_configured(self, *keys: str) -> None:
123
+ """Set TTL on keys if configured."""
124
+ if self._ttl is not None:
125
+ pipe = self._redis.pipeline()
126
+ for key in keys:
127
+ pipe.expire(key, self._ttl)
128
+ await pipe.execute()
129
+
130
+ # ------------------------------------------------------------------
131
+ # Session protocol implementation
132
+ # ------------------------------------------------------------------
133
+
134
+ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
135
+ """Retrieve the conversation history for this session.
136
+
137
+ Args:
138
+ limit: Maximum number of items to retrieve. If None, retrieves all items.
139
+ When specified, returns the latest N items in chronological order.
140
+
141
+ Returns:
142
+ List of input items representing the conversation history
143
+ """
144
+ async with self._lock:
145
+ if limit is None:
146
+ # Get all messages in chronological order
147
+ raw_messages = await self._redis.lrange(self._messages_key, 0, -1) # type: ignore[misc] # Redis library returns Union[Awaitable[T], T] in async context
148
+ else:
149
+ if limit <= 0:
150
+ return []
151
+ # Get the latest N messages (Redis list is ordered chronologically)
152
+ # Use negative indices to get from the end - Redis uses -N to -1 for last N items
153
+ raw_messages = await self._redis.lrange(self._messages_key, -limit, -1) # type: ignore[misc] # Redis library returns Union[Awaitable[T], T] in async context
154
+
155
+ items: list[TResponseInputItem] = []
156
+ for raw_msg in raw_messages:
157
+ try:
158
+ # Handle both bytes (default) and str (decode_responses=True) Redis clients
159
+ if isinstance(raw_msg, bytes):
160
+ msg_str = raw_msg.decode("utf-8")
161
+ else:
162
+ msg_str = raw_msg # Already a string
163
+ item = await self._deserialize_item(msg_str)
164
+ items.append(item)
165
+ except (json.JSONDecodeError, UnicodeDecodeError):
166
+ # Skip corrupted messages
167
+ continue
168
+
169
+ return items
170
+
171
+ async def add_items(self, items: list[TResponseInputItem]) -> None:
172
+ """Add new items to the conversation history.
173
+
174
+ Args:
175
+ items: List of input items to add to the history
176
+ """
177
+ if not items:
178
+ return
179
+
180
+ async with self._lock:
181
+ pipe = self._redis.pipeline()
182
+
183
+ # Set session metadata with current timestamp
184
+ pipe.hset(
185
+ self._session_key,
186
+ mapping={
187
+ "session_id": self.session_id,
188
+ "created_at": str(int(time.time())),
189
+ "updated_at": str(int(time.time())),
190
+ },
191
+ )
192
+
193
+ # Add all items to the messages list
194
+ serialized_items = []
195
+ for item in items:
196
+ serialized = await self._serialize_item(item)
197
+ serialized_items.append(serialized)
198
+
199
+ if serialized_items:
200
+ pipe.rpush(self._messages_key, *serialized_items)
201
+
202
+ # Update the session timestamp
203
+ pipe.hset(self._session_key, "updated_at", str(int(time.time())))
204
+
205
+ # Execute all commands
206
+ await pipe.execute()
207
+
208
+ # Set TTL if configured
209
+ await self._set_ttl_if_configured(
210
+ self._session_key, self._messages_key, self._counter_key
211
+ )
212
+
213
+ async def pop_item(self) -> TResponseInputItem | None:
214
+ """Remove and return the most recent item from the session.
215
+
216
+ Returns:
217
+ The most recent item if it exists, None if the session is empty
218
+ """
219
+ async with self._lock:
220
+ # Use RPOP to atomically remove and return the rightmost (most recent) item
221
+ raw_msg = await self._redis.rpop(self._messages_key) # type: ignore[misc] # Redis library returns Union[Awaitable[T], T] in async context
222
+
223
+ if raw_msg is None:
224
+ return None
225
+
226
+ try:
227
+ # Handle both bytes (default) and str (decode_responses=True) Redis clients
228
+ if isinstance(raw_msg, bytes):
229
+ msg_str = raw_msg.decode("utf-8")
230
+ else:
231
+ msg_str = raw_msg # Already a string
232
+ return await self._deserialize_item(msg_str)
233
+ except (json.JSONDecodeError, UnicodeDecodeError):
234
+ # Return None for corrupted messages (already removed)
235
+ return None
236
+
237
+ async def clear_session(self) -> None:
238
+ """Clear all items for this session."""
239
+ async with self._lock:
240
+ # Delete all keys associated with this session
241
+ await self._redis.delete(
242
+ self._session_key,
243
+ self._messages_key,
244
+ self._counter_key,
245
+ )
246
+
247
+ async def close(self) -> None:
248
+ """Close the Redis connection.
249
+
250
+ Only closes the connection if this session owns the Redis client
251
+ (i.e., created via from_url). If the client was injected externally,
252
+ the caller is responsible for managing its lifecycle.
253
+ """
254
+ if self._owns_client:
255
+ await self._redis.aclose()
256
+
257
+ async def ping(self) -> bool:
258
+ """Test Redis connectivity.
259
+
260
+ Returns:
261
+ True if Redis is reachable, False otherwise.
262
+ """
263
+ try:
264
+ await self._redis.ping()
265
+ return True
266
+ except Exception:
267
+ return False
agents/extensions/memory/sqlalchemy_session.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SQLAlchemy-powered Session backend.
2
+
3
+ Usage::
4
+
5
+ from agents.extensions.memory import SQLAlchemySession
6
+
7
+ # Create from SQLAlchemy URL (uses asyncpg driver under the hood for Postgres)
8
+ session = SQLAlchemySession.from_url(
9
+ session_id="user-123",
10
+ url="postgresql+asyncpg://app:[email protected]/agents",
11
+ create_tables=True, # If you want to auto-create tables, set to True.
12
+ )
13
+
14
+ # Or pass an existing AsyncEngine that your application already manages
15
+ session = SQLAlchemySession(
16
+ session_id="user-123",
17
+ engine=my_async_engine,
18
+ create_tables=True, # If you want to auto-create tables, set to True.
19
+ )
20
+
21
+ await Runner.run(agent, "Hello", session=session)
22
+ """
23
+
24
+ from __future__ import annotations
25
+
26
+ import asyncio
27
+ import json
28
+ from typing import Any
29
+
30
+ from sqlalchemy import (
31
+ TIMESTAMP,
32
+ Column,
33
+ ForeignKey,
34
+ Index,
35
+ Integer,
36
+ MetaData,
37
+ String,
38
+ Table,
39
+ Text,
40
+ delete,
41
+ insert,
42
+ select,
43
+ text as sql_text,
44
+ update,
45
+ )
46
+ from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine
47
+
48
+ from ...items import TResponseInputItem
49
+ from ...memory.session import SessionABC
50
+
51
+
52
+ class SQLAlchemySession(SessionABC):
53
+ """SQLAlchemy implementation of :pyclass:`agents.memory.session.Session`."""
54
+
55
+ _metadata: MetaData
56
+ _sessions: Table
57
+ _messages: Table
58
+
59
+ def __init__(
60
+ self,
61
+ session_id: str,
62
+ *,
63
+ engine: AsyncEngine,
64
+ create_tables: bool = False,
65
+ sessions_table: str = "agent_sessions",
66
+ messages_table: str = "agent_messages",
67
+ ):
68
+ """Initializes a new SQLAlchemySession.
69
+
70
+ Args:
71
+ session_id (str): Unique identifier for the conversation.
72
+ engine (AsyncEngine): A pre-configured SQLAlchemy async engine. The engine
73
+ must be created with an async driver (e.g., 'postgresql+asyncpg://',
74
+ 'mysql+aiomysql://', or 'sqlite+aiosqlite://').
75
+ create_tables (bool, optional): Whether to automatically create the required
76
+ tables and indexes. Defaults to False for production use. Set to True for
77
+ development and testing when migrations aren't used.
78
+ sessions_table (str, optional): Override the default table name for sessions if needed.
79
+ messages_table (str, optional): Override the default table name for messages if needed.
80
+ """
81
+ self.session_id = session_id
82
+ self._engine = engine
83
+ self._lock = asyncio.Lock()
84
+
85
+ self._metadata = MetaData()
86
+ self._sessions = Table(
87
+ sessions_table,
88
+ self._metadata,
89
+ Column("session_id", String, primary_key=True),
90
+ Column(
91
+ "created_at",
92
+ TIMESTAMP(timezone=False),
93
+ server_default=sql_text("CURRENT_TIMESTAMP"),
94
+ nullable=False,
95
+ ),
96
+ Column(
97
+ "updated_at",
98
+ TIMESTAMP(timezone=False),
99
+ server_default=sql_text("CURRENT_TIMESTAMP"),
100
+ onupdate=sql_text("CURRENT_TIMESTAMP"),
101
+ nullable=False,
102
+ ),
103
+ )
104
+
105
+ self._messages = Table(
106
+ messages_table,
107
+ self._metadata,
108
+ Column("id", Integer, primary_key=True, autoincrement=True),
109
+ Column(
110
+ "session_id",
111
+ String,
112
+ ForeignKey(f"{sessions_table}.session_id", ondelete="CASCADE"),
113
+ nullable=False,
114
+ ),
115
+ Column("message_data", Text, nullable=False),
116
+ Column(
117
+ "created_at",
118
+ TIMESTAMP(timezone=False),
119
+ server_default=sql_text("CURRENT_TIMESTAMP"),
120
+ nullable=False,
121
+ ),
122
+ Index(
123
+ f"idx_{messages_table}_session_time",
124
+ "session_id",
125
+ "created_at",
126
+ ),
127
+ sqlite_autoincrement=True,
128
+ )
129
+
130
+ # Async session factory
131
+ self._session_factory = async_sessionmaker(self._engine, expire_on_commit=False)
132
+
133
+ self._create_tables = create_tables
134
+
135
+ # ---------------------------------------------------------------------
136
+ # Convenience constructors
137
+ # ---------------------------------------------------------------------
138
+ @classmethod
139
+ def from_url(
140
+ cls,
141
+ session_id: str,
142
+ *,
143
+ url: str,
144
+ engine_kwargs: dict[str, Any] | None = None,
145
+ **kwargs: Any,
146
+ ) -> SQLAlchemySession:
147
+ """Create a session from a database URL string.
148
+
149
+ Args:
150
+ session_id (str): Conversation ID.
151
+ url (str): Any SQLAlchemy async URL, e.g. "postgresql+asyncpg://user:pass@host/db".
152
+ engine_kwargs (dict[str, Any] | None): Additional keyword arguments forwarded to
153
+ sqlalchemy.ext.asyncio.create_async_engine.
154
+ **kwargs: Additional keyword arguments forwarded to the main constructor
155
+ (e.g., create_tables, custom table names, etc.).
156
+
157
+ Returns:
158
+ SQLAlchemySession: An instance of SQLAlchemySession connected to the specified database.
159
+ """
160
+ engine_kwargs = engine_kwargs or {}
161
+ engine = create_async_engine(url, **engine_kwargs)
162
+ return cls(session_id, engine=engine, **kwargs)
163
+
164
+ async def _serialize_item(self, item: TResponseInputItem) -> str:
165
+ """Serialize an item to JSON string. Can be overridden by subclasses."""
166
+ return json.dumps(item, separators=(",", ":"))
167
+
168
+ async def _deserialize_item(self, item: str) -> TResponseInputItem:
169
+ """Deserialize a JSON string to an item. Can be overridden by subclasses."""
170
+ return json.loads(item) # type: ignore[no-any-return]
171
+
172
+ # ------------------------------------------------------------------
173
+ # Session protocol implementation
174
+ # ------------------------------------------------------------------
175
+ async def _ensure_tables(self) -> None:
176
+ """Ensure tables are created before any database operations."""
177
+ if self._create_tables:
178
+ async with self._engine.begin() as conn:
179
+ await conn.run_sync(self._metadata.create_all)
180
+ self._create_tables = False # Only create once
181
+
182
+ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
183
+ """Retrieve the conversation history for this session.
184
+
185
+ Args:
186
+ limit: Maximum number of items to retrieve. If None, retrieves all items.
187
+ When specified, returns the latest N items in chronological order.
188
+
189
+ Returns:
190
+ List of input items representing the conversation history
191
+ """
192
+ await self._ensure_tables()
193
+ async with self._session_factory() as sess:
194
+ if limit is None:
195
+ stmt = (
196
+ select(self._messages.c.message_data)
197
+ .where(self._messages.c.session_id == self.session_id)
198
+ .order_by(self._messages.c.created_at.asc())
199
+ )
200
+ else:
201
+ stmt = (
202
+ select(self._messages.c.message_data)
203
+ .where(self._messages.c.session_id == self.session_id)
204
+ # Use DESC + LIMIT to get the latest N
205
+ # then reverse later for chronological order.
206
+ .order_by(self._messages.c.created_at.desc())
207
+ .limit(limit)
208
+ )
209
+
210
+ result = await sess.execute(stmt)
211
+ rows: list[str] = [row[0] for row in result.all()]
212
+
213
+ if limit is not None:
214
+ rows.reverse()
215
+
216
+ items: list[TResponseInputItem] = []
217
+ for raw in rows:
218
+ try:
219
+ items.append(await self._deserialize_item(raw))
220
+ except json.JSONDecodeError:
221
+ # Skip corrupted rows
222
+ continue
223
+ return items
224
+
225
+ async def add_items(self, items: list[TResponseInputItem]) -> None:
226
+ """Add new items to the conversation history.
227
+
228
+ Args:
229
+ items: List of input items to add to the history
230
+ """
231
+ if not items:
232
+ return
233
+
234
+ await self._ensure_tables()
235
+ payload = [
236
+ {
237
+ "session_id": self.session_id,
238
+ "message_data": await self._serialize_item(item),
239
+ }
240
+ for item in items
241
+ ]
242
+
243
+ async with self._session_factory() as sess:
244
+ async with sess.begin():
245
+ # Ensure the parent session row exists - use merge for cross-DB compatibility
246
+ # Check if session exists
247
+ existing = await sess.execute(
248
+ select(self._sessions.c.session_id).where(
249
+ self._sessions.c.session_id == self.session_id
250
+ )
251
+ )
252
+ if not existing.scalar_one_or_none():
253
+ # Session doesn't exist, create it
254
+ await sess.execute(
255
+ insert(self._sessions).values({"session_id": self.session_id})
256
+ )
257
+
258
+ # Insert messages in bulk
259
+ await sess.execute(insert(self._messages), payload)
260
+
261
+ # Touch updated_at column
262
+ await sess.execute(
263
+ update(self._sessions)
264
+ .where(self._sessions.c.session_id == self.session_id)
265
+ .values(updated_at=sql_text("CURRENT_TIMESTAMP"))
266
+ )
267
+
268
+ async def pop_item(self) -> TResponseInputItem | None:
269
+ """Remove and return the most recent item from the session.
270
+
271
+ Returns:
272
+ The most recent item if it exists, None if the session is empty
273
+ """
274
+ await self._ensure_tables()
275
+ async with self._session_factory() as sess:
276
+ async with sess.begin():
277
+ # Fallback for all dialects - get ID first, then delete
278
+ subq = (
279
+ select(self._messages.c.id)
280
+ .where(self._messages.c.session_id == self.session_id)
281
+ .order_by(self._messages.c.created_at.desc())
282
+ .limit(1)
283
+ )
284
+ res = await sess.execute(subq)
285
+ row_id = res.scalar_one_or_none()
286
+ if row_id is None:
287
+ return None
288
+ # Fetch data before deleting
289
+ res_data = await sess.execute(
290
+ select(self._messages.c.message_data).where(self._messages.c.id == row_id)
291
+ )
292
+ row = res_data.scalar_one_or_none()
293
+ await sess.execute(delete(self._messages).where(self._messages.c.id == row_id))
294
+
295
+ if row is None:
296
+ return None
297
+ try:
298
+ return await self._deserialize_item(row)
299
+ except json.JSONDecodeError:
300
+ return None
301
+
302
+ async def clear_session(self) -> None:
303
+ """Clear all items for this session."""
304
+ await self._ensure_tables()
305
+ async with self._session_factory() as sess:
306
+ async with sess.begin():
307
+ await sess.execute(
308
+ delete(self._messages).where(self._messages.c.session_id == self.session_id)
309
+ )
310
+ await sess.execute(
311
+ delete(self._sessions).where(self._sessions.c.session_id == self.session_id)
312
+ )
agents/extensions/models/__init__.py ADDED
File without changes
agents/extensions/models/litellm_model.py ADDED
@@ -0,0 +1,601 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import time
5
+ from collections.abc import AsyncIterator
6
+ from copy import copy
7
+ from typing import Any, Literal, cast, overload
8
+
9
+ from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails
10
+
11
+ from agents.exceptions import ModelBehaviorError
12
+
13
+ try:
14
+ import litellm
15
+ except ImportError as _e:
16
+ raise ImportError(
17
+ "`litellm` is required to use the LitellmModel. You can install it via the optional "
18
+ "dependency group: `pip install 'openai-agents[litellm]'`."
19
+ ) from _e
20
+
21
+ from openai import NOT_GIVEN, AsyncStream, NotGiven
22
+ from openai.types.chat import (
23
+ ChatCompletionChunk,
24
+ ChatCompletionMessageCustomToolCall,
25
+ ChatCompletionMessageFunctionToolCall,
26
+ ChatCompletionMessageParam,
27
+ )
28
+ from openai.types.chat.chat_completion_message import (
29
+ Annotation,
30
+ AnnotationURLCitation,
31
+ ChatCompletionMessage,
32
+ )
33
+ from openai.types.chat.chat_completion_message_function_tool_call import Function
34
+ from openai.types.responses import Response
35
+
36
+ from ... import _debug
37
+ from ...agent_output import AgentOutputSchemaBase
38
+ from ...handoffs import Handoff
39
+ from ...items import ModelResponse, TResponseInputItem, TResponseStreamEvent
40
+ from ...logger import logger
41
+ from ...model_settings import ModelSettings
42
+ from ...models.chatcmpl_converter import Converter
43
+ from ...models.chatcmpl_helpers import HEADERS, HEADERS_OVERRIDE
44
+ from ...models.chatcmpl_stream_handler import ChatCmplStreamHandler
45
+ from ...models.fake_id import FAKE_RESPONSES_ID
46
+ from ...models.interface import Model, ModelTracing
47
+ from ...tool import Tool
48
+ from ...tracing import generation_span
49
+ from ...tracing.span_data import GenerationSpanData
50
+ from ...tracing.spans import Span
51
+ from ...usage import Usage
52
+ from ...util._json import _to_dump_compatible
53
+
54
+
55
+ class InternalChatCompletionMessage(ChatCompletionMessage):
56
+ """
57
+ An internal subclass to carry reasoning_content and thinking_blocks without modifying the original model.
58
+ """ # noqa: E501
59
+
60
+ reasoning_content: str
61
+ thinking_blocks: list[dict[str, Any]] | None = None
62
+
63
+
64
+ class LitellmModel(Model):
65
+ """This class enables using any model via LiteLLM. LiteLLM allows you to acess OpenAPI,
66
+ Anthropic, Gemini, Mistral, and many other models.
67
+ See supported models here: [litellm models](https://docs.litellm.ai/docs/providers).
68
+ """
69
+
70
+ def __init__(
71
+ self,
72
+ model: str,
73
+ base_url: str | None = None,
74
+ api_key: str | None = None,
75
+ ):
76
+ self.model = model
77
+ self.base_url = base_url
78
+ self.api_key = api_key
79
+
80
+ async def get_response(
81
+ self,
82
+ system_instructions: str | None,
83
+ input: str | list[TResponseInputItem],
84
+ model_settings: ModelSettings,
85
+ tools: list[Tool],
86
+ output_schema: AgentOutputSchemaBase | None,
87
+ handoffs: list[Handoff],
88
+ tracing: ModelTracing,
89
+ previous_response_id: str | None = None, # unused
90
+ conversation_id: str | None = None, # unused
91
+ prompt: Any | None = None,
92
+ ) -> ModelResponse:
93
+ with generation_span(
94
+ model=str(self.model),
95
+ model_config=model_settings.to_json_dict()
96
+ | {"base_url": str(self.base_url or ""), "model_impl": "litellm"},
97
+ disabled=tracing.is_disabled(),
98
+ ) as span_generation:
99
+ response = await self._fetch_response(
100
+ system_instructions,
101
+ input,
102
+ model_settings,
103
+ tools,
104
+ output_schema,
105
+ handoffs,
106
+ span_generation,
107
+ tracing,
108
+ stream=False,
109
+ prompt=prompt,
110
+ )
111
+
112
+ assert isinstance(response.choices[0], litellm.types.utils.Choices)
113
+
114
+ if _debug.DONT_LOG_MODEL_DATA:
115
+ logger.debug("Received model response")
116
+ else:
117
+ logger.debug(
118
+ f"""LLM resp:\n{
119
+ json.dumps(
120
+ response.choices[0].message.model_dump(), indent=2, ensure_ascii=False
121
+ )
122
+ }\n"""
123
+ )
124
+
125
+ if hasattr(response, "usage"):
126
+ response_usage = response.usage
127
+ usage = (
128
+ Usage(
129
+ requests=1,
130
+ input_tokens=response_usage.prompt_tokens,
131
+ output_tokens=response_usage.completion_tokens,
132
+ total_tokens=response_usage.total_tokens,
133
+ input_tokens_details=InputTokensDetails(
134
+ cached_tokens=getattr(
135
+ response_usage.prompt_tokens_details, "cached_tokens", 0
136
+ )
137
+ or 0
138
+ ),
139
+ output_tokens_details=OutputTokensDetails(
140
+ reasoning_tokens=getattr(
141
+ response_usage.completion_tokens_details, "reasoning_tokens", 0
142
+ )
143
+ or 0
144
+ ),
145
+ )
146
+ if response.usage
147
+ else Usage()
148
+ )
149
+ else:
150
+ usage = Usage()
151
+ logger.warning("No usage information returned from Litellm")
152
+
153
+ if tracing.include_data():
154
+ span_generation.span_data.output = [response.choices[0].message.model_dump()]
155
+ span_generation.span_data.usage = {
156
+ "input_tokens": usage.input_tokens,
157
+ "output_tokens": usage.output_tokens,
158
+ }
159
+
160
+ items = Converter.message_to_output_items(
161
+ LitellmConverter.convert_message_to_openai(response.choices[0].message)
162
+ )
163
+
164
+ return ModelResponse(
165
+ output=items,
166
+ usage=usage,
167
+ response_id=None,
168
+ )
169
+
170
+ async def stream_response(
171
+ self,
172
+ system_instructions: str | None,
173
+ input: str | list[TResponseInputItem],
174
+ model_settings: ModelSettings,
175
+ tools: list[Tool],
176
+ output_schema: AgentOutputSchemaBase | None,
177
+ handoffs: list[Handoff],
178
+ tracing: ModelTracing,
179
+ previous_response_id: str | None = None, # unused
180
+ conversation_id: str | None = None, # unused
181
+ prompt: Any | None = None,
182
+ ) -> AsyncIterator[TResponseStreamEvent]:
183
+ with generation_span(
184
+ model=str(self.model),
185
+ model_config=model_settings.to_json_dict()
186
+ | {"base_url": str(self.base_url or ""), "model_impl": "litellm"},
187
+ disabled=tracing.is_disabled(),
188
+ ) as span_generation:
189
+ response, stream = await self._fetch_response(
190
+ system_instructions,
191
+ input,
192
+ model_settings,
193
+ tools,
194
+ output_schema,
195
+ handoffs,
196
+ span_generation,
197
+ tracing,
198
+ stream=True,
199
+ prompt=prompt,
200
+ )
201
+
202
+ final_response: Response | None = None
203
+ async for chunk in ChatCmplStreamHandler.handle_stream(response, stream):
204
+ yield chunk
205
+
206
+ if chunk.type == "response.completed":
207
+ final_response = chunk.response
208
+
209
+ if tracing.include_data() and final_response:
210
+ span_generation.span_data.output = [final_response.model_dump()]
211
+
212
+ if final_response and final_response.usage:
213
+ span_generation.span_data.usage = {
214
+ "input_tokens": final_response.usage.input_tokens,
215
+ "output_tokens": final_response.usage.output_tokens,
216
+ }
217
+
218
+ @overload
219
+ async def _fetch_response(
220
+ self,
221
+ system_instructions: str | None,
222
+ input: str | list[TResponseInputItem],
223
+ model_settings: ModelSettings,
224
+ tools: list[Tool],
225
+ output_schema: AgentOutputSchemaBase | None,
226
+ handoffs: list[Handoff],
227
+ span: Span[GenerationSpanData],
228
+ tracing: ModelTracing,
229
+ stream: Literal[True],
230
+ prompt: Any | None = None,
231
+ ) -> tuple[Response, AsyncStream[ChatCompletionChunk]]: ...
232
+
233
+ @overload
234
+ async def _fetch_response(
235
+ self,
236
+ system_instructions: str | None,
237
+ input: str | list[TResponseInputItem],
238
+ model_settings: ModelSettings,
239
+ tools: list[Tool],
240
+ output_schema: AgentOutputSchemaBase | None,
241
+ handoffs: list[Handoff],
242
+ span: Span[GenerationSpanData],
243
+ tracing: ModelTracing,
244
+ stream: Literal[False],
245
+ prompt: Any | None = None,
246
+ ) -> litellm.types.utils.ModelResponse: ...
247
+
248
+ async def _fetch_response(
249
+ self,
250
+ system_instructions: str | None,
251
+ input: str | list[TResponseInputItem],
252
+ model_settings: ModelSettings,
253
+ tools: list[Tool],
254
+ output_schema: AgentOutputSchemaBase | None,
255
+ handoffs: list[Handoff],
256
+ span: Span[GenerationSpanData],
257
+ tracing: ModelTracing,
258
+ stream: bool = False,
259
+ prompt: Any | None = None,
260
+ ) -> litellm.types.utils.ModelResponse | tuple[Response, AsyncStream[ChatCompletionChunk]]:
261
+ # Preserve reasoning messages for tool calls when reasoning is on
262
+ # This is needed for models like Claude 4 Sonnet/Opus which support interleaved thinking
263
+ preserve_thinking_blocks = (
264
+ model_settings.reasoning is not None and model_settings.reasoning.effort is not None
265
+ )
266
+
267
+ converted_messages = Converter.items_to_messages(
268
+ input, preserve_thinking_blocks=preserve_thinking_blocks
269
+ )
270
+
271
+ # Fix for interleaved thinking bug: reorder messages to ensure tool_use comes before tool_result # noqa: E501
272
+ if preserve_thinking_blocks:
273
+ converted_messages = self._fix_tool_message_ordering(converted_messages)
274
+
275
+ if system_instructions:
276
+ converted_messages.insert(
277
+ 0,
278
+ {
279
+ "content": system_instructions,
280
+ "role": "system",
281
+ },
282
+ )
283
+ converted_messages = _to_dump_compatible(converted_messages)
284
+
285
+ if tracing.include_data():
286
+ span.span_data.input = converted_messages
287
+
288
+ parallel_tool_calls = (
289
+ True
290
+ if model_settings.parallel_tool_calls and tools and len(tools) > 0
291
+ else False
292
+ if model_settings.parallel_tool_calls is False
293
+ else None
294
+ )
295
+ tool_choice = Converter.convert_tool_choice(model_settings.tool_choice)
296
+ response_format = Converter.convert_response_format(output_schema)
297
+
298
+ converted_tools = [Converter.tool_to_openai(tool) for tool in tools] if tools else []
299
+
300
+ for handoff in handoffs:
301
+ converted_tools.append(Converter.convert_handoff_tool(handoff))
302
+
303
+ converted_tools = _to_dump_compatible(converted_tools)
304
+
305
+ if _debug.DONT_LOG_MODEL_DATA:
306
+ logger.debug("Calling LLM")
307
+ else:
308
+ messages_json = json.dumps(
309
+ converted_messages,
310
+ indent=2,
311
+ ensure_ascii=False,
312
+ )
313
+ tools_json = json.dumps(
314
+ converted_tools,
315
+ indent=2,
316
+ ensure_ascii=False,
317
+ )
318
+ logger.debug(
319
+ f"Calling Litellm model: {self.model}\n"
320
+ f"{messages_json}\n"
321
+ f"Tools:\n{tools_json}\n"
322
+ f"Stream: {stream}\n"
323
+ f"Tool choice: {tool_choice}\n"
324
+ f"Response format: {response_format}\n"
325
+ )
326
+
327
+ reasoning_effort = model_settings.reasoning.effort if model_settings.reasoning else None
328
+
329
+ stream_options = None
330
+ if stream and model_settings.include_usage is not None:
331
+ stream_options = {"include_usage": model_settings.include_usage}
332
+
333
+ extra_kwargs = {}
334
+ if model_settings.extra_query:
335
+ extra_kwargs["extra_query"] = copy(model_settings.extra_query)
336
+ if model_settings.metadata:
337
+ extra_kwargs["metadata"] = copy(model_settings.metadata)
338
+ if model_settings.extra_body and isinstance(model_settings.extra_body, dict):
339
+ extra_kwargs.update(model_settings.extra_body)
340
+
341
+ # Add kwargs from model_settings.extra_args, filtering out None values
342
+ if model_settings.extra_args:
343
+ extra_kwargs.update(model_settings.extra_args)
344
+
345
+ ret = await litellm.acompletion(
346
+ model=self.model,
347
+ messages=converted_messages,
348
+ tools=converted_tools or None,
349
+ temperature=model_settings.temperature,
350
+ top_p=model_settings.top_p,
351
+ frequency_penalty=model_settings.frequency_penalty,
352
+ presence_penalty=model_settings.presence_penalty,
353
+ max_tokens=model_settings.max_tokens,
354
+ tool_choice=self._remove_not_given(tool_choice),
355
+ response_format=self._remove_not_given(response_format),
356
+ parallel_tool_calls=parallel_tool_calls,
357
+ stream=stream,
358
+ stream_options=stream_options,
359
+ reasoning_effort=reasoning_effort,
360
+ top_logprobs=model_settings.top_logprobs,
361
+ extra_headers=self._merge_headers(model_settings),
362
+ api_key=self.api_key,
363
+ base_url=self.base_url,
364
+ **extra_kwargs,
365
+ )
366
+
367
+ if isinstance(ret, litellm.types.utils.ModelResponse):
368
+ return ret
369
+
370
+ response = Response(
371
+ id=FAKE_RESPONSES_ID,
372
+ created_at=time.time(),
373
+ model=self.model,
374
+ object="response",
375
+ output=[],
376
+ tool_choice=cast(Literal["auto", "required", "none"], tool_choice)
377
+ if tool_choice != NOT_GIVEN
378
+ else "auto",
379
+ top_p=model_settings.top_p,
380
+ temperature=model_settings.temperature,
381
+ tools=[],
382
+ parallel_tool_calls=parallel_tool_calls or False,
383
+ reasoning=model_settings.reasoning,
384
+ )
385
+ return response, ret
386
+
387
+ def _fix_tool_message_ordering(
388
+ self, messages: list[ChatCompletionMessageParam]
389
+ ) -> list[ChatCompletionMessageParam]:
390
+ """
391
+ Fix the ordering of tool messages to ensure tool_use messages come before tool_result messages.
392
+
393
+ This addresses the interleaved thinking bug where conversation histories may contain
394
+ tool results before their corresponding tool calls, causing Anthropic API to reject the request.
395
+ """ # noqa: E501
396
+ if not messages:
397
+ return messages
398
+
399
+ # Collect all tool calls and tool results
400
+ tool_call_messages = {} # tool_id -> (index, message)
401
+ tool_result_messages = {} # tool_id -> (index, message)
402
+ other_messages = [] # (index, message) for non-tool messages
403
+
404
+ for i, message in enumerate(messages):
405
+ if not isinstance(message, dict):
406
+ other_messages.append((i, message))
407
+ continue
408
+
409
+ role = message.get("role")
410
+
411
+ if role == "assistant" and message.get("tool_calls"):
412
+ # Extract tool calls from this assistant message
413
+ tool_calls = message.get("tool_calls", [])
414
+ if isinstance(tool_calls, list):
415
+ for tool_call in tool_calls:
416
+ if isinstance(tool_call, dict):
417
+ tool_id = tool_call.get("id")
418
+ if tool_id:
419
+ # Create a separate assistant message for each tool call
420
+ single_tool_msg = cast(dict[str, Any], message.copy())
421
+ single_tool_msg["tool_calls"] = [tool_call]
422
+ tool_call_messages[tool_id] = (
423
+ i,
424
+ cast(ChatCompletionMessageParam, single_tool_msg),
425
+ )
426
+
427
+ elif role == "tool":
428
+ tool_call_id = message.get("tool_call_id")
429
+ if tool_call_id:
430
+ tool_result_messages[tool_call_id] = (i, message)
431
+ else:
432
+ other_messages.append((i, message))
433
+ else:
434
+ other_messages.append((i, message))
435
+
436
+ # First, identify which tool results will be paired to avoid duplicates
437
+ paired_tool_result_indices = set()
438
+ for tool_id in tool_call_messages:
439
+ if tool_id in tool_result_messages:
440
+ tool_result_idx, _ = tool_result_messages[tool_id]
441
+ paired_tool_result_indices.add(tool_result_idx)
442
+
443
+ # Create the fixed message sequence
444
+ fixed_messages: list[ChatCompletionMessageParam] = []
445
+ used_indices = set()
446
+
447
+ # Add messages in their original order, but ensure tool_use → tool_result pairing
448
+ for i, original_message in enumerate(messages):
449
+ if i in used_indices:
450
+ continue
451
+
452
+ if not isinstance(original_message, dict):
453
+ fixed_messages.append(original_message)
454
+ used_indices.add(i)
455
+ continue
456
+
457
+ role = original_message.get("role")
458
+
459
+ if role == "assistant" and original_message.get("tool_calls"):
460
+ # Process each tool call in this assistant message
461
+ tool_calls = original_message.get("tool_calls", [])
462
+ if isinstance(tool_calls, list):
463
+ for tool_call in tool_calls:
464
+ if isinstance(tool_call, dict):
465
+ tool_id = tool_call.get("id")
466
+ if (
467
+ tool_id
468
+ and tool_id in tool_call_messages
469
+ and tool_id in tool_result_messages
470
+ ):
471
+ # Add tool_use → tool_result pair
472
+ _, tool_call_msg = tool_call_messages[tool_id]
473
+ tool_result_idx, tool_result_msg = tool_result_messages[tool_id]
474
+
475
+ fixed_messages.append(tool_call_msg)
476
+ fixed_messages.append(tool_result_msg)
477
+
478
+ # Mark both as used
479
+ used_indices.add(tool_call_messages[tool_id][0])
480
+ used_indices.add(tool_result_idx)
481
+ elif tool_id and tool_id in tool_call_messages:
482
+ # Tool call without result - add just the tool call
483
+ _, tool_call_msg = tool_call_messages[tool_id]
484
+ fixed_messages.append(tool_call_msg)
485
+ used_indices.add(tool_call_messages[tool_id][0])
486
+
487
+ used_indices.add(i) # Mark original multi-tool message as used
488
+
489
+ elif role == "tool":
490
+ # Only preserve unmatched tool results to avoid duplicates
491
+ if i not in paired_tool_result_indices:
492
+ fixed_messages.append(original_message)
493
+ used_indices.add(i)
494
+
495
+ else:
496
+ # Regular message - add it normally
497
+ fixed_messages.append(original_message)
498
+ used_indices.add(i)
499
+
500
+ return fixed_messages
501
+
502
+ def _remove_not_given(self, value: Any) -> Any:
503
+ if isinstance(value, NotGiven):
504
+ return None
505
+ return value
506
+
507
+ def _merge_headers(self, model_settings: ModelSettings):
508
+ return {**HEADERS, **(model_settings.extra_headers or {}), **(HEADERS_OVERRIDE.get() or {})}
509
+
510
+
511
+ class LitellmConverter:
512
+ @classmethod
513
+ def convert_message_to_openai(
514
+ cls, message: litellm.types.utils.Message
515
+ ) -> ChatCompletionMessage:
516
+ if message.role != "assistant":
517
+ raise ModelBehaviorError(f"Unsupported role: {message.role}")
518
+
519
+ tool_calls: (
520
+ list[ChatCompletionMessageFunctionToolCall | ChatCompletionMessageCustomToolCall] | None
521
+ ) = (
522
+ [LitellmConverter.convert_tool_call_to_openai(tool) for tool in message.tool_calls]
523
+ if message.tool_calls
524
+ else None
525
+ )
526
+
527
+ provider_specific_fields = message.get("provider_specific_fields", None)
528
+ refusal = (
529
+ provider_specific_fields.get("refusal", None) if provider_specific_fields else None
530
+ )
531
+
532
+ reasoning_content = ""
533
+ if hasattr(message, "reasoning_content") and message.reasoning_content:
534
+ reasoning_content = message.reasoning_content
535
+
536
+ # Extract full thinking blocks including signatures (for Anthropic)
537
+ thinking_blocks: list[dict[str, Any]] | None = None
538
+ if hasattr(message, "thinking_blocks") and message.thinking_blocks:
539
+ # Convert thinking blocks to dict format for compatibility
540
+ thinking_blocks = []
541
+ for block in message.thinking_blocks:
542
+ if isinstance(block, dict):
543
+ thinking_blocks.append(cast(dict[str, Any], block))
544
+ else:
545
+ # Convert object to dict by accessing its attributes
546
+ block_dict: dict[str, Any] = {}
547
+ if hasattr(block, "__dict__"):
548
+ block_dict = dict(block.__dict__.items())
549
+ elif hasattr(block, "model_dump"):
550
+ block_dict = block.model_dump()
551
+ else:
552
+ # Last resort: convert to string representation
553
+ block_dict = {"thinking": str(block)}
554
+ thinking_blocks.append(block_dict)
555
+
556
+ return InternalChatCompletionMessage(
557
+ content=message.content,
558
+ refusal=refusal,
559
+ role="assistant",
560
+ annotations=cls.convert_annotations_to_openai(message),
561
+ audio=message.get("audio", None), # litellm deletes audio if not present
562
+ tool_calls=tool_calls,
563
+ reasoning_content=reasoning_content,
564
+ thinking_blocks=thinking_blocks,
565
+ )
566
+
567
+ @classmethod
568
+ def convert_annotations_to_openai(
569
+ cls, message: litellm.types.utils.Message
570
+ ) -> list[Annotation] | None:
571
+ annotations: list[litellm.types.llms.openai.ChatCompletionAnnotation] | None = message.get(
572
+ "annotations", None
573
+ )
574
+ if not annotations:
575
+ return None
576
+
577
+ return [
578
+ Annotation(
579
+ type="url_citation",
580
+ url_citation=AnnotationURLCitation(
581
+ start_index=annotation["url_citation"]["start_index"],
582
+ end_index=annotation["url_citation"]["end_index"],
583
+ url=annotation["url_citation"]["url"],
584
+ title=annotation["url_citation"]["title"],
585
+ ),
586
+ )
587
+ for annotation in annotations
588
+ ]
589
+
590
+ @classmethod
591
+ def convert_tool_call_to_openai(
592
+ cls, tool_call: litellm.types.utils.ChatCompletionMessageToolCall
593
+ ) -> ChatCompletionMessageFunctionToolCall:
594
+ return ChatCompletionMessageFunctionToolCall(
595
+ id=tool_call.id,
596
+ type="function",
597
+ function=Function(
598
+ name=tool_call.function.name or "",
599
+ arguments=tool_call.function.arguments,
600
+ ),
601
+ )
agents/extensions/models/litellm_provider.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ...models.default_models import get_default_model
2
+ from ...models.interface import Model, ModelProvider
3
+ from .litellm_model import LitellmModel
4
+
5
+ # This is kept for backward compatiblity but using get_default_model() method is recommended.
6
+ DEFAULT_MODEL: str = "gpt-4.1"
7
+
8
+
9
+ class LitellmProvider(ModelProvider):
10
+ """A ModelProvider that uses LiteLLM to route to any model provider. You can use it via:
11
+ ```python
12
+ Runner.run(agent, input, run_config=RunConfig(model_provider=LitellmProvider()))
13
+ ```
14
+ See supported models here: [litellm models](https://docs.litellm.ai/docs/providers).
15
+
16
+ NOTE: API keys must be set via environment variables. If you're using models that require
17
+ additional configuration (e.g. Azure API base or version), those must also be set via the
18
+ environment variables that LiteLLM expects. If you have more advanced needs, we recommend
19
+ copy-pasting this class and making any modifications you need.
20
+ """
21
+
22
+ def get_model(self, model_name: str | None) -> Model:
23
+ return LitellmModel(model_name or get_default_model())
agents/extensions/visualization.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import graphviz # type: ignore
4
+
5
+ from agents import Agent
6
+ from agents.handoffs import Handoff
7
+ from agents.tool import Tool
8
+
9
+
10
+ def get_main_graph(agent: Agent) -> str:
11
+ """
12
+ Generates the main graph structure in DOT format for the given agent.
13
+
14
+ Args:
15
+ agent (Agent): The agent for which the graph is to be generated.
16
+
17
+ Returns:
18
+ str: The DOT format string representing the graph.
19
+ """
20
+ parts = [
21
+ """
22
+ digraph G {
23
+ graph [splines=true];
24
+ node [fontname="Arial"];
25
+ edge [penwidth=1.5];
26
+ """
27
+ ]
28
+ parts.append(get_all_nodes(agent))
29
+ parts.append(get_all_edges(agent))
30
+ parts.append("}")
31
+ return "".join(parts)
32
+
33
+
34
+ def get_all_nodes(
35
+ agent: Agent, parent: Agent | None = None, visited: set[str] | None = None
36
+ ) -> str:
37
+ """
38
+ Recursively generates the nodes for the given agent and its handoffs in DOT format.
39
+
40
+ Args:
41
+ agent (Agent): The agent for which the nodes are to be generated.
42
+
43
+ Returns:
44
+ str: The DOT format string representing the nodes.
45
+ """
46
+ if visited is None:
47
+ visited = set()
48
+ if agent.name in visited:
49
+ return ""
50
+ visited.add(agent.name)
51
+
52
+ parts = []
53
+
54
+ # Start and end the graph
55
+ if not parent:
56
+ parts.append(
57
+ '"__start__" [label="__start__", shape=ellipse, style=filled, '
58
+ "fillcolor=lightblue, width=0.5, height=0.3];"
59
+ '"__end__" [label="__end__", shape=ellipse, style=filled, '
60
+ "fillcolor=lightblue, width=0.5, height=0.3];"
61
+ )
62
+ # Ensure parent agent node is colored
63
+ parts.append(
64
+ f'"{agent.name}" [label="{agent.name}", shape=box, style=filled, '
65
+ "fillcolor=lightyellow, width=1.5, height=0.8];"
66
+ )
67
+
68
+ for tool in agent.tools:
69
+ parts.append(
70
+ f'"{tool.name}" [label="{tool.name}", shape=ellipse, style=filled, '
71
+ f"fillcolor=lightgreen, width=0.5, height=0.3];"
72
+ )
73
+
74
+ for mcp_server in agent.mcp_servers:
75
+ parts.append(
76
+ f'"{mcp_server.name}" [label="{mcp_server.name}", shape=box, style=filled, '
77
+ f"fillcolor=lightgrey, width=1, height=0.5];"
78
+ )
79
+
80
+ for handoff in agent.handoffs:
81
+ if isinstance(handoff, Handoff):
82
+ parts.append(
83
+ f'"{handoff.agent_name}" [label="{handoff.agent_name}", '
84
+ f"shape=box, style=filled, style=rounded, "
85
+ f"fillcolor=lightyellow, width=1.5, height=0.8];"
86
+ )
87
+ if isinstance(handoff, Agent):
88
+ if handoff.name not in visited:
89
+ parts.append(
90
+ f'"{handoff.name}" [label="{handoff.name}", '
91
+ f"shape=box, style=filled, style=rounded, "
92
+ f"fillcolor=lightyellow, width=1.5, height=0.8];"
93
+ )
94
+ parts.append(get_all_nodes(handoff, agent, visited))
95
+
96
+ return "".join(parts)
97
+
98
+
99
+ def get_all_edges(
100
+ agent: Agent, parent: Agent | None = None, visited: set[str] | None = None
101
+ ) -> str:
102
+ """
103
+ Recursively generates the edges for the given agent and its handoffs in DOT format.
104
+
105
+ Args:
106
+ agent (Agent): The agent for which the edges are to be generated.
107
+ parent (Agent, optional): The parent agent. Defaults to None.
108
+
109
+ Returns:
110
+ str: The DOT format string representing the edges.
111
+ """
112
+ if visited is None:
113
+ visited = set()
114
+ if agent.name in visited:
115
+ return ""
116
+ visited.add(agent.name)
117
+
118
+ parts = []
119
+
120
+ if not parent:
121
+ parts.append(f'"__start__" -> "{agent.name}";')
122
+
123
+ for tool in agent.tools:
124
+ parts.append(f"""
125
+ "{agent.name}" -> "{tool.name}" [style=dotted, penwidth=1.5];
126
+ "{tool.name}" -> "{agent.name}" [style=dotted, penwidth=1.5];""")
127
+
128
+ for mcp_server in agent.mcp_servers:
129
+ parts.append(f"""
130
+ "{agent.name}" -> "{mcp_server.name}" [style=dashed, penwidth=1.5];
131
+ "{mcp_server.name}" -> "{agent.name}" [style=dashed, penwidth=1.5];""")
132
+
133
+ for handoff in agent.handoffs:
134
+ if isinstance(handoff, Handoff):
135
+ parts.append(f"""
136
+ "{agent.name}" -> "{handoff.agent_name}";""")
137
+ if isinstance(handoff, Agent):
138
+ parts.append(f"""
139
+ "{agent.name}" -> "{handoff.name}";""")
140
+ parts.append(get_all_edges(handoff, agent, visited))
141
+
142
+ if not agent.handoffs and not isinstance(agent, Tool): # type: ignore
143
+ parts.append(f'"{agent.name}" -> "__end__";')
144
+
145
+ return "".join(parts)
146
+
147
+
148
+ def draw_graph(agent: Agent, filename: str | None = None) -> graphviz.Source:
149
+ """
150
+ Draws the graph for the given agent and optionally saves it as a PNG file.
151
+
152
+ Args:
153
+ agent (Agent): The agent for which the graph is to be drawn.
154
+ filename (str): The name of the file to save the graph as a PNG.
155
+
156
+ Returns:
157
+ graphviz.Source: The graphviz Source object representing the graph.
158
+ """
159
+ dot_code = get_main_graph(agent)
160
+ graph = graphviz.Source(dot_code)
161
+
162
+ if filename:
163
+ graph.render(filename, format="png", cleanup=True)
164
+
165
+ return graph
agents/function_schema.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import contextlib
4
+ import inspect
5
+ import logging
6
+ import re
7
+ from dataclasses import dataclass
8
+ from typing import Annotated, Any, Callable, Literal, get_args, get_origin, get_type_hints
9
+
10
+ from griffe import Docstring, DocstringSectionKind
11
+ from pydantic import BaseModel, Field, create_model
12
+ from pydantic.fields import FieldInfo
13
+
14
+ from .exceptions import UserError
15
+ from .run_context import RunContextWrapper
16
+ from .strict_schema import ensure_strict_json_schema
17
+ from .tool_context import ToolContext
18
+
19
+
20
+ @dataclass
21
+ class FuncSchema:
22
+ """
23
+ Captures the schema for a python function, in preparation for sending it to an LLM as a tool.
24
+ """
25
+
26
+ name: str
27
+ """The name of the function."""
28
+ description: str | None
29
+ """The description of the function."""
30
+ params_pydantic_model: type[BaseModel]
31
+ """A Pydantic model that represents the function's parameters."""
32
+ params_json_schema: dict[str, Any]
33
+ """The JSON schema for the function's parameters, derived from the Pydantic model."""
34
+ signature: inspect.Signature
35
+ """The signature of the function."""
36
+ takes_context: bool = False
37
+ """Whether the function takes a RunContextWrapper argument (must be the first argument)."""
38
+ strict_json_schema: bool = True
39
+ """Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True,
40
+ as it increases the likelihood of correct JSON input."""
41
+
42
+ def to_call_args(self, data: BaseModel) -> tuple[list[Any], dict[str, Any]]:
43
+ """
44
+ Converts validated data from the Pydantic model into (args, kwargs), suitable for calling
45
+ the original function.
46
+ """
47
+ positional_args: list[Any] = []
48
+ keyword_args: dict[str, Any] = {}
49
+ seen_var_positional = False
50
+
51
+ # Use enumerate() so we can skip the first parameter if it's context.
52
+ for idx, (name, param) in enumerate(self.signature.parameters.items()):
53
+ # If the function takes a RunContextWrapper and this is the first parameter, skip it.
54
+ if self.takes_context and idx == 0:
55
+ continue
56
+
57
+ value = getattr(data, name, None)
58
+ if param.kind == param.VAR_POSITIONAL:
59
+ # e.g. *args: extend positional args and mark that *args is now seen
60
+ positional_args.extend(value or [])
61
+ seen_var_positional = True
62
+ elif param.kind == param.VAR_KEYWORD:
63
+ # e.g. **kwargs handling
64
+ keyword_args.update(value or {})
65
+ elif param.kind in (param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD):
66
+ # Before *args, add to positional args. After *args, add to keyword args.
67
+ if not seen_var_positional:
68
+ positional_args.append(value)
69
+ else:
70
+ keyword_args[name] = value
71
+ else:
72
+ # For KEYWORD_ONLY parameters, always use keyword args.
73
+ keyword_args[name] = value
74
+ return positional_args, keyword_args
75
+
76
+
77
+ @dataclass
78
+ class FuncDocumentation:
79
+ """Contains metadata about a Python function, extracted from its docstring."""
80
+
81
+ name: str
82
+ """The name of the function, via `__name__`."""
83
+ description: str | None
84
+ """The description of the function, derived from the docstring."""
85
+ param_descriptions: dict[str, str] | None
86
+ """The parameter descriptions of the function, derived from the docstring."""
87
+
88
+
89
+ DocstringStyle = Literal["google", "numpy", "sphinx"]
90
+
91
+
92
+ # As of Feb 2025, the automatic style detection in griffe is an Insiders feature. This
93
+ # code approximates it.
94
+ def _detect_docstring_style(doc: str) -> DocstringStyle:
95
+ scores: dict[DocstringStyle, int] = {"sphinx": 0, "numpy": 0, "google": 0}
96
+
97
+ # Sphinx style detection: look for :param, :type, :return:, and :rtype:
98
+ sphinx_patterns = [r"^:param\s", r"^:type\s", r"^:return:", r"^:rtype:"]
99
+ for pattern in sphinx_patterns:
100
+ if re.search(pattern, doc, re.MULTILINE):
101
+ scores["sphinx"] += 1
102
+
103
+ # Numpy style detection: look for headers like 'Parameters', 'Returns', or 'Yields' followed by
104
+ # a dashed underline
105
+ numpy_patterns = [
106
+ r"^Parameters\s*\n\s*-{3,}",
107
+ r"^Returns\s*\n\s*-{3,}",
108
+ r"^Yields\s*\n\s*-{3,}",
109
+ ]
110
+ for pattern in numpy_patterns:
111
+ if re.search(pattern, doc, re.MULTILINE):
112
+ scores["numpy"] += 1
113
+
114
+ # Google style detection: look for section headers with a trailing colon
115
+ google_patterns = [r"^(Args|Arguments):", r"^(Returns):", r"^(Raises):"]
116
+ for pattern in google_patterns:
117
+ if re.search(pattern, doc, re.MULTILINE):
118
+ scores["google"] += 1
119
+
120
+ max_score = max(scores.values())
121
+ if max_score == 0:
122
+ return "google"
123
+
124
+ # Priority order: sphinx > numpy > google in case of tie
125
+ styles: list[DocstringStyle] = ["sphinx", "numpy", "google"]
126
+
127
+ for style in styles:
128
+ if scores[style] == max_score:
129
+ return style
130
+
131
+ return "google"
132
+
133
+
134
+ @contextlib.contextmanager
135
+ def _suppress_griffe_logging():
136
+ # Suppresses warnings about missing annotations for params
137
+ logger = logging.getLogger("griffe")
138
+ previous_level = logger.getEffectiveLevel()
139
+ logger.setLevel(logging.ERROR)
140
+ try:
141
+ yield
142
+ finally:
143
+ logger.setLevel(previous_level)
144
+
145
+
146
+ def generate_func_documentation(
147
+ func: Callable[..., Any], style: DocstringStyle | None = None
148
+ ) -> FuncDocumentation:
149
+ """
150
+ Extracts metadata from a function docstring, in preparation for sending it to an LLM as a tool.
151
+
152
+ Args:
153
+ func: The function to extract documentation from.
154
+ style: The style of the docstring to use for parsing. If not provided, we will attempt to
155
+ auto-detect the style.
156
+
157
+ Returns:
158
+ A FuncDocumentation object containing the function's name, description, and parameter
159
+ descriptions.
160
+ """
161
+ name = func.__name__
162
+ doc = inspect.getdoc(func)
163
+ if not doc:
164
+ return FuncDocumentation(name=name, description=None, param_descriptions=None)
165
+
166
+ with _suppress_griffe_logging():
167
+ docstring = Docstring(doc, lineno=1, parser=style or _detect_docstring_style(doc))
168
+ parsed = docstring.parse()
169
+
170
+ description: str | None = next(
171
+ (section.value for section in parsed if section.kind == DocstringSectionKind.text), None
172
+ )
173
+
174
+ param_descriptions: dict[str, str] = {
175
+ param.name: param.description
176
+ for section in parsed
177
+ if section.kind == DocstringSectionKind.parameters
178
+ for param in section.value
179
+ }
180
+
181
+ return FuncDocumentation(
182
+ name=func.__name__,
183
+ description=description,
184
+ param_descriptions=param_descriptions or None,
185
+ )
186
+
187
+
188
+ def _strip_annotated(annotation: Any) -> tuple[Any, tuple[Any, ...]]:
189
+ """Returns the underlying annotation and any metadata from typing.Annotated."""
190
+
191
+ metadata: tuple[Any, ...] = ()
192
+ ann = annotation
193
+
194
+ while get_origin(ann) is Annotated:
195
+ args = get_args(ann)
196
+ if not args:
197
+ break
198
+ ann = args[0]
199
+ metadata = (*metadata, *args[1:])
200
+
201
+ return ann, metadata
202
+
203
+
204
+ def _extract_description_from_metadata(metadata: tuple[Any, ...]) -> str | None:
205
+ """Extracts a human readable description from Annotated metadata if present."""
206
+
207
+ for item in metadata:
208
+ if isinstance(item, str):
209
+ return item
210
+ return None
211
+
212
+
213
+ def function_schema(
214
+ func: Callable[..., Any],
215
+ docstring_style: DocstringStyle | None = None,
216
+ name_override: str | None = None,
217
+ description_override: str | None = None,
218
+ use_docstring_info: bool = True,
219
+ strict_json_schema: bool = True,
220
+ ) -> FuncSchema:
221
+ """
222
+ Given a Python function, extracts a `FuncSchema` from it, capturing the name, description,
223
+ parameter descriptions, and other metadata.
224
+
225
+ Args:
226
+ func: The function to extract the schema from.
227
+ docstring_style: The style of the docstring to use for parsing. If not provided, we will
228
+ attempt to auto-detect the style.
229
+ name_override: If provided, use this name instead of the function's `__name__`.
230
+ description_override: If provided, use this description instead of the one derived from the
231
+ docstring.
232
+ use_docstring_info: If True, uses the docstring to generate the description and parameter
233
+ descriptions.
234
+ strict_json_schema: Whether the JSON schema is in strict mode. If True, we'll ensure that
235
+ the schema adheres to the "strict" standard the OpenAI API expects. We **strongly**
236
+ recommend setting this to True, as it increases the likelihood of the LLM producing
237
+ correct JSON input.
238
+
239
+ Returns:
240
+ A `FuncSchema` object containing the function's name, description, parameter descriptions,
241
+ and other metadata.
242
+ """
243
+
244
+ # 1. Grab docstring info
245
+ if use_docstring_info:
246
+ doc_info = generate_func_documentation(func, docstring_style)
247
+ param_descs = dict(doc_info.param_descriptions or {})
248
+ else:
249
+ doc_info = None
250
+ param_descs = {}
251
+
252
+ type_hints_with_extras = get_type_hints(func, include_extras=True)
253
+ type_hints: dict[str, Any] = {}
254
+ annotated_param_descs: dict[str, str] = {}
255
+
256
+ for name, annotation in type_hints_with_extras.items():
257
+ if name == "return":
258
+ continue
259
+
260
+ stripped_ann, metadata = _strip_annotated(annotation)
261
+ type_hints[name] = stripped_ann
262
+
263
+ description = _extract_description_from_metadata(metadata)
264
+ if description is not None:
265
+ annotated_param_descs[name] = description
266
+
267
+ for name, description in annotated_param_descs.items():
268
+ param_descs.setdefault(name, description)
269
+
270
+ # Ensure name_override takes precedence even if docstring info is disabled.
271
+ func_name = name_override or (doc_info.name if doc_info else func.__name__)
272
+
273
+ # 2. Inspect function signature and get type hints
274
+ sig = inspect.signature(func)
275
+ params = list(sig.parameters.items())
276
+ takes_context = False
277
+ filtered_params = []
278
+
279
+ if params:
280
+ first_name, first_param = params[0]
281
+ # Prefer the evaluated type hint if available
282
+ ann = type_hints.get(first_name, first_param.annotation)
283
+ if ann != inspect._empty:
284
+ origin = get_origin(ann) or ann
285
+ if origin is RunContextWrapper or origin is ToolContext:
286
+ takes_context = True # Mark that the function takes context
287
+ else:
288
+ filtered_params.append((first_name, first_param))
289
+ else:
290
+ filtered_params.append((first_name, first_param))
291
+
292
+ # For parameters other than the first, raise error if any use RunContextWrapper or ToolContext.
293
+ for name, param in params[1:]:
294
+ ann = type_hints.get(name, param.annotation)
295
+ if ann != inspect._empty:
296
+ origin = get_origin(ann) or ann
297
+ if origin is RunContextWrapper or origin is ToolContext:
298
+ raise UserError(
299
+ f"RunContextWrapper/ToolContext param found at non-first position in function"
300
+ f" {func.__name__}"
301
+ )
302
+ filtered_params.append((name, param))
303
+
304
+ # We will collect field definitions for create_model as a dict:
305
+ # field_name -> (type_annotation, default_value_or_Field(...))
306
+ fields: dict[str, Any] = {}
307
+
308
+ for name, param in filtered_params:
309
+ ann = type_hints.get(name, param.annotation)
310
+ default = param.default
311
+
312
+ # If there's no type hint, assume `Any`
313
+ if ann == inspect._empty:
314
+ ann = Any
315
+
316
+ # If a docstring param description exists, use it
317
+ field_description = param_descs.get(name, None)
318
+
319
+ # Handle different parameter kinds
320
+ if param.kind == param.VAR_POSITIONAL:
321
+ # e.g. *args: extend positional args
322
+ if get_origin(ann) is tuple:
323
+ # e.g. def foo(*args: tuple[int, ...]) -> treat as List[int]
324
+ args_of_tuple = get_args(ann)
325
+ if len(args_of_tuple) == 2 and args_of_tuple[1] is Ellipsis:
326
+ ann = list[args_of_tuple[0]] # type: ignore
327
+ else:
328
+ ann = list[Any]
329
+ else:
330
+ # If user wrote *args: int, treat as List[int]
331
+ ann = list[ann] # type: ignore
332
+
333
+ # Default factory to empty list
334
+ fields[name] = (
335
+ ann,
336
+ Field(default_factory=list, description=field_description),
337
+ )
338
+
339
+ elif param.kind == param.VAR_KEYWORD:
340
+ # **kwargs handling
341
+ if get_origin(ann) is dict:
342
+ # e.g. def foo(**kwargs: dict[str, int])
343
+ dict_args = get_args(ann)
344
+ if len(dict_args) == 2:
345
+ ann = dict[dict_args[0], dict_args[1]] # type: ignore
346
+ else:
347
+ ann = dict[str, Any]
348
+ else:
349
+ # e.g. def foo(**kwargs: int) -> Dict[str, int]
350
+ ann = dict[str, ann] # type: ignore
351
+
352
+ fields[name] = (
353
+ ann,
354
+ Field(default_factory=dict, description=field_description),
355
+ )
356
+
357
+ else:
358
+ # Normal parameter
359
+ if default == inspect._empty:
360
+ # Required field
361
+ fields[name] = (
362
+ ann,
363
+ Field(..., description=field_description),
364
+ )
365
+ elif isinstance(default, FieldInfo):
366
+ # Parameter with a default value that is a Field(...)
367
+ fields[name] = (
368
+ ann,
369
+ FieldInfo.merge_field_infos(
370
+ default, description=field_description or default.description
371
+ ),
372
+ )
373
+ else:
374
+ # Parameter with a default value
375
+ fields[name] = (
376
+ ann,
377
+ Field(default=default, description=field_description),
378
+ )
379
+
380
+ # 3. Dynamically build a Pydantic model
381
+ dynamic_model = create_model(f"{func_name}_args", __base__=BaseModel, **fields)
382
+
383
+ # 4. Build JSON schema from that model
384
+ json_schema = dynamic_model.model_json_schema()
385
+ if strict_json_schema:
386
+ json_schema = ensure_strict_json_schema(json_schema)
387
+
388
+ # 5. Return as a FuncSchema dataclass
389
+ return FuncSchema(
390
+ name=func_name,
391
+ # Ensure description_override takes precedence even if docstring info is disabled.
392
+ description=description_override or (doc_info.description if doc_info else None),
393
+ params_pydantic_model=dynamic_model,
394
+ params_json_schema=json_schema,
395
+ signature=sig,
396
+ takes_context=takes_context,
397
+ strict_json_schema=strict_json_schema,
398
+ )
agents/guardrail.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import inspect
4
+ from collections.abc import Awaitable
5
+ from dataclasses import dataclass
6
+ from typing import TYPE_CHECKING, Any, Callable, Generic, Union, overload
7
+
8
+ from typing_extensions import TypeVar
9
+
10
+ from .exceptions import UserError
11
+ from .items import TResponseInputItem
12
+ from .run_context import RunContextWrapper, TContext
13
+ from .util._types import MaybeAwaitable
14
+
15
+ if TYPE_CHECKING:
16
+ from .agent import Agent
17
+
18
+
19
+ @dataclass
20
+ class GuardrailFunctionOutput:
21
+ """The output of a guardrail function."""
22
+
23
+ output_info: Any
24
+ """
25
+ Optional information about the guardrail's output. For example, the guardrail could include
26
+ information about the checks it performed and granular results.
27
+ """
28
+
29
+ tripwire_triggered: bool
30
+ """
31
+ Whether the tripwire was triggered. If triggered, the agent's execution will be halted.
32
+ """
33
+
34
+
35
+ @dataclass
36
+ class InputGuardrailResult:
37
+ """The result of a guardrail run."""
38
+
39
+ guardrail: InputGuardrail[Any]
40
+ """
41
+ The guardrail that was run.
42
+ """
43
+
44
+ output: GuardrailFunctionOutput
45
+ """The output of the guardrail function."""
46
+
47
+
48
+ @dataclass
49
+ class OutputGuardrailResult:
50
+ """The result of a guardrail run."""
51
+
52
+ guardrail: OutputGuardrail[Any]
53
+ """
54
+ The guardrail that was run.
55
+ """
56
+
57
+ agent_output: Any
58
+ """
59
+ The output of the agent that was checked by the guardrail.
60
+ """
61
+
62
+ agent: Agent[Any]
63
+ """
64
+ The agent that was checked by the guardrail.
65
+ """
66
+
67
+ output: GuardrailFunctionOutput
68
+ """The output of the guardrail function."""
69
+
70
+
71
+ @dataclass
72
+ class InputGuardrail(Generic[TContext]):
73
+ """Input guardrails are checks that run in parallel to the agent's execution.
74
+ They can be used to do things like:
75
+ - Check if input messages are off-topic
76
+ - Take over control of the agent's execution if an unexpected input is detected
77
+
78
+ You can use the `@input_guardrail()` decorator to turn a function into an `InputGuardrail`, or
79
+ create an `InputGuardrail` manually.
80
+
81
+ Guardrails return a `GuardrailResult`. If `result.tripwire_triggered` is `True`,
82
+ the agent's execution will immediately stop, and
83
+ an `InputGuardrailTripwireTriggered` exception will be raised
84
+ """
85
+
86
+ guardrail_function: Callable[
87
+ [RunContextWrapper[TContext], Agent[Any], str | list[TResponseInputItem]],
88
+ MaybeAwaitable[GuardrailFunctionOutput],
89
+ ]
90
+ """A function that receives the agent input and the context, and returns a
91
+ `GuardrailResult`. The result marks whether the tripwire was triggered, and can optionally
92
+ include information about the guardrail's output.
93
+ """
94
+
95
+ name: str | None = None
96
+ """The name of the guardrail, used for tracing. If not provided, we'll use the guardrail
97
+ function's name.
98
+ """
99
+
100
+ def get_name(self) -> str:
101
+ if self.name:
102
+ return self.name
103
+
104
+ return self.guardrail_function.__name__
105
+
106
+ async def run(
107
+ self,
108
+ agent: Agent[Any],
109
+ input: str | list[TResponseInputItem],
110
+ context: RunContextWrapper[TContext],
111
+ ) -> InputGuardrailResult:
112
+ if not callable(self.guardrail_function):
113
+ raise UserError(f"Guardrail function must be callable, got {self.guardrail_function}")
114
+
115
+ output = self.guardrail_function(context, agent, input)
116
+ if inspect.isawaitable(output):
117
+ return InputGuardrailResult(
118
+ guardrail=self,
119
+ output=await output,
120
+ )
121
+
122
+ return InputGuardrailResult(
123
+ guardrail=self,
124
+ output=output,
125
+ )
126
+
127
+
128
+ @dataclass
129
+ class OutputGuardrail(Generic[TContext]):
130
+ """Output guardrails are checks that run on the final output of an agent.
131
+ They can be used to do check if the output passes certain validation criteria
132
+
133
+ You can use the `@output_guardrail()` decorator to turn a function into an `OutputGuardrail`,
134
+ or create an `OutputGuardrail` manually.
135
+
136
+ Guardrails return a `GuardrailResult`. If `result.tripwire_triggered` is `True`, an
137
+ `OutputGuardrailTripwireTriggered` exception will be raised.
138
+ """
139
+
140
+ guardrail_function: Callable[
141
+ [RunContextWrapper[TContext], Agent[Any], Any],
142
+ MaybeAwaitable[GuardrailFunctionOutput],
143
+ ]
144
+ """A function that receives the final agent, its output, and the context, and returns a
145
+ `GuardrailResult`. The result marks whether the tripwire was triggered, and can optionally
146
+ include information about the guardrail's output.
147
+ """
148
+
149
+ name: str | None = None
150
+ """The name of the guardrail, used for tracing. If not provided, we'll use the guardrail
151
+ function's name.
152
+ """
153
+
154
+ def get_name(self) -> str:
155
+ if self.name:
156
+ return self.name
157
+
158
+ return self.guardrail_function.__name__
159
+
160
+ async def run(
161
+ self, context: RunContextWrapper[TContext], agent: Agent[Any], agent_output: Any
162
+ ) -> OutputGuardrailResult:
163
+ if not callable(self.guardrail_function):
164
+ raise UserError(f"Guardrail function must be callable, got {self.guardrail_function}")
165
+
166
+ output = self.guardrail_function(context, agent, agent_output)
167
+ if inspect.isawaitable(output):
168
+ return OutputGuardrailResult(
169
+ guardrail=self,
170
+ agent=agent,
171
+ agent_output=agent_output,
172
+ output=await output,
173
+ )
174
+
175
+ return OutputGuardrailResult(
176
+ guardrail=self,
177
+ agent=agent,
178
+ agent_output=agent_output,
179
+ output=output,
180
+ )
181
+
182
+
183
+ TContext_co = TypeVar("TContext_co", bound=Any, covariant=True)
184
+
185
+ # For InputGuardrail
186
+ _InputGuardrailFuncSync = Callable[
187
+ [RunContextWrapper[TContext_co], "Agent[Any]", Union[str, list[TResponseInputItem]]],
188
+ GuardrailFunctionOutput,
189
+ ]
190
+ _InputGuardrailFuncAsync = Callable[
191
+ [RunContextWrapper[TContext_co], "Agent[Any]", Union[str, list[TResponseInputItem]]],
192
+ Awaitable[GuardrailFunctionOutput],
193
+ ]
194
+
195
+
196
+ @overload
197
+ def input_guardrail(
198
+ func: _InputGuardrailFuncSync[TContext_co],
199
+ ) -> InputGuardrail[TContext_co]: ...
200
+
201
+
202
+ @overload
203
+ def input_guardrail(
204
+ func: _InputGuardrailFuncAsync[TContext_co],
205
+ ) -> InputGuardrail[TContext_co]: ...
206
+
207
+
208
+ @overload
209
+ def input_guardrail(
210
+ *,
211
+ name: str | None = None,
212
+ ) -> Callable[
213
+ [_InputGuardrailFuncSync[TContext_co] | _InputGuardrailFuncAsync[TContext_co]],
214
+ InputGuardrail[TContext_co],
215
+ ]: ...
216
+
217
+
218
+ def input_guardrail(
219
+ func: _InputGuardrailFuncSync[TContext_co]
220
+ | _InputGuardrailFuncAsync[TContext_co]
221
+ | None = None,
222
+ *,
223
+ name: str | None = None,
224
+ ) -> (
225
+ InputGuardrail[TContext_co]
226
+ | Callable[
227
+ [_InputGuardrailFuncSync[TContext_co] | _InputGuardrailFuncAsync[TContext_co]],
228
+ InputGuardrail[TContext_co],
229
+ ]
230
+ ):
231
+ """
232
+ Decorator that transforms a sync or async function into an `InputGuardrail`.
233
+ It can be used directly (no parentheses) or with keyword args, e.g.:
234
+
235
+ @input_guardrail
236
+ def my_sync_guardrail(...): ...
237
+
238
+ @input_guardrail(name="guardrail_name")
239
+ async def my_async_guardrail(...): ...
240
+ """
241
+
242
+ def decorator(
243
+ f: _InputGuardrailFuncSync[TContext_co] | _InputGuardrailFuncAsync[TContext_co],
244
+ ) -> InputGuardrail[TContext_co]:
245
+ return InputGuardrail(
246
+ guardrail_function=f,
247
+ # If not set, guardrail name uses the function’s name by default.
248
+ name=name if name else f.__name__,
249
+ )
250
+
251
+ if func is not None:
252
+ # Decorator was used without parentheses
253
+ return decorator(func)
254
+
255
+ # Decorator used with keyword arguments
256
+ return decorator
257
+
258
+
259
+ _OutputGuardrailFuncSync = Callable[
260
+ [RunContextWrapper[TContext_co], "Agent[Any]", Any],
261
+ GuardrailFunctionOutput,
262
+ ]
263
+ _OutputGuardrailFuncAsync = Callable[
264
+ [RunContextWrapper[TContext_co], "Agent[Any]", Any],
265
+ Awaitable[GuardrailFunctionOutput],
266
+ ]
267
+
268
+
269
+ @overload
270
+ def output_guardrail(
271
+ func: _OutputGuardrailFuncSync[TContext_co],
272
+ ) -> OutputGuardrail[TContext_co]: ...
273
+
274
+
275
+ @overload
276
+ def output_guardrail(
277
+ func: _OutputGuardrailFuncAsync[TContext_co],
278
+ ) -> OutputGuardrail[TContext_co]: ...
279
+
280
+
281
+ @overload
282
+ def output_guardrail(
283
+ *,
284
+ name: str | None = None,
285
+ ) -> Callable[
286
+ [_OutputGuardrailFuncSync[TContext_co] | _OutputGuardrailFuncAsync[TContext_co]],
287
+ OutputGuardrail[TContext_co],
288
+ ]: ...
289
+
290
+
291
+ def output_guardrail(
292
+ func: _OutputGuardrailFuncSync[TContext_co]
293
+ | _OutputGuardrailFuncAsync[TContext_co]
294
+ | None = None,
295
+ *,
296
+ name: str | None = None,
297
+ ) -> (
298
+ OutputGuardrail[TContext_co]
299
+ | Callable[
300
+ [_OutputGuardrailFuncSync[TContext_co] | _OutputGuardrailFuncAsync[TContext_co]],
301
+ OutputGuardrail[TContext_co],
302
+ ]
303
+ ):
304
+ """
305
+ Decorator that transforms a sync or async function into an `OutputGuardrail`.
306
+ It can be used directly (no parentheses) or with keyword args, e.g.:
307
+
308
+ @output_guardrail
309
+ def my_sync_guardrail(...): ...
310
+
311
+ @output_guardrail(name="guardrail_name")
312
+ async def my_async_guardrail(...): ...
313
+ """
314
+
315
+ def decorator(
316
+ f: _OutputGuardrailFuncSync[TContext_co] | _OutputGuardrailFuncAsync[TContext_co],
317
+ ) -> OutputGuardrail[TContext_co]:
318
+ return OutputGuardrail(
319
+ guardrail_function=f,
320
+ # Guardrail name defaults to function's name when not specified (None).
321
+ name=name if name else f.__name__,
322
+ )
323
+
324
+ if func is not None:
325
+ # Decorator was used without parentheses
326
+ return decorator(func)
327
+
328
+ # Decorator used with keyword arguments
329
+ return decorator