Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +3 -9
- agents/__init__.py +319 -0
- agents/__pycache__/__init__.cpython-312.pyc +0 -0
- agents/__pycache__/_config.cpython-312.pyc +0 -0
- agents/__pycache__/_debug.cpython-312.pyc +0 -0
- agents/__pycache__/_run_impl.cpython-312.pyc +0 -0
- agents/__pycache__/agent.cpython-312.pyc +0 -0
- agents/__pycache__/agent_output.cpython-312.pyc +0 -0
- agents/__pycache__/computer.cpython-312.pyc +0 -0
- agents/__pycache__/exceptions.cpython-312.pyc +0 -0
- agents/__pycache__/function_schema.cpython-312.pyc +0 -0
- agents/__pycache__/guardrail.cpython-312.pyc +0 -0
- agents/__pycache__/handoffs.cpython-312.pyc +0 -0
- agents/__pycache__/items.cpython-312.pyc +0 -0
- agents/__pycache__/lifecycle.cpython-312.pyc +0 -0
- agents/__pycache__/logger.cpython-312.pyc +0 -0
- agents/__pycache__/model_settings.cpython-312.pyc +0 -0
- agents/__pycache__/prompts.cpython-312.pyc +0 -0
- agents/__pycache__/repl.cpython-312.pyc +0 -0
- agents/__pycache__/result.cpython-312.pyc +0 -0
- agents/__pycache__/run.cpython-312.pyc +0 -0
- agents/__pycache__/run_context.cpython-312.pyc +0 -0
- agents/__pycache__/stream_events.cpython-312.pyc +0 -0
- agents/__pycache__/strict_schema.cpython-312.pyc +0 -0
- agents/__pycache__/tool.cpython-312.pyc +0 -0
- agents/__pycache__/tool_context.cpython-312.pyc +0 -0
- agents/__pycache__/tool_guardrails.cpython-312.pyc +0 -0
- agents/__pycache__/usage.cpython-312.pyc +0 -0
- agents/__pycache__/version.cpython-312.pyc +0 -0
- agents/_config.py +26 -0
- agents/_debug.py +28 -0
- agents/_run_impl.py +1442 -0
- agents/agent.py +476 -0
- agents/agent_output.py +194 -0
- agents/computer.py +107 -0
- agents/exceptions.py +131 -0
- agents/extensions/__init__.py +0 -0
- agents/extensions/handoff_filters.py +70 -0
- agents/extensions/handoff_prompt.py +19 -0
- agents/extensions/memory/__init__.py +65 -0
- agents/extensions/memory/advanced_sqlite_session.py +1285 -0
- agents/extensions/memory/encrypt_session.py +185 -0
- agents/extensions/memory/redis_session.py +267 -0
- agents/extensions/memory/sqlalchemy_session.py +312 -0
- agents/extensions/models/__init__.py +0 -0
- agents/extensions/models/litellm_model.py +601 -0
- agents/extensions/models/litellm_provider.py +23 -0
- agents/extensions/visualization.py +165 -0
- agents/function_schema.py +398 -0
- agents/guardrail.py +329 -0
README.md
CHANGED
|
@@ -1,12 +1,6 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji: 🔥
|
| 4 |
-
colorFrom: gray
|
| 5 |
-
colorTo: blue
|
| 6 |
-
sdk: gradio
|
| 7 |
-
sdk_version: 6.1.0
|
| 8 |
app_file: app.py
|
| 9 |
-
|
|
|
|
| 10 |
---
|
| 11 |
-
|
| 12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 1 |
---
|
| 2 |
+
title: deep_research-personal
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
app_file: app.py
|
| 4 |
+
sdk: gradio
|
| 5 |
+
sdk_version: 6.0.2
|
| 6 |
---
|
|
|
|
|
|
agents/__init__.py
ADDED
|
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import sys
|
| 3 |
+
from typing import Literal
|
| 4 |
+
|
| 5 |
+
from openai import AsyncOpenAI
|
| 6 |
+
|
| 7 |
+
from . import _config
|
| 8 |
+
from .agent import (
|
| 9 |
+
Agent,
|
| 10 |
+
AgentBase,
|
| 11 |
+
StopAtTools,
|
| 12 |
+
ToolsToFinalOutputFunction,
|
| 13 |
+
ToolsToFinalOutputResult,
|
| 14 |
+
)
|
| 15 |
+
from .agent_output import AgentOutputSchema, AgentOutputSchemaBase
|
| 16 |
+
from .computer import AsyncComputer, Button, Computer, Environment
|
| 17 |
+
from .exceptions import (
|
| 18 |
+
AgentsException,
|
| 19 |
+
InputGuardrailTripwireTriggered,
|
| 20 |
+
MaxTurnsExceeded,
|
| 21 |
+
ModelBehaviorError,
|
| 22 |
+
OutputGuardrailTripwireTriggered,
|
| 23 |
+
RunErrorDetails,
|
| 24 |
+
ToolInputGuardrailTripwireTriggered,
|
| 25 |
+
ToolOutputGuardrailTripwireTriggered,
|
| 26 |
+
UserError,
|
| 27 |
+
)
|
| 28 |
+
from .guardrail import (
|
| 29 |
+
GuardrailFunctionOutput,
|
| 30 |
+
InputGuardrail,
|
| 31 |
+
InputGuardrailResult,
|
| 32 |
+
OutputGuardrail,
|
| 33 |
+
OutputGuardrailResult,
|
| 34 |
+
input_guardrail,
|
| 35 |
+
output_guardrail,
|
| 36 |
+
)
|
| 37 |
+
from .handoffs import Handoff, HandoffInputData, HandoffInputFilter, handoff
|
| 38 |
+
from .items import (
|
| 39 |
+
HandoffCallItem,
|
| 40 |
+
HandoffOutputItem,
|
| 41 |
+
ItemHelpers,
|
| 42 |
+
MessageOutputItem,
|
| 43 |
+
ModelResponse,
|
| 44 |
+
ReasoningItem,
|
| 45 |
+
RunItem,
|
| 46 |
+
ToolCallItem,
|
| 47 |
+
ToolCallOutputItem,
|
| 48 |
+
TResponseInputItem,
|
| 49 |
+
)
|
| 50 |
+
from .lifecycle import AgentHooks, RunHooks
|
| 51 |
+
from .memory import OpenAIConversationsSession, Session, SessionABC, SQLiteSession
|
| 52 |
+
from .model_settings import ModelSettings
|
| 53 |
+
from .models.interface import Model, ModelProvider, ModelTracing
|
| 54 |
+
from .models.multi_provider import MultiProvider
|
| 55 |
+
from .models.openai_chatcompletions import OpenAIChatCompletionsModel
|
| 56 |
+
from .models.openai_provider import OpenAIProvider
|
| 57 |
+
from .models.openai_responses import OpenAIResponsesModel
|
| 58 |
+
from .prompts import DynamicPromptFunction, GenerateDynamicPromptData, Prompt
|
| 59 |
+
from .repl import run_demo_loop
|
| 60 |
+
from .result import RunResult, RunResultStreaming
|
| 61 |
+
from .run import RunConfig, Runner
|
| 62 |
+
from .run_context import RunContextWrapper, TContext
|
| 63 |
+
from .stream_events import (
|
| 64 |
+
AgentUpdatedStreamEvent,
|
| 65 |
+
RawResponsesStreamEvent,
|
| 66 |
+
RunItemStreamEvent,
|
| 67 |
+
StreamEvent,
|
| 68 |
+
)
|
| 69 |
+
from .tool import (
|
| 70 |
+
CodeInterpreterTool,
|
| 71 |
+
ComputerTool,
|
| 72 |
+
FileSearchTool,
|
| 73 |
+
FunctionTool,
|
| 74 |
+
FunctionToolResult,
|
| 75 |
+
HostedMCPTool,
|
| 76 |
+
ImageGenerationTool,
|
| 77 |
+
LocalShellCommandRequest,
|
| 78 |
+
LocalShellExecutor,
|
| 79 |
+
LocalShellTool,
|
| 80 |
+
MCPToolApprovalFunction,
|
| 81 |
+
MCPToolApprovalFunctionResult,
|
| 82 |
+
MCPToolApprovalRequest,
|
| 83 |
+
Tool,
|
| 84 |
+
WebSearchTool,
|
| 85 |
+
default_tool_error_function,
|
| 86 |
+
function_tool,
|
| 87 |
+
)
|
| 88 |
+
from .tool_guardrails import (
|
| 89 |
+
ToolGuardrailFunctionOutput,
|
| 90 |
+
ToolInputGuardrail,
|
| 91 |
+
ToolInputGuardrailData,
|
| 92 |
+
ToolInputGuardrailResult,
|
| 93 |
+
ToolOutputGuardrail,
|
| 94 |
+
ToolOutputGuardrailData,
|
| 95 |
+
ToolOutputGuardrailResult,
|
| 96 |
+
tool_input_guardrail,
|
| 97 |
+
tool_output_guardrail,
|
| 98 |
+
)
|
| 99 |
+
from .tracing import (
|
| 100 |
+
AgentSpanData,
|
| 101 |
+
CustomSpanData,
|
| 102 |
+
FunctionSpanData,
|
| 103 |
+
GenerationSpanData,
|
| 104 |
+
GuardrailSpanData,
|
| 105 |
+
HandoffSpanData,
|
| 106 |
+
MCPListToolsSpanData,
|
| 107 |
+
Span,
|
| 108 |
+
SpanData,
|
| 109 |
+
SpanError,
|
| 110 |
+
SpeechGroupSpanData,
|
| 111 |
+
SpeechSpanData,
|
| 112 |
+
Trace,
|
| 113 |
+
TracingProcessor,
|
| 114 |
+
TranscriptionSpanData,
|
| 115 |
+
add_trace_processor,
|
| 116 |
+
agent_span,
|
| 117 |
+
custom_span,
|
| 118 |
+
function_span,
|
| 119 |
+
gen_span_id,
|
| 120 |
+
gen_trace_id,
|
| 121 |
+
generation_span,
|
| 122 |
+
get_current_span,
|
| 123 |
+
get_current_trace,
|
| 124 |
+
guardrail_span,
|
| 125 |
+
handoff_span,
|
| 126 |
+
mcp_tools_span,
|
| 127 |
+
set_trace_processors,
|
| 128 |
+
set_trace_provider,
|
| 129 |
+
set_tracing_disabled,
|
| 130 |
+
set_tracing_export_api_key,
|
| 131 |
+
speech_group_span,
|
| 132 |
+
speech_span,
|
| 133 |
+
trace,
|
| 134 |
+
transcription_span,
|
| 135 |
+
)
|
| 136 |
+
from .usage import Usage
|
| 137 |
+
from .version import __version__
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def set_default_openai_key(key: str, use_for_tracing: bool = True) -> None:
|
| 141 |
+
"""Set the default OpenAI API key to use for LLM requests (and optionally tracing()). This is
|
| 142 |
+
only necessary if the OPENAI_API_KEY environment variable is not already set.
|
| 143 |
+
|
| 144 |
+
If provided, this key will be used instead of the OPENAI_API_KEY environment variable.
|
| 145 |
+
|
| 146 |
+
Args:
|
| 147 |
+
key: The OpenAI key to use.
|
| 148 |
+
use_for_tracing: Whether to also use this key to send traces to OpenAI. Defaults to True
|
| 149 |
+
If False, you'll either need to set the OPENAI_API_KEY environment variable or call
|
| 150 |
+
set_tracing_export_api_key() with the API key you want to use for tracing.
|
| 151 |
+
"""
|
| 152 |
+
_config.set_default_openai_key(key, use_for_tracing)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def set_default_openai_client(client: AsyncOpenAI, use_for_tracing: bool = True) -> None:
|
| 156 |
+
"""Set the default OpenAI client to use for LLM requests and/or tracing. If provided, this
|
| 157 |
+
client will be used instead of the default OpenAI client.
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
client: The OpenAI client to use.
|
| 161 |
+
use_for_tracing: Whether to use the API key from this client for uploading traces. If False,
|
| 162 |
+
you'll either need to set the OPENAI_API_KEY environment variable or call
|
| 163 |
+
set_tracing_export_api_key() with the API key you want to use for tracing.
|
| 164 |
+
"""
|
| 165 |
+
_config.set_default_openai_client(client, use_for_tracing)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def set_default_openai_api(api: Literal["chat_completions", "responses"]) -> None:
|
| 169 |
+
"""Set the default API to use for OpenAI LLM requests. By default, we will use the responses API
|
| 170 |
+
but you can set this to use the chat completions API instead.
|
| 171 |
+
"""
|
| 172 |
+
_config.set_default_openai_api(api)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def enable_verbose_stdout_logging():
|
| 176 |
+
"""Enables verbose logging to stdout. This is useful for debugging."""
|
| 177 |
+
logger = logging.getLogger("openai.agents")
|
| 178 |
+
logger.setLevel(logging.DEBUG)
|
| 179 |
+
logger.addHandler(logging.StreamHandler(sys.stdout))
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
__all__ = [
|
| 183 |
+
"Agent",
|
| 184 |
+
"AgentBase",
|
| 185 |
+
"StopAtTools",
|
| 186 |
+
"ToolsToFinalOutputFunction",
|
| 187 |
+
"ToolsToFinalOutputResult",
|
| 188 |
+
"Runner",
|
| 189 |
+
"run_demo_loop",
|
| 190 |
+
"Model",
|
| 191 |
+
"ModelProvider",
|
| 192 |
+
"ModelTracing",
|
| 193 |
+
"ModelSettings",
|
| 194 |
+
"OpenAIChatCompletionsModel",
|
| 195 |
+
"MultiProvider",
|
| 196 |
+
"OpenAIProvider",
|
| 197 |
+
"OpenAIResponsesModel",
|
| 198 |
+
"AgentOutputSchema",
|
| 199 |
+
"AgentOutputSchemaBase",
|
| 200 |
+
"Computer",
|
| 201 |
+
"AsyncComputer",
|
| 202 |
+
"Environment",
|
| 203 |
+
"Button",
|
| 204 |
+
"AgentsException",
|
| 205 |
+
"InputGuardrailTripwireTriggered",
|
| 206 |
+
"OutputGuardrailTripwireTriggered",
|
| 207 |
+
"ToolInputGuardrailTripwireTriggered",
|
| 208 |
+
"ToolOutputGuardrailTripwireTriggered",
|
| 209 |
+
"DynamicPromptFunction",
|
| 210 |
+
"GenerateDynamicPromptData",
|
| 211 |
+
"Prompt",
|
| 212 |
+
"MaxTurnsExceeded",
|
| 213 |
+
"ModelBehaviorError",
|
| 214 |
+
"UserError",
|
| 215 |
+
"InputGuardrail",
|
| 216 |
+
"InputGuardrailResult",
|
| 217 |
+
"OutputGuardrail",
|
| 218 |
+
"OutputGuardrailResult",
|
| 219 |
+
"GuardrailFunctionOutput",
|
| 220 |
+
"input_guardrail",
|
| 221 |
+
"output_guardrail",
|
| 222 |
+
"ToolInputGuardrail",
|
| 223 |
+
"ToolOutputGuardrail",
|
| 224 |
+
"ToolGuardrailFunctionOutput",
|
| 225 |
+
"ToolInputGuardrailData",
|
| 226 |
+
"ToolInputGuardrailResult",
|
| 227 |
+
"ToolOutputGuardrailData",
|
| 228 |
+
"ToolOutputGuardrailResult",
|
| 229 |
+
"tool_input_guardrail",
|
| 230 |
+
"tool_output_guardrail",
|
| 231 |
+
"handoff",
|
| 232 |
+
"Handoff",
|
| 233 |
+
"HandoffInputData",
|
| 234 |
+
"HandoffInputFilter",
|
| 235 |
+
"TResponseInputItem",
|
| 236 |
+
"MessageOutputItem",
|
| 237 |
+
"ModelResponse",
|
| 238 |
+
"RunItem",
|
| 239 |
+
"HandoffCallItem",
|
| 240 |
+
"HandoffOutputItem",
|
| 241 |
+
"ToolCallItem",
|
| 242 |
+
"ToolCallOutputItem",
|
| 243 |
+
"ReasoningItem",
|
| 244 |
+
"ItemHelpers",
|
| 245 |
+
"RunHooks",
|
| 246 |
+
"AgentHooks",
|
| 247 |
+
"Session",
|
| 248 |
+
"SessionABC",
|
| 249 |
+
"SQLiteSession",
|
| 250 |
+
"OpenAIConversationsSession",
|
| 251 |
+
"RunContextWrapper",
|
| 252 |
+
"TContext",
|
| 253 |
+
"RunErrorDetails",
|
| 254 |
+
"RunResult",
|
| 255 |
+
"RunResultStreaming",
|
| 256 |
+
"RunConfig",
|
| 257 |
+
"RawResponsesStreamEvent",
|
| 258 |
+
"RunItemStreamEvent",
|
| 259 |
+
"AgentUpdatedStreamEvent",
|
| 260 |
+
"StreamEvent",
|
| 261 |
+
"FunctionTool",
|
| 262 |
+
"FunctionToolResult",
|
| 263 |
+
"ComputerTool",
|
| 264 |
+
"FileSearchTool",
|
| 265 |
+
"CodeInterpreterTool",
|
| 266 |
+
"ImageGenerationTool",
|
| 267 |
+
"LocalShellCommandRequest",
|
| 268 |
+
"LocalShellExecutor",
|
| 269 |
+
"LocalShellTool",
|
| 270 |
+
"Tool",
|
| 271 |
+
"WebSearchTool",
|
| 272 |
+
"HostedMCPTool",
|
| 273 |
+
"MCPToolApprovalFunction",
|
| 274 |
+
"MCPToolApprovalRequest",
|
| 275 |
+
"MCPToolApprovalFunctionResult",
|
| 276 |
+
"function_tool",
|
| 277 |
+
"Usage",
|
| 278 |
+
"add_trace_processor",
|
| 279 |
+
"agent_span",
|
| 280 |
+
"custom_span",
|
| 281 |
+
"function_span",
|
| 282 |
+
"generation_span",
|
| 283 |
+
"get_current_span",
|
| 284 |
+
"get_current_trace",
|
| 285 |
+
"guardrail_span",
|
| 286 |
+
"handoff_span",
|
| 287 |
+
"set_trace_processors",
|
| 288 |
+
"set_trace_provider",
|
| 289 |
+
"set_tracing_disabled",
|
| 290 |
+
"speech_group_span",
|
| 291 |
+
"transcription_span",
|
| 292 |
+
"speech_span",
|
| 293 |
+
"mcp_tools_span",
|
| 294 |
+
"trace",
|
| 295 |
+
"Trace",
|
| 296 |
+
"TracingProcessor",
|
| 297 |
+
"SpanError",
|
| 298 |
+
"Span",
|
| 299 |
+
"SpanData",
|
| 300 |
+
"AgentSpanData",
|
| 301 |
+
"CustomSpanData",
|
| 302 |
+
"FunctionSpanData",
|
| 303 |
+
"GenerationSpanData",
|
| 304 |
+
"GuardrailSpanData",
|
| 305 |
+
"HandoffSpanData",
|
| 306 |
+
"SpeechGroupSpanData",
|
| 307 |
+
"SpeechSpanData",
|
| 308 |
+
"MCPListToolsSpanData",
|
| 309 |
+
"TranscriptionSpanData",
|
| 310 |
+
"set_default_openai_key",
|
| 311 |
+
"set_default_openai_client",
|
| 312 |
+
"set_default_openai_api",
|
| 313 |
+
"set_tracing_export_api_key",
|
| 314 |
+
"enable_verbose_stdout_logging",
|
| 315 |
+
"gen_trace_id",
|
| 316 |
+
"gen_span_id",
|
| 317 |
+
"default_tool_error_function",
|
| 318 |
+
"__version__",
|
| 319 |
+
]
|
agents/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (7.96 kB). View file
|
|
|
agents/__pycache__/_config.cpython-312.pyc
ADDED
|
Binary file (1.36 kB). View file
|
|
|
agents/__pycache__/_debug.cpython-312.pyc
ADDED
|
Binary file (1.1 kB). View file
|
|
|
agents/__pycache__/_run_impl.cpython-312.pyc
ADDED
|
Binary file (54.7 kB). View file
|
|
|
agents/__pycache__/agent.cpython-312.pyc
ADDED
|
Binary file (18.7 kB). View file
|
|
|
agents/__pycache__/agent_output.cpython-312.pyc
ADDED
|
Binary file (8.27 kB). View file
|
|
|
agents/__pycache__/computer.cpython-312.pyc
ADDED
|
Binary file (5.74 kB). View file
|
|
|
agents/__pycache__/exceptions.cpython-312.pyc
ADDED
|
Binary file (6.27 kB). View file
|
|
|
agents/__pycache__/function_schema.cpython-312.pyc
ADDED
|
Binary file (13.9 kB). View file
|
|
|
agents/__pycache__/guardrail.cpython-312.pyc
ADDED
|
Binary file (10 kB). View file
|
|
|
agents/__pycache__/handoffs.cpython-312.pyc
ADDED
|
Binary file (10.6 kB). View file
|
|
|
agents/__pycache__/items.cpython-312.pyc
ADDED
|
Binary file (10.6 kB). View file
|
|
|
agents/__pycache__/lifecycle.cpython-312.pyc
ADDED
|
Binary file (5.49 kB). View file
|
|
|
agents/__pycache__/logger.cpython-312.pyc
ADDED
|
Binary file (271 Bytes). View file
|
|
|
agents/__pycache__/model_settings.cpython-312.pyc
ADDED
|
Binary file (6.23 kB). View file
|
|
|
agents/__pycache__/prompts.cpython-312.pyc
ADDED
|
Binary file (2.76 kB). View file
|
|
|
agents/__pycache__/repl.cpython-312.pyc
ADDED
|
Binary file (3.44 kB). View file
|
|
|
agents/__pycache__/result.cpython-312.pyc
ADDED
|
Binary file (14.3 kB). View file
|
|
|
agents/__pycache__/run.cpython-312.pyc
ADDED
|
Binary file (57.2 kB). View file
|
|
|
agents/__pycache__/run_context.cpython-312.pyc
ADDED
|
Binary file (1.12 kB). View file
|
|
|
agents/__pycache__/stream_events.cpython-312.pyc
ADDED
|
Binary file (2.15 kB). View file
|
|
|
agents/__pycache__/strict_schema.cpython-312.pyc
ADDED
|
Binary file (5.98 kB). View file
|
|
|
agents/__pycache__/tool.cpython-312.pyc
ADDED
|
Binary file (18.3 kB). View file
|
|
|
agents/__pycache__/tool_context.cpython-312.pyc
ADDED
|
Binary file (2.57 kB). View file
|
|
|
agents/__pycache__/tool_guardrails.cpython-312.pyc
ADDED
|
Binary file (10.4 kB). View file
|
|
|
agents/__pycache__/usage.cpython-312.pyc
ADDED
|
Binary file (2.31 kB). View file
|
|
|
agents/__pycache__/version.cpython-312.pyc
ADDED
|
Binary file (477 Bytes). View file
|
|
|
agents/_config.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from openai import AsyncOpenAI
|
| 2 |
+
from typing_extensions import Literal
|
| 3 |
+
|
| 4 |
+
from .models import _openai_shared
|
| 5 |
+
from .tracing import set_tracing_export_api_key
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def set_default_openai_key(key: str, use_for_tracing: bool) -> None:
|
| 9 |
+
_openai_shared.set_default_openai_key(key)
|
| 10 |
+
|
| 11 |
+
if use_for_tracing:
|
| 12 |
+
set_tracing_export_api_key(key)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def set_default_openai_client(client: AsyncOpenAI, use_for_tracing: bool) -> None:
|
| 16 |
+
_openai_shared.set_default_openai_client(client)
|
| 17 |
+
|
| 18 |
+
if use_for_tracing:
|
| 19 |
+
set_tracing_export_api_key(client.api_key)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def set_default_openai_api(api: Literal["chat_completions", "responses"]) -> None:
|
| 23 |
+
if api == "chat_completions":
|
| 24 |
+
_openai_shared.set_use_responses_by_default(False)
|
| 25 |
+
else:
|
| 26 |
+
_openai_shared.set_use_responses_by_default(True)
|
agents/_debug.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def _debug_flag_enabled(flag: str, default: bool = False) -> bool:
|
| 5 |
+
flag_value = os.getenv(flag)
|
| 6 |
+
if flag_value is None:
|
| 7 |
+
return default
|
| 8 |
+
else:
|
| 9 |
+
return flag_value == "1" or flag_value.lower() == "true"
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def _load_dont_log_model_data() -> bool:
|
| 13 |
+
return _debug_flag_enabled("OPENAI_AGENTS_DONT_LOG_MODEL_DATA", default=True)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def _load_dont_log_tool_data() -> bool:
|
| 17 |
+
return _debug_flag_enabled("OPENAI_AGENTS_DONT_LOG_TOOL_DATA", default=True)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
DONT_LOG_MODEL_DATA = _load_dont_log_model_data()
|
| 21 |
+
"""By default we don't log LLM inputs/outputs, to prevent exposing sensitive information. Set this
|
| 22 |
+
flag to enable logging them.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
DONT_LOG_TOOL_DATA = _load_dont_log_tool_data()
|
| 26 |
+
"""By default we don't log tool call inputs/outputs, to prevent exposing sensitive information. Set
|
| 27 |
+
this flag to enable logging them.
|
| 28 |
+
"""
|
agents/_run_impl.py
ADDED
|
@@ -0,0 +1,1442 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import asyncio
|
| 4 |
+
import dataclasses
|
| 5 |
+
import inspect
|
| 6 |
+
from collections.abc import Awaitable
|
| 7 |
+
from dataclasses import dataclass, field
|
| 8 |
+
from typing import TYPE_CHECKING, Any, cast
|
| 9 |
+
|
| 10 |
+
from openai.types.responses import (
|
| 11 |
+
ResponseComputerToolCall,
|
| 12 |
+
ResponseFileSearchToolCall,
|
| 13 |
+
ResponseFunctionToolCall,
|
| 14 |
+
ResponseFunctionWebSearch,
|
| 15 |
+
ResponseOutputMessage,
|
| 16 |
+
)
|
| 17 |
+
from openai.types.responses.response_code_interpreter_tool_call import (
|
| 18 |
+
ResponseCodeInterpreterToolCall,
|
| 19 |
+
)
|
| 20 |
+
from openai.types.responses.response_computer_tool_call import (
|
| 21 |
+
ActionClick,
|
| 22 |
+
ActionDoubleClick,
|
| 23 |
+
ActionDrag,
|
| 24 |
+
ActionKeypress,
|
| 25 |
+
ActionMove,
|
| 26 |
+
ActionScreenshot,
|
| 27 |
+
ActionScroll,
|
| 28 |
+
ActionType,
|
| 29 |
+
ActionWait,
|
| 30 |
+
)
|
| 31 |
+
from openai.types.responses.response_input_item_param import (
|
| 32 |
+
ComputerCallOutputAcknowledgedSafetyCheck,
|
| 33 |
+
)
|
| 34 |
+
from openai.types.responses.response_input_param import ComputerCallOutput, McpApprovalResponse
|
| 35 |
+
from openai.types.responses.response_output_item import (
|
| 36 |
+
ImageGenerationCall,
|
| 37 |
+
LocalShellCall,
|
| 38 |
+
McpApprovalRequest,
|
| 39 |
+
McpCall,
|
| 40 |
+
McpListTools,
|
| 41 |
+
)
|
| 42 |
+
from openai.types.responses.response_reasoning_item import ResponseReasoningItem
|
| 43 |
+
|
| 44 |
+
from .agent import Agent, ToolsToFinalOutputResult
|
| 45 |
+
from .agent_output import AgentOutputSchemaBase
|
| 46 |
+
from .computer import AsyncComputer, Computer
|
| 47 |
+
from .exceptions import (
|
| 48 |
+
AgentsException,
|
| 49 |
+
ModelBehaviorError,
|
| 50 |
+
ToolInputGuardrailTripwireTriggered,
|
| 51 |
+
ToolOutputGuardrailTripwireTriggered,
|
| 52 |
+
UserError,
|
| 53 |
+
)
|
| 54 |
+
from .guardrail import InputGuardrail, InputGuardrailResult, OutputGuardrail, OutputGuardrailResult
|
| 55 |
+
from .handoffs import Handoff, HandoffInputData
|
| 56 |
+
from .items import (
|
| 57 |
+
HandoffCallItem,
|
| 58 |
+
HandoffOutputItem,
|
| 59 |
+
ItemHelpers,
|
| 60 |
+
MCPApprovalRequestItem,
|
| 61 |
+
MCPApprovalResponseItem,
|
| 62 |
+
MCPListToolsItem,
|
| 63 |
+
MessageOutputItem,
|
| 64 |
+
ModelResponse,
|
| 65 |
+
ReasoningItem,
|
| 66 |
+
RunItem,
|
| 67 |
+
ToolCallItem,
|
| 68 |
+
ToolCallOutputItem,
|
| 69 |
+
TResponseInputItem,
|
| 70 |
+
)
|
| 71 |
+
from .lifecycle import RunHooks
|
| 72 |
+
from .logger import logger
|
| 73 |
+
from .model_settings import ModelSettings
|
| 74 |
+
from .models.interface import ModelTracing
|
| 75 |
+
from .run_context import RunContextWrapper, TContext
|
| 76 |
+
from .stream_events import RunItemStreamEvent, StreamEvent
|
| 77 |
+
from .tool import (
|
| 78 |
+
ComputerTool,
|
| 79 |
+
ComputerToolSafetyCheckData,
|
| 80 |
+
FunctionTool,
|
| 81 |
+
FunctionToolResult,
|
| 82 |
+
HostedMCPTool,
|
| 83 |
+
LocalShellCommandRequest,
|
| 84 |
+
LocalShellTool,
|
| 85 |
+
MCPToolApprovalRequest,
|
| 86 |
+
Tool,
|
| 87 |
+
)
|
| 88 |
+
from .tool_context import ToolContext
|
| 89 |
+
from .tool_guardrails import (
|
| 90 |
+
ToolInputGuardrailData,
|
| 91 |
+
ToolInputGuardrailResult,
|
| 92 |
+
ToolOutputGuardrailData,
|
| 93 |
+
ToolOutputGuardrailResult,
|
| 94 |
+
)
|
| 95 |
+
from .tracing import (
|
| 96 |
+
SpanError,
|
| 97 |
+
Trace,
|
| 98 |
+
function_span,
|
| 99 |
+
get_current_trace,
|
| 100 |
+
guardrail_span,
|
| 101 |
+
handoff_span,
|
| 102 |
+
trace,
|
| 103 |
+
)
|
| 104 |
+
from .util import _coro, _error_tracing
|
| 105 |
+
|
| 106 |
+
if TYPE_CHECKING:
|
| 107 |
+
from .run import RunConfig
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class QueueCompleteSentinel:
|
| 111 |
+
pass
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
QUEUE_COMPLETE_SENTINEL = QueueCompleteSentinel()
|
| 115 |
+
|
| 116 |
+
_NOT_FINAL_OUTPUT = ToolsToFinalOutputResult(is_final_output=False, final_output=None)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
@dataclass
|
| 120 |
+
class AgentToolUseTracker:
|
| 121 |
+
agent_to_tools: list[tuple[Agent, list[str]]] = field(default_factory=list)
|
| 122 |
+
"""Tuple of (agent, list of tools used). Can't use a dict because agents aren't hashable."""
|
| 123 |
+
|
| 124 |
+
def add_tool_use(self, agent: Agent[Any], tool_names: list[str]) -> None:
|
| 125 |
+
existing_data = next((item for item in self.agent_to_tools if item[0] == agent), None)
|
| 126 |
+
if existing_data:
|
| 127 |
+
existing_data[1].extend(tool_names)
|
| 128 |
+
else:
|
| 129 |
+
self.agent_to_tools.append((agent, tool_names))
|
| 130 |
+
|
| 131 |
+
def has_used_tools(self, agent: Agent[Any]) -> bool:
|
| 132 |
+
existing_data = next((item for item in self.agent_to_tools if item[0] == agent), None)
|
| 133 |
+
return existing_data is not None and len(existing_data[1]) > 0
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
@dataclass
|
| 137 |
+
class ToolRunHandoff:
|
| 138 |
+
handoff: Handoff
|
| 139 |
+
tool_call: ResponseFunctionToolCall
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
@dataclass
|
| 143 |
+
class ToolRunFunction:
|
| 144 |
+
tool_call: ResponseFunctionToolCall
|
| 145 |
+
function_tool: FunctionTool
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
@dataclass
|
| 149 |
+
class ToolRunComputerAction:
|
| 150 |
+
tool_call: ResponseComputerToolCall
|
| 151 |
+
computer_tool: ComputerTool
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
@dataclass
|
| 155 |
+
class ToolRunMCPApprovalRequest:
|
| 156 |
+
request_item: McpApprovalRequest
|
| 157 |
+
mcp_tool: HostedMCPTool
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
@dataclass
|
| 161 |
+
class ToolRunLocalShellCall:
|
| 162 |
+
tool_call: LocalShellCall
|
| 163 |
+
local_shell_tool: LocalShellTool
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
@dataclass
|
| 167 |
+
class ProcessedResponse:
|
| 168 |
+
new_items: list[RunItem]
|
| 169 |
+
handoffs: list[ToolRunHandoff]
|
| 170 |
+
functions: list[ToolRunFunction]
|
| 171 |
+
computer_actions: list[ToolRunComputerAction]
|
| 172 |
+
local_shell_calls: list[ToolRunLocalShellCall]
|
| 173 |
+
tools_used: list[str] # Names of all tools used, including hosted tools
|
| 174 |
+
mcp_approval_requests: list[ToolRunMCPApprovalRequest] # Only requests with callbacks
|
| 175 |
+
|
| 176 |
+
def has_tools_or_approvals_to_run(self) -> bool:
|
| 177 |
+
# Handoffs, functions and computer actions need local processing
|
| 178 |
+
# Hosted tools have already run, so there's nothing to do.
|
| 179 |
+
return any(
|
| 180 |
+
[
|
| 181 |
+
self.handoffs,
|
| 182 |
+
self.functions,
|
| 183 |
+
self.computer_actions,
|
| 184 |
+
self.local_shell_calls,
|
| 185 |
+
self.mcp_approval_requests,
|
| 186 |
+
]
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
@dataclass
|
| 191 |
+
class NextStepHandoff:
|
| 192 |
+
new_agent: Agent[Any]
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
@dataclass
|
| 196 |
+
class NextStepFinalOutput:
|
| 197 |
+
output: Any
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
@dataclass
|
| 201 |
+
class NextStepRunAgain:
|
| 202 |
+
pass
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
@dataclass
|
| 206 |
+
class SingleStepResult:
|
| 207 |
+
original_input: str | list[TResponseInputItem]
|
| 208 |
+
"""The input items i.e. the items before run() was called. May be mutated by handoff input
|
| 209 |
+
filters."""
|
| 210 |
+
|
| 211 |
+
model_response: ModelResponse
|
| 212 |
+
"""The model response for the current step."""
|
| 213 |
+
|
| 214 |
+
pre_step_items: list[RunItem]
|
| 215 |
+
"""Items generated before the current step."""
|
| 216 |
+
|
| 217 |
+
new_step_items: list[RunItem]
|
| 218 |
+
"""Items generated during this current step."""
|
| 219 |
+
|
| 220 |
+
next_step: NextStepHandoff | NextStepFinalOutput | NextStepRunAgain
|
| 221 |
+
"""The next step to take."""
|
| 222 |
+
|
| 223 |
+
tool_input_guardrail_results: list[ToolInputGuardrailResult]
|
| 224 |
+
"""Tool input guardrail results from this step."""
|
| 225 |
+
|
| 226 |
+
tool_output_guardrail_results: list[ToolOutputGuardrailResult]
|
| 227 |
+
"""Tool output guardrail results from this step."""
|
| 228 |
+
|
| 229 |
+
@property
|
| 230 |
+
def generated_items(self) -> list[RunItem]:
|
| 231 |
+
"""Items generated during the agent run (i.e. everything generated after
|
| 232 |
+
`original_input`)."""
|
| 233 |
+
return self.pre_step_items + self.new_step_items
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def get_model_tracing_impl(
|
| 237 |
+
tracing_disabled: bool, trace_include_sensitive_data: bool
|
| 238 |
+
) -> ModelTracing:
|
| 239 |
+
if tracing_disabled:
|
| 240 |
+
return ModelTracing.DISABLED
|
| 241 |
+
elif trace_include_sensitive_data:
|
| 242 |
+
return ModelTracing.ENABLED
|
| 243 |
+
else:
|
| 244 |
+
return ModelTracing.ENABLED_WITHOUT_DATA
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
class RunImpl:
|
| 248 |
+
@classmethod
|
| 249 |
+
async def execute_tools_and_side_effects(
|
| 250 |
+
cls,
|
| 251 |
+
*,
|
| 252 |
+
agent: Agent[TContext],
|
| 253 |
+
# The original input to the Runner
|
| 254 |
+
original_input: str | list[TResponseInputItem],
|
| 255 |
+
# Everything generated by Runner since the original input, but before the current step
|
| 256 |
+
pre_step_items: list[RunItem],
|
| 257 |
+
new_response: ModelResponse,
|
| 258 |
+
processed_response: ProcessedResponse,
|
| 259 |
+
output_schema: AgentOutputSchemaBase | None,
|
| 260 |
+
hooks: RunHooks[TContext],
|
| 261 |
+
context_wrapper: RunContextWrapper[TContext],
|
| 262 |
+
run_config: RunConfig,
|
| 263 |
+
) -> SingleStepResult:
|
| 264 |
+
# Make a copy of the generated items
|
| 265 |
+
pre_step_items = list(pre_step_items)
|
| 266 |
+
|
| 267 |
+
new_step_items: list[RunItem] = []
|
| 268 |
+
new_step_items.extend(processed_response.new_items)
|
| 269 |
+
|
| 270 |
+
# First, lets run the tool calls - function tools and computer actions
|
| 271 |
+
(
|
| 272 |
+
(function_results, tool_input_guardrail_results, tool_output_guardrail_results),
|
| 273 |
+
computer_results,
|
| 274 |
+
) = await asyncio.gather(
|
| 275 |
+
cls.execute_function_tool_calls(
|
| 276 |
+
agent=agent,
|
| 277 |
+
tool_runs=processed_response.functions,
|
| 278 |
+
hooks=hooks,
|
| 279 |
+
context_wrapper=context_wrapper,
|
| 280 |
+
config=run_config,
|
| 281 |
+
),
|
| 282 |
+
cls.execute_computer_actions(
|
| 283 |
+
agent=agent,
|
| 284 |
+
actions=processed_response.computer_actions,
|
| 285 |
+
hooks=hooks,
|
| 286 |
+
context_wrapper=context_wrapper,
|
| 287 |
+
config=run_config,
|
| 288 |
+
),
|
| 289 |
+
)
|
| 290 |
+
new_step_items.extend([result.run_item for result in function_results])
|
| 291 |
+
new_step_items.extend(computer_results)
|
| 292 |
+
|
| 293 |
+
# Next, run the MCP approval requests
|
| 294 |
+
if processed_response.mcp_approval_requests:
|
| 295 |
+
approval_results = await cls.execute_mcp_approval_requests(
|
| 296 |
+
agent=agent,
|
| 297 |
+
approval_requests=processed_response.mcp_approval_requests,
|
| 298 |
+
context_wrapper=context_wrapper,
|
| 299 |
+
)
|
| 300 |
+
new_step_items.extend(approval_results)
|
| 301 |
+
|
| 302 |
+
# Next, check if there are any handoffs
|
| 303 |
+
if run_handoffs := processed_response.handoffs:
|
| 304 |
+
return await cls.execute_handoffs(
|
| 305 |
+
agent=agent,
|
| 306 |
+
original_input=original_input,
|
| 307 |
+
pre_step_items=pre_step_items,
|
| 308 |
+
new_step_items=new_step_items,
|
| 309 |
+
new_response=new_response,
|
| 310 |
+
run_handoffs=run_handoffs,
|
| 311 |
+
hooks=hooks,
|
| 312 |
+
context_wrapper=context_wrapper,
|
| 313 |
+
run_config=run_config,
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
# Next, we'll check if the tool use should result in a final output
|
| 317 |
+
check_tool_use = await cls._check_for_final_output_from_tools(
|
| 318 |
+
agent=agent,
|
| 319 |
+
tool_results=function_results,
|
| 320 |
+
context_wrapper=context_wrapper,
|
| 321 |
+
config=run_config,
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
if check_tool_use.is_final_output:
|
| 325 |
+
# If the output type is str, then let's just stringify it
|
| 326 |
+
if not agent.output_type or agent.output_type is str:
|
| 327 |
+
check_tool_use.final_output = str(check_tool_use.final_output)
|
| 328 |
+
|
| 329 |
+
if check_tool_use.final_output is None:
|
| 330 |
+
logger.error(
|
| 331 |
+
"Model returned a final output of None. Not raising an error because we assume"
|
| 332 |
+
"you know what you're doing."
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
return await cls.execute_final_output(
|
| 336 |
+
agent=agent,
|
| 337 |
+
original_input=original_input,
|
| 338 |
+
new_response=new_response,
|
| 339 |
+
pre_step_items=pre_step_items,
|
| 340 |
+
new_step_items=new_step_items,
|
| 341 |
+
final_output=check_tool_use.final_output,
|
| 342 |
+
hooks=hooks,
|
| 343 |
+
context_wrapper=context_wrapper,
|
| 344 |
+
tool_input_guardrail_results=tool_input_guardrail_results,
|
| 345 |
+
tool_output_guardrail_results=tool_output_guardrail_results,
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
# Now we can check if the model also produced a final output
|
| 349 |
+
message_items = [item for item in new_step_items if isinstance(item, MessageOutputItem)]
|
| 350 |
+
|
| 351 |
+
# We'll use the last content output as the final output
|
| 352 |
+
potential_final_output_text = (
|
| 353 |
+
ItemHelpers.extract_last_text(message_items[-1].raw_item) if message_items else None
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
# Generate final output only when there are no pending tool calls or approval requests.
|
| 357 |
+
if not processed_response.has_tools_or_approvals_to_run():
|
| 358 |
+
if output_schema and not output_schema.is_plain_text() and potential_final_output_text:
|
| 359 |
+
final_output = output_schema.validate_json(potential_final_output_text)
|
| 360 |
+
return await cls.execute_final_output(
|
| 361 |
+
agent=agent,
|
| 362 |
+
original_input=original_input,
|
| 363 |
+
new_response=new_response,
|
| 364 |
+
pre_step_items=pre_step_items,
|
| 365 |
+
new_step_items=new_step_items,
|
| 366 |
+
final_output=final_output,
|
| 367 |
+
hooks=hooks,
|
| 368 |
+
context_wrapper=context_wrapper,
|
| 369 |
+
tool_input_guardrail_results=tool_input_guardrail_results,
|
| 370 |
+
tool_output_guardrail_results=tool_output_guardrail_results,
|
| 371 |
+
)
|
| 372 |
+
elif not output_schema or output_schema.is_plain_text():
|
| 373 |
+
return await cls.execute_final_output(
|
| 374 |
+
agent=agent,
|
| 375 |
+
original_input=original_input,
|
| 376 |
+
new_response=new_response,
|
| 377 |
+
pre_step_items=pre_step_items,
|
| 378 |
+
new_step_items=new_step_items,
|
| 379 |
+
final_output=potential_final_output_text or "",
|
| 380 |
+
hooks=hooks,
|
| 381 |
+
context_wrapper=context_wrapper,
|
| 382 |
+
tool_input_guardrail_results=tool_input_guardrail_results,
|
| 383 |
+
tool_output_guardrail_results=tool_output_guardrail_results,
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
# If there's no final output, we can just run again
|
| 387 |
+
return SingleStepResult(
|
| 388 |
+
original_input=original_input,
|
| 389 |
+
model_response=new_response,
|
| 390 |
+
pre_step_items=pre_step_items,
|
| 391 |
+
new_step_items=new_step_items,
|
| 392 |
+
next_step=NextStepRunAgain(),
|
| 393 |
+
tool_input_guardrail_results=tool_input_guardrail_results,
|
| 394 |
+
tool_output_guardrail_results=tool_output_guardrail_results,
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
@classmethod
|
| 398 |
+
def maybe_reset_tool_choice(
|
| 399 |
+
cls, agent: Agent[Any], tool_use_tracker: AgentToolUseTracker, model_settings: ModelSettings
|
| 400 |
+
) -> ModelSettings:
|
| 401 |
+
"""Resets tool choice to None if the agent has used tools and the agent's reset_tool_choice
|
| 402 |
+
flag is True."""
|
| 403 |
+
|
| 404 |
+
if agent.reset_tool_choice is True and tool_use_tracker.has_used_tools(agent):
|
| 405 |
+
return dataclasses.replace(model_settings, tool_choice=None)
|
| 406 |
+
|
| 407 |
+
return model_settings
|
| 408 |
+
|
| 409 |
+
@classmethod
|
| 410 |
+
def process_model_response(
|
| 411 |
+
cls,
|
| 412 |
+
*,
|
| 413 |
+
agent: Agent[Any],
|
| 414 |
+
all_tools: list[Tool],
|
| 415 |
+
response: ModelResponse,
|
| 416 |
+
output_schema: AgentOutputSchemaBase | None,
|
| 417 |
+
handoffs: list[Handoff],
|
| 418 |
+
) -> ProcessedResponse:
|
| 419 |
+
items: list[RunItem] = []
|
| 420 |
+
|
| 421 |
+
run_handoffs = []
|
| 422 |
+
functions = []
|
| 423 |
+
computer_actions = []
|
| 424 |
+
local_shell_calls = []
|
| 425 |
+
mcp_approval_requests = []
|
| 426 |
+
tools_used: list[str] = []
|
| 427 |
+
handoff_map = {handoff.tool_name: handoff for handoff in handoffs}
|
| 428 |
+
function_map = {tool.name: tool for tool in all_tools if isinstance(tool, FunctionTool)}
|
| 429 |
+
computer_tool = next((tool for tool in all_tools if isinstance(tool, ComputerTool)), None)
|
| 430 |
+
local_shell_tool = next(
|
| 431 |
+
(tool for tool in all_tools if isinstance(tool, LocalShellTool)), None
|
| 432 |
+
)
|
| 433 |
+
hosted_mcp_server_map = {
|
| 434 |
+
tool.tool_config["server_label"]: tool
|
| 435 |
+
for tool in all_tools
|
| 436 |
+
if isinstance(tool, HostedMCPTool)
|
| 437 |
+
}
|
| 438 |
+
|
| 439 |
+
for output in response.output:
|
| 440 |
+
if isinstance(output, ResponseOutputMessage):
|
| 441 |
+
items.append(MessageOutputItem(raw_item=output, agent=agent))
|
| 442 |
+
elif isinstance(output, ResponseFileSearchToolCall):
|
| 443 |
+
items.append(ToolCallItem(raw_item=output, agent=agent))
|
| 444 |
+
tools_used.append("file_search")
|
| 445 |
+
elif isinstance(output, ResponseFunctionWebSearch):
|
| 446 |
+
items.append(ToolCallItem(raw_item=output, agent=agent))
|
| 447 |
+
tools_used.append("web_search")
|
| 448 |
+
elif isinstance(output, ResponseReasoningItem):
|
| 449 |
+
items.append(ReasoningItem(raw_item=output, agent=agent))
|
| 450 |
+
elif isinstance(output, ResponseComputerToolCall):
|
| 451 |
+
items.append(ToolCallItem(raw_item=output, agent=agent))
|
| 452 |
+
tools_used.append("computer_use")
|
| 453 |
+
if not computer_tool:
|
| 454 |
+
_error_tracing.attach_error_to_current_span(
|
| 455 |
+
SpanError(
|
| 456 |
+
message="Computer tool not found",
|
| 457 |
+
data={},
|
| 458 |
+
)
|
| 459 |
+
)
|
| 460 |
+
raise ModelBehaviorError(
|
| 461 |
+
"Model produced computer action without a computer tool."
|
| 462 |
+
)
|
| 463 |
+
computer_actions.append(
|
| 464 |
+
ToolRunComputerAction(tool_call=output, computer_tool=computer_tool)
|
| 465 |
+
)
|
| 466 |
+
elif isinstance(output, McpApprovalRequest):
|
| 467 |
+
items.append(MCPApprovalRequestItem(raw_item=output, agent=agent))
|
| 468 |
+
if output.server_label not in hosted_mcp_server_map:
|
| 469 |
+
_error_tracing.attach_error_to_current_span(
|
| 470 |
+
SpanError(
|
| 471 |
+
message="MCP server label not found",
|
| 472 |
+
data={"server_label": output.server_label},
|
| 473 |
+
)
|
| 474 |
+
)
|
| 475 |
+
raise ModelBehaviorError(f"MCP server label {output.server_label} not found")
|
| 476 |
+
else:
|
| 477 |
+
server = hosted_mcp_server_map[output.server_label]
|
| 478 |
+
if server.on_approval_request:
|
| 479 |
+
mcp_approval_requests.append(
|
| 480 |
+
ToolRunMCPApprovalRequest(
|
| 481 |
+
request_item=output,
|
| 482 |
+
mcp_tool=server,
|
| 483 |
+
)
|
| 484 |
+
)
|
| 485 |
+
else:
|
| 486 |
+
logger.warning(
|
| 487 |
+
f"MCP server {output.server_label} has no on_approval_request hook"
|
| 488 |
+
)
|
| 489 |
+
elif isinstance(output, McpListTools):
|
| 490 |
+
items.append(MCPListToolsItem(raw_item=output, agent=agent))
|
| 491 |
+
elif isinstance(output, McpCall):
|
| 492 |
+
items.append(ToolCallItem(raw_item=output, agent=agent))
|
| 493 |
+
tools_used.append("mcp")
|
| 494 |
+
elif isinstance(output, ImageGenerationCall):
|
| 495 |
+
items.append(ToolCallItem(raw_item=output, agent=agent))
|
| 496 |
+
tools_used.append("image_generation")
|
| 497 |
+
elif isinstance(output, ResponseCodeInterpreterToolCall):
|
| 498 |
+
items.append(ToolCallItem(raw_item=output, agent=agent))
|
| 499 |
+
tools_used.append("code_interpreter")
|
| 500 |
+
elif isinstance(output, LocalShellCall):
|
| 501 |
+
items.append(ToolCallItem(raw_item=output, agent=agent))
|
| 502 |
+
tools_used.append("local_shell")
|
| 503 |
+
if not local_shell_tool:
|
| 504 |
+
_error_tracing.attach_error_to_current_span(
|
| 505 |
+
SpanError(
|
| 506 |
+
message="Local shell tool not found",
|
| 507 |
+
data={},
|
| 508 |
+
)
|
| 509 |
+
)
|
| 510 |
+
raise ModelBehaviorError(
|
| 511 |
+
"Model produced local shell call without a local shell tool."
|
| 512 |
+
)
|
| 513 |
+
local_shell_calls.append(
|
| 514 |
+
ToolRunLocalShellCall(tool_call=output, local_shell_tool=local_shell_tool)
|
| 515 |
+
)
|
| 516 |
+
|
| 517 |
+
elif not isinstance(output, ResponseFunctionToolCall):
|
| 518 |
+
logger.warning(f"Unexpected output type, ignoring: {type(output)}")
|
| 519 |
+
continue
|
| 520 |
+
|
| 521 |
+
# At this point we know it's a function tool call
|
| 522 |
+
if not isinstance(output, ResponseFunctionToolCall):
|
| 523 |
+
continue
|
| 524 |
+
|
| 525 |
+
tools_used.append(output.name)
|
| 526 |
+
|
| 527 |
+
# Handoffs
|
| 528 |
+
if output.name in handoff_map:
|
| 529 |
+
items.append(HandoffCallItem(raw_item=output, agent=agent))
|
| 530 |
+
handoff = ToolRunHandoff(
|
| 531 |
+
tool_call=output,
|
| 532 |
+
handoff=handoff_map[output.name],
|
| 533 |
+
)
|
| 534 |
+
run_handoffs.append(handoff)
|
| 535 |
+
# Regular function tool call
|
| 536 |
+
else:
|
| 537 |
+
if output.name not in function_map:
|
| 538 |
+
if output_schema is not None and output.name == "json_tool_call":
|
| 539 |
+
# LiteLLM could generate non-existent tool calls for structured outputs
|
| 540 |
+
items.append(ToolCallItem(raw_item=output, agent=agent))
|
| 541 |
+
functions.append(
|
| 542 |
+
ToolRunFunction(
|
| 543 |
+
tool_call=output,
|
| 544 |
+
# this tool does not exist in function_map, so generate ad-hoc one,
|
| 545 |
+
# which just parses the input if it's a string, and returns the
|
| 546 |
+
# value otherwise
|
| 547 |
+
function_tool=_build_litellm_json_tool_call(output),
|
| 548 |
+
)
|
| 549 |
+
)
|
| 550 |
+
continue
|
| 551 |
+
else:
|
| 552 |
+
_error_tracing.attach_error_to_current_span(
|
| 553 |
+
SpanError(
|
| 554 |
+
message="Tool not found",
|
| 555 |
+
data={"tool_name": output.name},
|
| 556 |
+
)
|
| 557 |
+
)
|
| 558 |
+
error = f"Tool {output.name} not found in agent {agent.name}"
|
| 559 |
+
raise ModelBehaviorError(error)
|
| 560 |
+
|
| 561 |
+
items.append(ToolCallItem(raw_item=output, agent=agent))
|
| 562 |
+
functions.append(
|
| 563 |
+
ToolRunFunction(
|
| 564 |
+
tool_call=output,
|
| 565 |
+
function_tool=function_map[output.name],
|
| 566 |
+
)
|
| 567 |
+
)
|
| 568 |
+
|
| 569 |
+
return ProcessedResponse(
|
| 570 |
+
new_items=items,
|
| 571 |
+
handoffs=run_handoffs,
|
| 572 |
+
functions=functions,
|
| 573 |
+
computer_actions=computer_actions,
|
| 574 |
+
local_shell_calls=local_shell_calls,
|
| 575 |
+
tools_used=tools_used,
|
| 576 |
+
mcp_approval_requests=mcp_approval_requests,
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
+
@classmethod
|
| 580 |
+
async def _execute_input_guardrails(
|
| 581 |
+
cls,
|
| 582 |
+
*,
|
| 583 |
+
func_tool: FunctionTool,
|
| 584 |
+
tool_context: ToolContext[TContext],
|
| 585 |
+
agent: Agent[TContext],
|
| 586 |
+
tool_input_guardrail_results: list[ToolInputGuardrailResult],
|
| 587 |
+
) -> str | None:
|
| 588 |
+
"""Execute input guardrails for a tool.
|
| 589 |
+
|
| 590 |
+
Args:
|
| 591 |
+
func_tool: The function tool being executed.
|
| 592 |
+
tool_context: The tool execution context.
|
| 593 |
+
agent: The agent executing the tool.
|
| 594 |
+
tool_input_guardrail_results: List to append guardrail results to.
|
| 595 |
+
|
| 596 |
+
Returns:
|
| 597 |
+
None if tool execution should proceed, or a message string if execution should be
|
| 598 |
+
skipped.
|
| 599 |
+
|
| 600 |
+
Raises:
|
| 601 |
+
ToolInputGuardrailTripwireTriggered: If a guardrail triggers an exception.
|
| 602 |
+
"""
|
| 603 |
+
if not func_tool.tool_input_guardrails:
|
| 604 |
+
return None
|
| 605 |
+
|
| 606 |
+
for guardrail in func_tool.tool_input_guardrails:
|
| 607 |
+
gr_out = await guardrail.run(
|
| 608 |
+
ToolInputGuardrailData(
|
| 609 |
+
context=tool_context,
|
| 610 |
+
agent=agent,
|
| 611 |
+
)
|
| 612 |
+
)
|
| 613 |
+
|
| 614 |
+
# Store the guardrail result
|
| 615 |
+
tool_input_guardrail_results.append(
|
| 616 |
+
ToolInputGuardrailResult(
|
| 617 |
+
guardrail=guardrail,
|
| 618 |
+
output=gr_out,
|
| 619 |
+
)
|
| 620 |
+
)
|
| 621 |
+
|
| 622 |
+
# Handle different behavior types
|
| 623 |
+
if gr_out.behavior["type"] == "raise_exception":
|
| 624 |
+
raise ToolInputGuardrailTripwireTriggered(guardrail=guardrail, output=gr_out)
|
| 625 |
+
elif gr_out.behavior["type"] == "reject_content":
|
| 626 |
+
# Set final_result to the message and skip tool execution
|
| 627 |
+
return gr_out.behavior["message"]
|
| 628 |
+
elif gr_out.behavior["type"] == "allow":
|
| 629 |
+
# Continue to next guardrail or tool execution
|
| 630 |
+
continue
|
| 631 |
+
|
| 632 |
+
return None
|
| 633 |
+
|
| 634 |
+
@classmethod
|
| 635 |
+
async def _execute_output_guardrails(
|
| 636 |
+
cls,
|
| 637 |
+
*,
|
| 638 |
+
func_tool: FunctionTool,
|
| 639 |
+
tool_context: ToolContext[TContext],
|
| 640 |
+
agent: Agent[TContext],
|
| 641 |
+
real_result: Any,
|
| 642 |
+
tool_output_guardrail_results: list[ToolOutputGuardrailResult],
|
| 643 |
+
) -> Any:
|
| 644 |
+
"""Execute output guardrails for a tool.
|
| 645 |
+
|
| 646 |
+
Args:
|
| 647 |
+
func_tool: The function tool being executed.
|
| 648 |
+
tool_context: The tool execution context.
|
| 649 |
+
agent: The agent executing the tool.
|
| 650 |
+
real_result: The actual result from the tool execution.
|
| 651 |
+
tool_output_guardrail_results: List to append guardrail results to.
|
| 652 |
+
|
| 653 |
+
Returns:
|
| 654 |
+
The final result after guardrail processing (may be modified).
|
| 655 |
+
|
| 656 |
+
Raises:
|
| 657 |
+
ToolOutputGuardrailTripwireTriggered: If a guardrail triggers an exception.
|
| 658 |
+
"""
|
| 659 |
+
if not func_tool.tool_output_guardrails:
|
| 660 |
+
return real_result
|
| 661 |
+
|
| 662 |
+
final_result = real_result
|
| 663 |
+
for output_guardrail in func_tool.tool_output_guardrails:
|
| 664 |
+
gr_out = await output_guardrail.run(
|
| 665 |
+
ToolOutputGuardrailData(
|
| 666 |
+
context=tool_context,
|
| 667 |
+
agent=agent,
|
| 668 |
+
output=real_result,
|
| 669 |
+
)
|
| 670 |
+
)
|
| 671 |
+
|
| 672 |
+
# Store the guardrail result
|
| 673 |
+
tool_output_guardrail_results.append(
|
| 674 |
+
ToolOutputGuardrailResult(
|
| 675 |
+
guardrail=output_guardrail,
|
| 676 |
+
output=gr_out,
|
| 677 |
+
)
|
| 678 |
+
)
|
| 679 |
+
|
| 680 |
+
# Handle different behavior types
|
| 681 |
+
if gr_out.behavior["type"] == "raise_exception":
|
| 682 |
+
raise ToolOutputGuardrailTripwireTriggered(
|
| 683 |
+
guardrail=output_guardrail, output=gr_out
|
| 684 |
+
)
|
| 685 |
+
elif gr_out.behavior["type"] == "reject_content":
|
| 686 |
+
# Override the result with the guardrail message
|
| 687 |
+
final_result = gr_out.behavior["message"]
|
| 688 |
+
break
|
| 689 |
+
elif gr_out.behavior["type"] == "allow":
|
| 690 |
+
# Continue to next guardrail
|
| 691 |
+
continue
|
| 692 |
+
|
| 693 |
+
return final_result
|
| 694 |
+
|
| 695 |
+
@classmethod
|
| 696 |
+
async def _execute_tool_with_hooks(
|
| 697 |
+
cls,
|
| 698 |
+
*,
|
| 699 |
+
func_tool: FunctionTool,
|
| 700 |
+
tool_context: ToolContext[TContext],
|
| 701 |
+
agent: Agent[TContext],
|
| 702 |
+
hooks: RunHooks[TContext],
|
| 703 |
+
tool_call: ResponseFunctionToolCall,
|
| 704 |
+
) -> Any:
|
| 705 |
+
"""Execute the core tool function with before/after hooks.
|
| 706 |
+
|
| 707 |
+
Args:
|
| 708 |
+
func_tool: The function tool being executed.
|
| 709 |
+
tool_context: The tool execution context.
|
| 710 |
+
agent: The agent executing the tool.
|
| 711 |
+
hooks: The run hooks to execute.
|
| 712 |
+
tool_call: The tool call details.
|
| 713 |
+
|
| 714 |
+
Returns:
|
| 715 |
+
The result from the tool execution.
|
| 716 |
+
"""
|
| 717 |
+
await asyncio.gather(
|
| 718 |
+
hooks.on_tool_start(tool_context, agent, func_tool),
|
| 719 |
+
(
|
| 720 |
+
agent.hooks.on_tool_start(tool_context, agent, func_tool)
|
| 721 |
+
if agent.hooks
|
| 722 |
+
else _coro.noop_coroutine()
|
| 723 |
+
),
|
| 724 |
+
)
|
| 725 |
+
|
| 726 |
+
return await func_tool.on_invoke_tool(tool_context, tool_call.arguments)
|
| 727 |
+
|
| 728 |
+
@classmethod
|
| 729 |
+
async def execute_function_tool_calls(
|
| 730 |
+
cls,
|
| 731 |
+
*,
|
| 732 |
+
agent: Agent[TContext],
|
| 733 |
+
tool_runs: list[ToolRunFunction],
|
| 734 |
+
hooks: RunHooks[TContext],
|
| 735 |
+
context_wrapper: RunContextWrapper[TContext],
|
| 736 |
+
config: RunConfig,
|
| 737 |
+
) -> tuple[
|
| 738 |
+
list[FunctionToolResult], list[ToolInputGuardrailResult], list[ToolOutputGuardrailResult]
|
| 739 |
+
]:
|
| 740 |
+
# Collect guardrail results
|
| 741 |
+
tool_input_guardrail_results: list[ToolInputGuardrailResult] = []
|
| 742 |
+
tool_output_guardrail_results: list[ToolOutputGuardrailResult] = []
|
| 743 |
+
|
| 744 |
+
async def run_single_tool(
|
| 745 |
+
func_tool: FunctionTool, tool_call: ResponseFunctionToolCall
|
| 746 |
+
) -> Any:
|
| 747 |
+
with function_span(func_tool.name) as span_fn:
|
| 748 |
+
tool_context = ToolContext.from_agent_context(
|
| 749 |
+
context_wrapper,
|
| 750 |
+
tool_call.call_id,
|
| 751 |
+
tool_call=tool_call,
|
| 752 |
+
)
|
| 753 |
+
if config.trace_include_sensitive_data:
|
| 754 |
+
span_fn.span_data.input = tool_call.arguments
|
| 755 |
+
try:
|
| 756 |
+
# 1) Run input tool guardrails, if any
|
| 757 |
+
rejected_message = await cls._execute_input_guardrails(
|
| 758 |
+
func_tool=func_tool,
|
| 759 |
+
tool_context=tool_context,
|
| 760 |
+
agent=agent,
|
| 761 |
+
tool_input_guardrail_results=tool_input_guardrail_results,
|
| 762 |
+
)
|
| 763 |
+
|
| 764 |
+
if rejected_message is not None:
|
| 765 |
+
# Input guardrail rejected the tool call
|
| 766 |
+
final_result = rejected_message
|
| 767 |
+
else:
|
| 768 |
+
# 2) Actually run the tool
|
| 769 |
+
real_result = await cls._execute_tool_with_hooks(
|
| 770 |
+
func_tool=func_tool,
|
| 771 |
+
tool_context=tool_context,
|
| 772 |
+
agent=agent,
|
| 773 |
+
hooks=hooks,
|
| 774 |
+
tool_call=tool_call,
|
| 775 |
+
)
|
| 776 |
+
|
| 777 |
+
# 3) Run output tool guardrails, if any
|
| 778 |
+
final_result = await cls._execute_output_guardrails(
|
| 779 |
+
func_tool=func_tool,
|
| 780 |
+
tool_context=tool_context,
|
| 781 |
+
agent=agent,
|
| 782 |
+
real_result=real_result,
|
| 783 |
+
tool_output_guardrail_results=tool_output_guardrail_results,
|
| 784 |
+
)
|
| 785 |
+
|
| 786 |
+
# 4) Tool end hooks (with final result, which may have been overridden)
|
| 787 |
+
await asyncio.gather(
|
| 788 |
+
hooks.on_tool_end(tool_context, agent, func_tool, final_result),
|
| 789 |
+
(
|
| 790 |
+
agent.hooks.on_tool_end(
|
| 791 |
+
tool_context, agent, func_tool, final_result
|
| 792 |
+
)
|
| 793 |
+
if agent.hooks
|
| 794 |
+
else _coro.noop_coroutine()
|
| 795 |
+
),
|
| 796 |
+
)
|
| 797 |
+
result = final_result
|
| 798 |
+
except Exception as e:
|
| 799 |
+
_error_tracing.attach_error_to_current_span(
|
| 800 |
+
SpanError(
|
| 801 |
+
message="Error running tool",
|
| 802 |
+
data={"tool_name": func_tool.name, "error": str(e)},
|
| 803 |
+
)
|
| 804 |
+
)
|
| 805 |
+
if isinstance(e, AgentsException):
|
| 806 |
+
raise e
|
| 807 |
+
raise UserError(f"Error running tool {func_tool.name}: {e}") from e
|
| 808 |
+
|
| 809 |
+
if config.trace_include_sensitive_data:
|
| 810 |
+
span_fn.span_data.output = result
|
| 811 |
+
return result
|
| 812 |
+
|
| 813 |
+
tasks = []
|
| 814 |
+
for tool_run in tool_runs:
|
| 815 |
+
function_tool = tool_run.function_tool
|
| 816 |
+
tasks.append(run_single_tool(function_tool, tool_run.tool_call))
|
| 817 |
+
|
| 818 |
+
results = await asyncio.gather(*tasks)
|
| 819 |
+
|
| 820 |
+
function_tool_results = [
|
| 821 |
+
FunctionToolResult(
|
| 822 |
+
tool=tool_run.function_tool,
|
| 823 |
+
output=result,
|
| 824 |
+
run_item=ToolCallOutputItem(
|
| 825 |
+
output=result,
|
| 826 |
+
raw_item=ItemHelpers.tool_call_output_item(tool_run.tool_call, str(result)),
|
| 827 |
+
agent=agent,
|
| 828 |
+
),
|
| 829 |
+
)
|
| 830 |
+
for tool_run, result in zip(tool_runs, results)
|
| 831 |
+
]
|
| 832 |
+
|
| 833 |
+
return function_tool_results, tool_input_guardrail_results, tool_output_guardrail_results
|
| 834 |
+
|
| 835 |
+
@classmethod
|
| 836 |
+
async def execute_local_shell_calls(
|
| 837 |
+
cls,
|
| 838 |
+
*,
|
| 839 |
+
agent: Agent[TContext],
|
| 840 |
+
calls: list[ToolRunLocalShellCall],
|
| 841 |
+
context_wrapper: RunContextWrapper[TContext],
|
| 842 |
+
hooks: RunHooks[TContext],
|
| 843 |
+
config: RunConfig,
|
| 844 |
+
) -> list[RunItem]:
|
| 845 |
+
results: list[RunItem] = []
|
| 846 |
+
# Need to run these serially, because each call can affect the local shell state
|
| 847 |
+
for call in calls:
|
| 848 |
+
results.append(
|
| 849 |
+
await LocalShellAction.execute(
|
| 850 |
+
agent=agent,
|
| 851 |
+
call=call,
|
| 852 |
+
hooks=hooks,
|
| 853 |
+
context_wrapper=context_wrapper,
|
| 854 |
+
config=config,
|
| 855 |
+
)
|
| 856 |
+
)
|
| 857 |
+
return results
|
| 858 |
+
|
| 859 |
+
@classmethod
|
| 860 |
+
async def execute_computer_actions(
|
| 861 |
+
cls,
|
| 862 |
+
*,
|
| 863 |
+
agent: Agent[TContext],
|
| 864 |
+
actions: list[ToolRunComputerAction],
|
| 865 |
+
hooks: RunHooks[TContext],
|
| 866 |
+
context_wrapper: RunContextWrapper[TContext],
|
| 867 |
+
config: RunConfig,
|
| 868 |
+
) -> list[RunItem]:
|
| 869 |
+
results: list[RunItem] = []
|
| 870 |
+
# Need to run these serially, because each action can affect the computer state
|
| 871 |
+
for action in actions:
|
| 872 |
+
acknowledged: list[ComputerCallOutputAcknowledgedSafetyCheck] | None = None
|
| 873 |
+
if action.tool_call.pending_safety_checks and action.computer_tool.on_safety_check:
|
| 874 |
+
acknowledged = []
|
| 875 |
+
for check in action.tool_call.pending_safety_checks:
|
| 876 |
+
data = ComputerToolSafetyCheckData(
|
| 877 |
+
ctx_wrapper=context_wrapper,
|
| 878 |
+
agent=agent,
|
| 879 |
+
tool_call=action.tool_call,
|
| 880 |
+
safety_check=check,
|
| 881 |
+
)
|
| 882 |
+
maybe = action.computer_tool.on_safety_check(data)
|
| 883 |
+
ack = await maybe if inspect.isawaitable(maybe) else maybe
|
| 884 |
+
if ack:
|
| 885 |
+
acknowledged.append(
|
| 886 |
+
ComputerCallOutputAcknowledgedSafetyCheck(
|
| 887 |
+
id=check.id,
|
| 888 |
+
code=check.code,
|
| 889 |
+
message=check.message,
|
| 890 |
+
)
|
| 891 |
+
)
|
| 892 |
+
else:
|
| 893 |
+
raise UserError("Computer tool safety check was not acknowledged")
|
| 894 |
+
|
| 895 |
+
results.append(
|
| 896 |
+
await ComputerAction.execute(
|
| 897 |
+
agent=agent,
|
| 898 |
+
action=action,
|
| 899 |
+
hooks=hooks,
|
| 900 |
+
context_wrapper=context_wrapper,
|
| 901 |
+
config=config,
|
| 902 |
+
acknowledged_safety_checks=acknowledged,
|
| 903 |
+
)
|
| 904 |
+
)
|
| 905 |
+
|
| 906 |
+
return results
|
| 907 |
+
|
| 908 |
+
@classmethod
|
| 909 |
+
async def execute_handoffs(
|
| 910 |
+
cls,
|
| 911 |
+
*,
|
| 912 |
+
agent: Agent[TContext],
|
| 913 |
+
original_input: str | list[TResponseInputItem],
|
| 914 |
+
pre_step_items: list[RunItem],
|
| 915 |
+
new_step_items: list[RunItem],
|
| 916 |
+
new_response: ModelResponse,
|
| 917 |
+
run_handoffs: list[ToolRunHandoff],
|
| 918 |
+
hooks: RunHooks[TContext],
|
| 919 |
+
context_wrapper: RunContextWrapper[TContext],
|
| 920 |
+
run_config: RunConfig,
|
| 921 |
+
) -> SingleStepResult:
|
| 922 |
+
# If there is more than one handoff, add tool responses that reject those handoffs
|
| 923 |
+
multiple_handoffs = len(run_handoffs) > 1
|
| 924 |
+
if multiple_handoffs:
|
| 925 |
+
output_message = "Multiple handoffs detected, ignoring this one."
|
| 926 |
+
new_step_items.extend(
|
| 927 |
+
[
|
| 928 |
+
ToolCallOutputItem(
|
| 929 |
+
output=output_message,
|
| 930 |
+
raw_item=ItemHelpers.tool_call_output_item(
|
| 931 |
+
handoff.tool_call, output_message
|
| 932 |
+
),
|
| 933 |
+
agent=agent,
|
| 934 |
+
)
|
| 935 |
+
for handoff in run_handoffs[1:]
|
| 936 |
+
]
|
| 937 |
+
)
|
| 938 |
+
|
| 939 |
+
actual_handoff = run_handoffs[0]
|
| 940 |
+
with handoff_span(from_agent=agent.name) as span_handoff:
|
| 941 |
+
handoff = actual_handoff.handoff
|
| 942 |
+
new_agent: Agent[Any] = await handoff.on_invoke_handoff(
|
| 943 |
+
context_wrapper, actual_handoff.tool_call.arguments
|
| 944 |
+
)
|
| 945 |
+
span_handoff.span_data.to_agent = new_agent.name
|
| 946 |
+
if multiple_handoffs:
|
| 947 |
+
requested_agents = [handoff.handoff.agent_name for handoff in run_handoffs]
|
| 948 |
+
span_handoff.set_error(
|
| 949 |
+
SpanError(
|
| 950 |
+
message="Multiple handoffs requested",
|
| 951 |
+
data={
|
| 952 |
+
"requested_agents": requested_agents,
|
| 953 |
+
},
|
| 954 |
+
)
|
| 955 |
+
)
|
| 956 |
+
|
| 957 |
+
# Append a tool output item for the handoff
|
| 958 |
+
new_step_items.append(
|
| 959 |
+
HandoffOutputItem(
|
| 960 |
+
agent=agent,
|
| 961 |
+
raw_item=ItemHelpers.tool_call_output_item(
|
| 962 |
+
actual_handoff.tool_call,
|
| 963 |
+
handoff.get_transfer_message(new_agent),
|
| 964 |
+
),
|
| 965 |
+
source_agent=agent,
|
| 966 |
+
target_agent=new_agent,
|
| 967 |
+
)
|
| 968 |
+
)
|
| 969 |
+
|
| 970 |
+
# Execute handoff hooks
|
| 971 |
+
await asyncio.gather(
|
| 972 |
+
hooks.on_handoff(
|
| 973 |
+
context=context_wrapper,
|
| 974 |
+
from_agent=agent,
|
| 975 |
+
to_agent=new_agent,
|
| 976 |
+
),
|
| 977 |
+
(
|
| 978 |
+
agent.hooks.on_handoff(
|
| 979 |
+
context_wrapper,
|
| 980 |
+
agent=new_agent,
|
| 981 |
+
source=agent,
|
| 982 |
+
)
|
| 983 |
+
if agent.hooks
|
| 984 |
+
else _coro.noop_coroutine()
|
| 985 |
+
),
|
| 986 |
+
)
|
| 987 |
+
|
| 988 |
+
# If there's an input filter, filter the input for the next agent
|
| 989 |
+
input_filter = handoff.input_filter or (
|
| 990 |
+
run_config.handoff_input_filter if run_config else None
|
| 991 |
+
)
|
| 992 |
+
if input_filter:
|
| 993 |
+
logger.debug("Filtering inputs for handoff")
|
| 994 |
+
handoff_input_data = HandoffInputData(
|
| 995 |
+
input_history=tuple(original_input)
|
| 996 |
+
if isinstance(original_input, list)
|
| 997 |
+
else original_input,
|
| 998 |
+
pre_handoff_items=tuple(pre_step_items),
|
| 999 |
+
new_items=tuple(new_step_items),
|
| 1000 |
+
run_context=context_wrapper,
|
| 1001 |
+
)
|
| 1002 |
+
if not callable(input_filter):
|
| 1003 |
+
_error_tracing.attach_error_to_span(
|
| 1004 |
+
span_handoff,
|
| 1005 |
+
SpanError(
|
| 1006 |
+
message="Invalid input filter",
|
| 1007 |
+
data={"details": "not callable()"},
|
| 1008 |
+
),
|
| 1009 |
+
)
|
| 1010 |
+
raise UserError(f"Invalid input filter: {input_filter}")
|
| 1011 |
+
filtered = input_filter(handoff_input_data)
|
| 1012 |
+
if inspect.isawaitable(filtered):
|
| 1013 |
+
filtered = await filtered
|
| 1014 |
+
if not isinstance(filtered, HandoffInputData):
|
| 1015 |
+
_error_tracing.attach_error_to_span(
|
| 1016 |
+
span_handoff,
|
| 1017 |
+
SpanError(
|
| 1018 |
+
message="Invalid input filter result",
|
| 1019 |
+
data={"details": "not a HandoffInputData"},
|
| 1020 |
+
),
|
| 1021 |
+
)
|
| 1022 |
+
raise UserError(f"Invalid input filter result: {filtered}")
|
| 1023 |
+
|
| 1024 |
+
original_input = (
|
| 1025 |
+
filtered.input_history
|
| 1026 |
+
if isinstance(filtered.input_history, str)
|
| 1027 |
+
else list(filtered.input_history)
|
| 1028 |
+
)
|
| 1029 |
+
pre_step_items = list(filtered.pre_handoff_items)
|
| 1030 |
+
new_step_items = list(filtered.new_items)
|
| 1031 |
+
|
| 1032 |
+
return SingleStepResult(
|
| 1033 |
+
original_input=original_input,
|
| 1034 |
+
model_response=new_response,
|
| 1035 |
+
pre_step_items=pre_step_items,
|
| 1036 |
+
new_step_items=new_step_items,
|
| 1037 |
+
next_step=NextStepHandoff(new_agent),
|
| 1038 |
+
tool_input_guardrail_results=[],
|
| 1039 |
+
tool_output_guardrail_results=[],
|
| 1040 |
+
)
|
| 1041 |
+
|
| 1042 |
+
@classmethod
|
| 1043 |
+
async def execute_mcp_approval_requests(
|
| 1044 |
+
cls,
|
| 1045 |
+
*,
|
| 1046 |
+
agent: Agent[TContext],
|
| 1047 |
+
approval_requests: list[ToolRunMCPApprovalRequest],
|
| 1048 |
+
context_wrapper: RunContextWrapper[TContext],
|
| 1049 |
+
) -> list[RunItem]:
|
| 1050 |
+
async def run_single_approval(approval_request: ToolRunMCPApprovalRequest) -> RunItem:
|
| 1051 |
+
callback = approval_request.mcp_tool.on_approval_request
|
| 1052 |
+
assert callback is not None, "Callback is required for MCP approval requests"
|
| 1053 |
+
maybe_awaitable_result = callback(
|
| 1054 |
+
MCPToolApprovalRequest(context_wrapper, approval_request.request_item)
|
| 1055 |
+
)
|
| 1056 |
+
if inspect.isawaitable(maybe_awaitable_result):
|
| 1057 |
+
result = await maybe_awaitable_result
|
| 1058 |
+
else:
|
| 1059 |
+
result = maybe_awaitable_result
|
| 1060 |
+
reason = result.get("reason", None)
|
| 1061 |
+
raw_item: McpApprovalResponse = {
|
| 1062 |
+
"approval_request_id": approval_request.request_item.id,
|
| 1063 |
+
"approve": result["approve"],
|
| 1064 |
+
"type": "mcp_approval_response",
|
| 1065 |
+
}
|
| 1066 |
+
if not result["approve"] and reason:
|
| 1067 |
+
raw_item["reason"] = reason
|
| 1068 |
+
return MCPApprovalResponseItem(
|
| 1069 |
+
raw_item=raw_item,
|
| 1070 |
+
agent=agent,
|
| 1071 |
+
)
|
| 1072 |
+
|
| 1073 |
+
tasks = [run_single_approval(approval_request) for approval_request in approval_requests]
|
| 1074 |
+
return await asyncio.gather(*tasks)
|
| 1075 |
+
|
| 1076 |
+
@classmethod
|
| 1077 |
+
async def execute_final_output(
|
| 1078 |
+
cls,
|
| 1079 |
+
*,
|
| 1080 |
+
agent: Agent[TContext],
|
| 1081 |
+
original_input: str | list[TResponseInputItem],
|
| 1082 |
+
new_response: ModelResponse,
|
| 1083 |
+
pre_step_items: list[RunItem],
|
| 1084 |
+
new_step_items: list[RunItem],
|
| 1085 |
+
final_output: Any,
|
| 1086 |
+
hooks: RunHooks[TContext],
|
| 1087 |
+
context_wrapper: RunContextWrapper[TContext],
|
| 1088 |
+
tool_input_guardrail_results: list[ToolInputGuardrailResult],
|
| 1089 |
+
tool_output_guardrail_results: list[ToolOutputGuardrailResult],
|
| 1090 |
+
) -> SingleStepResult:
|
| 1091 |
+
# Run the on_end hooks
|
| 1092 |
+
await cls.run_final_output_hooks(agent, hooks, context_wrapper, final_output)
|
| 1093 |
+
|
| 1094 |
+
return SingleStepResult(
|
| 1095 |
+
original_input=original_input,
|
| 1096 |
+
model_response=new_response,
|
| 1097 |
+
pre_step_items=pre_step_items,
|
| 1098 |
+
new_step_items=new_step_items,
|
| 1099 |
+
next_step=NextStepFinalOutput(final_output),
|
| 1100 |
+
tool_input_guardrail_results=tool_input_guardrail_results,
|
| 1101 |
+
tool_output_guardrail_results=tool_output_guardrail_results,
|
| 1102 |
+
)
|
| 1103 |
+
|
| 1104 |
+
@classmethod
|
| 1105 |
+
async def run_final_output_hooks(
|
| 1106 |
+
cls,
|
| 1107 |
+
agent: Agent[TContext],
|
| 1108 |
+
hooks: RunHooks[TContext],
|
| 1109 |
+
context_wrapper: RunContextWrapper[TContext],
|
| 1110 |
+
final_output: Any,
|
| 1111 |
+
):
|
| 1112 |
+
await asyncio.gather(
|
| 1113 |
+
hooks.on_agent_end(context_wrapper, agent, final_output),
|
| 1114 |
+
agent.hooks.on_end(context_wrapper, agent, final_output)
|
| 1115 |
+
if agent.hooks
|
| 1116 |
+
else _coro.noop_coroutine(),
|
| 1117 |
+
)
|
| 1118 |
+
|
| 1119 |
+
@classmethod
|
| 1120 |
+
async def run_single_input_guardrail(
|
| 1121 |
+
cls,
|
| 1122 |
+
agent: Agent[Any],
|
| 1123 |
+
guardrail: InputGuardrail[TContext],
|
| 1124 |
+
input: str | list[TResponseInputItem],
|
| 1125 |
+
context: RunContextWrapper[TContext],
|
| 1126 |
+
) -> InputGuardrailResult:
|
| 1127 |
+
with guardrail_span(guardrail.get_name()) as span_guardrail:
|
| 1128 |
+
result = await guardrail.run(agent, input, context)
|
| 1129 |
+
span_guardrail.span_data.triggered = result.output.tripwire_triggered
|
| 1130 |
+
return result
|
| 1131 |
+
|
| 1132 |
+
@classmethod
|
| 1133 |
+
async def run_single_output_guardrail(
|
| 1134 |
+
cls,
|
| 1135 |
+
guardrail: OutputGuardrail[TContext],
|
| 1136 |
+
agent: Agent[Any],
|
| 1137 |
+
agent_output: Any,
|
| 1138 |
+
context: RunContextWrapper[TContext],
|
| 1139 |
+
) -> OutputGuardrailResult:
|
| 1140 |
+
with guardrail_span(guardrail.get_name()) as span_guardrail:
|
| 1141 |
+
result = await guardrail.run(agent=agent, agent_output=agent_output, context=context)
|
| 1142 |
+
span_guardrail.span_data.triggered = result.output.tripwire_triggered
|
| 1143 |
+
return result
|
| 1144 |
+
|
| 1145 |
+
@classmethod
|
| 1146 |
+
def stream_step_items_to_queue(
|
| 1147 |
+
cls,
|
| 1148 |
+
new_step_items: list[RunItem],
|
| 1149 |
+
queue: asyncio.Queue[StreamEvent | QueueCompleteSentinel],
|
| 1150 |
+
):
|
| 1151 |
+
for item in new_step_items:
|
| 1152 |
+
if isinstance(item, MessageOutputItem):
|
| 1153 |
+
event = RunItemStreamEvent(item=item, name="message_output_created")
|
| 1154 |
+
elif isinstance(item, HandoffCallItem):
|
| 1155 |
+
event = RunItemStreamEvent(item=item, name="handoff_requested")
|
| 1156 |
+
elif isinstance(item, HandoffOutputItem):
|
| 1157 |
+
event = RunItemStreamEvent(item=item, name="handoff_occured")
|
| 1158 |
+
elif isinstance(item, ToolCallItem):
|
| 1159 |
+
event = RunItemStreamEvent(item=item, name="tool_called")
|
| 1160 |
+
elif isinstance(item, ToolCallOutputItem):
|
| 1161 |
+
event = RunItemStreamEvent(item=item, name="tool_output")
|
| 1162 |
+
elif isinstance(item, ReasoningItem):
|
| 1163 |
+
event = RunItemStreamEvent(item=item, name="reasoning_item_created")
|
| 1164 |
+
elif isinstance(item, MCPApprovalRequestItem):
|
| 1165 |
+
event = RunItemStreamEvent(item=item, name="mcp_approval_requested")
|
| 1166 |
+
elif isinstance(item, MCPListToolsItem):
|
| 1167 |
+
event = RunItemStreamEvent(item=item, name="mcp_list_tools")
|
| 1168 |
+
|
| 1169 |
+
else:
|
| 1170 |
+
logger.warning(f"Unexpected item type: {type(item)}")
|
| 1171 |
+
event = None
|
| 1172 |
+
|
| 1173 |
+
if event:
|
| 1174 |
+
queue.put_nowait(event)
|
| 1175 |
+
|
| 1176 |
+
@classmethod
|
| 1177 |
+
def stream_step_result_to_queue(
|
| 1178 |
+
cls,
|
| 1179 |
+
step_result: SingleStepResult,
|
| 1180 |
+
queue: asyncio.Queue[StreamEvent | QueueCompleteSentinel],
|
| 1181 |
+
):
|
| 1182 |
+
cls.stream_step_items_to_queue(step_result.new_step_items, queue)
|
| 1183 |
+
|
| 1184 |
+
@classmethod
|
| 1185 |
+
async def _check_for_final_output_from_tools(
|
| 1186 |
+
cls,
|
| 1187 |
+
*,
|
| 1188 |
+
agent: Agent[TContext],
|
| 1189 |
+
tool_results: list[FunctionToolResult],
|
| 1190 |
+
context_wrapper: RunContextWrapper[TContext],
|
| 1191 |
+
config: RunConfig,
|
| 1192 |
+
) -> ToolsToFinalOutputResult:
|
| 1193 |
+
"""Determine if tool results should produce a final output.
|
| 1194 |
+
Returns:
|
| 1195 |
+
ToolsToFinalOutputResult: Indicates whether final output is ready, and the output value.
|
| 1196 |
+
"""
|
| 1197 |
+
if not tool_results:
|
| 1198 |
+
return _NOT_FINAL_OUTPUT
|
| 1199 |
+
|
| 1200 |
+
if agent.tool_use_behavior == "run_llm_again":
|
| 1201 |
+
return _NOT_FINAL_OUTPUT
|
| 1202 |
+
elif agent.tool_use_behavior == "stop_on_first_tool":
|
| 1203 |
+
return ToolsToFinalOutputResult(
|
| 1204 |
+
is_final_output=True, final_output=tool_results[0].output
|
| 1205 |
+
)
|
| 1206 |
+
elif isinstance(agent.tool_use_behavior, dict):
|
| 1207 |
+
names = agent.tool_use_behavior.get("stop_at_tool_names", [])
|
| 1208 |
+
for tool_result in tool_results:
|
| 1209 |
+
if tool_result.tool.name in names:
|
| 1210 |
+
return ToolsToFinalOutputResult(
|
| 1211 |
+
is_final_output=True, final_output=tool_result.output
|
| 1212 |
+
)
|
| 1213 |
+
return ToolsToFinalOutputResult(is_final_output=False, final_output=None)
|
| 1214 |
+
elif callable(agent.tool_use_behavior):
|
| 1215 |
+
if inspect.iscoroutinefunction(agent.tool_use_behavior):
|
| 1216 |
+
return await cast(
|
| 1217 |
+
Awaitable[ToolsToFinalOutputResult],
|
| 1218 |
+
agent.tool_use_behavior(context_wrapper, tool_results),
|
| 1219 |
+
)
|
| 1220 |
+
else:
|
| 1221 |
+
return cast(
|
| 1222 |
+
ToolsToFinalOutputResult, agent.tool_use_behavior(context_wrapper, tool_results)
|
| 1223 |
+
)
|
| 1224 |
+
|
| 1225 |
+
logger.error(f"Invalid tool_use_behavior: {agent.tool_use_behavior}")
|
| 1226 |
+
raise UserError(f"Invalid tool_use_behavior: {agent.tool_use_behavior}")
|
| 1227 |
+
|
| 1228 |
+
|
| 1229 |
+
class TraceCtxManager:
|
| 1230 |
+
"""Creates a trace only if there is no current trace, and manages the trace lifecycle."""
|
| 1231 |
+
|
| 1232 |
+
def __init__(
|
| 1233 |
+
self,
|
| 1234 |
+
workflow_name: str,
|
| 1235 |
+
trace_id: str | None,
|
| 1236 |
+
group_id: str | None,
|
| 1237 |
+
metadata: dict[str, Any] | None,
|
| 1238 |
+
disabled: bool,
|
| 1239 |
+
):
|
| 1240 |
+
self.trace: Trace | None = None
|
| 1241 |
+
self.workflow_name = workflow_name
|
| 1242 |
+
self.trace_id = trace_id
|
| 1243 |
+
self.group_id = group_id
|
| 1244 |
+
self.metadata = metadata
|
| 1245 |
+
self.disabled = disabled
|
| 1246 |
+
|
| 1247 |
+
def __enter__(self) -> TraceCtxManager:
|
| 1248 |
+
current_trace = get_current_trace()
|
| 1249 |
+
if not current_trace:
|
| 1250 |
+
self.trace = trace(
|
| 1251 |
+
workflow_name=self.workflow_name,
|
| 1252 |
+
trace_id=self.trace_id,
|
| 1253 |
+
group_id=self.group_id,
|
| 1254 |
+
metadata=self.metadata,
|
| 1255 |
+
disabled=self.disabled,
|
| 1256 |
+
)
|
| 1257 |
+
self.trace.start(mark_as_current=True)
|
| 1258 |
+
|
| 1259 |
+
return self
|
| 1260 |
+
|
| 1261 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 1262 |
+
if self.trace:
|
| 1263 |
+
self.trace.finish(reset_current=True)
|
| 1264 |
+
|
| 1265 |
+
|
| 1266 |
+
class ComputerAction:
|
| 1267 |
+
@classmethod
|
| 1268 |
+
async def execute(
|
| 1269 |
+
cls,
|
| 1270 |
+
*,
|
| 1271 |
+
agent: Agent[TContext],
|
| 1272 |
+
action: ToolRunComputerAction,
|
| 1273 |
+
hooks: RunHooks[TContext],
|
| 1274 |
+
context_wrapper: RunContextWrapper[TContext],
|
| 1275 |
+
config: RunConfig,
|
| 1276 |
+
acknowledged_safety_checks: list[ComputerCallOutputAcknowledgedSafetyCheck] | None = None,
|
| 1277 |
+
) -> RunItem:
|
| 1278 |
+
output_func = (
|
| 1279 |
+
cls._get_screenshot_async(action.computer_tool.computer, action.tool_call)
|
| 1280 |
+
if isinstance(action.computer_tool.computer, AsyncComputer)
|
| 1281 |
+
else cls._get_screenshot_sync(action.computer_tool.computer, action.tool_call)
|
| 1282 |
+
)
|
| 1283 |
+
|
| 1284 |
+
_, _, output = await asyncio.gather(
|
| 1285 |
+
hooks.on_tool_start(context_wrapper, agent, action.computer_tool),
|
| 1286 |
+
(
|
| 1287 |
+
agent.hooks.on_tool_start(context_wrapper, agent, action.computer_tool)
|
| 1288 |
+
if agent.hooks
|
| 1289 |
+
else _coro.noop_coroutine()
|
| 1290 |
+
),
|
| 1291 |
+
output_func,
|
| 1292 |
+
)
|
| 1293 |
+
|
| 1294 |
+
await asyncio.gather(
|
| 1295 |
+
hooks.on_tool_end(context_wrapper, agent, action.computer_tool, output),
|
| 1296 |
+
(
|
| 1297 |
+
agent.hooks.on_tool_end(context_wrapper, agent, action.computer_tool, output)
|
| 1298 |
+
if agent.hooks
|
| 1299 |
+
else _coro.noop_coroutine()
|
| 1300 |
+
),
|
| 1301 |
+
)
|
| 1302 |
+
|
| 1303 |
+
# TODO: don't send a screenshot every single time, use references
|
| 1304 |
+
image_url = f"data:image/png;base64,{output}"
|
| 1305 |
+
return ToolCallOutputItem(
|
| 1306 |
+
agent=agent,
|
| 1307 |
+
output=image_url,
|
| 1308 |
+
raw_item=ComputerCallOutput(
|
| 1309 |
+
call_id=action.tool_call.call_id,
|
| 1310 |
+
output={
|
| 1311 |
+
"type": "computer_screenshot",
|
| 1312 |
+
"image_url": image_url,
|
| 1313 |
+
},
|
| 1314 |
+
type="computer_call_output",
|
| 1315 |
+
acknowledged_safety_checks=acknowledged_safety_checks,
|
| 1316 |
+
),
|
| 1317 |
+
)
|
| 1318 |
+
|
| 1319 |
+
@classmethod
|
| 1320 |
+
async def _get_screenshot_sync(
|
| 1321 |
+
cls,
|
| 1322 |
+
computer: Computer,
|
| 1323 |
+
tool_call: ResponseComputerToolCall,
|
| 1324 |
+
) -> str:
|
| 1325 |
+
action = tool_call.action
|
| 1326 |
+
if isinstance(action, ActionClick):
|
| 1327 |
+
computer.click(action.x, action.y, action.button)
|
| 1328 |
+
elif isinstance(action, ActionDoubleClick):
|
| 1329 |
+
computer.double_click(action.x, action.y)
|
| 1330 |
+
elif isinstance(action, ActionDrag):
|
| 1331 |
+
computer.drag([(p.x, p.y) for p in action.path])
|
| 1332 |
+
elif isinstance(action, ActionKeypress):
|
| 1333 |
+
computer.keypress(action.keys)
|
| 1334 |
+
elif isinstance(action, ActionMove):
|
| 1335 |
+
computer.move(action.x, action.y)
|
| 1336 |
+
elif isinstance(action, ActionScreenshot):
|
| 1337 |
+
computer.screenshot()
|
| 1338 |
+
elif isinstance(action, ActionScroll):
|
| 1339 |
+
computer.scroll(action.x, action.y, action.scroll_x, action.scroll_y)
|
| 1340 |
+
elif isinstance(action, ActionType):
|
| 1341 |
+
computer.type(action.text)
|
| 1342 |
+
elif isinstance(action, ActionWait):
|
| 1343 |
+
computer.wait()
|
| 1344 |
+
|
| 1345 |
+
return computer.screenshot()
|
| 1346 |
+
|
| 1347 |
+
@classmethod
|
| 1348 |
+
async def _get_screenshot_async(
|
| 1349 |
+
cls,
|
| 1350 |
+
computer: AsyncComputer,
|
| 1351 |
+
tool_call: ResponseComputerToolCall,
|
| 1352 |
+
) -> str:
|
| 1353 |
+
action = tool_call.action
|
| 1354 |
+
if isinstance(action, ActionClick):
|
| 1355 |
+
await computer.click(action.x, action.y, action.button)
|
| 1356 |
+
elif isinstance(action, ActionDoubleClick):
|
| 1357 |
+
await computer.double_click(action.x, action.y)
|
| 1358 |
+
elif isinstance(action, ActionDrag):
|
| 1359 |
+
await computer.drag([(p.x, p.y) for p in action.path])
|
| 1360 |
+
elif isinstance(action, ActionKeypress):
|
| 1361 |
+
await computer.keypress(action.keys)
|
| 1362 |
+
elif isinstance(action, ActionMove):
|
| 1363 |
+
await computer.move(action.x, action.y)
|
| 1364 |
+
elif isinstance(action, ActionScreenshot):
|
| 1365 |
+
await computer.screenshot()
|
| 1366 |
+
elif isinstance(action, ActionScroll):
|
| 1367 |
+
await computer.scroll(action.x, action.y, action.scroll_x, action.scroll_y)
|
| 1368 |
+
elif isinstance(action, ActionType):
|
| 1369 |
+
await computer.type(action.text)
|
| 1370 |
+
elif isinstance(action, ActionWait):
|
| 1371 |
+
await computer.wait()
|
| 1372 |
+
|
| 1373 |
+
return await computer.screenshot()
|
| 1374 |
+
|
| 1375 |
+
|
| 1376 |
+
class LocalShellAction:
|
| 1377 |
+
@classmethod
|
| 1378 |
+
async def execute(
|
| 1379 |
+
cls,
|
| 1380 |
+
*,
|
| 1381 |
+
agent: Agent[TContext],
|
| 1382 |
+
call: ToolRunLocalShellCall,
|
| 1383 |
+
hooks: RunHooks[TContext],
|
| 1384 |
+
context_wrapper: RunContextWrapper[TContext],
|
| 1385 |
+
config: RunConfig,
|
| 1386 |
+
) -> RunItem:
|
| 1387 |
+
await asyncio.gather(
|
| 1388 |
+
hooks.on_tool_start(context_wrapper, agent, call.local_shell_tool),
|
| 1389 |
+
(
|
| 1390 |
+
agent.hooks.on_tool_start(context_wrapper, agent, call.local_shell_tool)
|
| 1391 |
+
if agent.hooks
|
| 1392 |
+
else _coro.noop_coroutine()
|
| 1393 |
+
),
|
| 1394 |
+
)
|
| 1395 |
+
|
| 1396 |
+
request = LocalShellCommandRequest(
|
| 1397 |
+
ctx_wrapper=context_wrapper,
|
| 1398 |
+
data=call.tool_call,
|
| 1399 |
+
)
|
| 1400 |
+
output = call.local_shell_tool.executor(request)
|
| 1401 |
+
if inspect.isawaitable(output):
|
| 1402 |
+
result = await output
|
| 1403 |
+
else:
|
| 1404 |
+
result = output
|
| 1405 |
+
|
| 1406 |
+
await asyncio.gather(
|
| 1407 |
+
hooks.on_tool_end(context_wrapper, agent, call.local_shell_tool, result),
|
| 1408 |
+
(
|
| 1409 |
+
agent.hooks.on_tool_end(context_wrapper, agent, call.local_shell_tool, result)
|
| 1410 |
+
if agent.hooks
|
| 1411 |
+
else _coro.noop_coroutine()
|
| 1412 |
+
),
|
| 1413 |
+
)
|
| 1414 |
+
|
| 1415 |
+
return ToolCallOutputItem(
|
| 1416 |
+
agent=agent,
|
| 1417 |
+
output=output,
|
| 1418 |
+
raw_item={
|
| 1419 |
+
"type": "local_shell_call_output",
|
| 1420 |
+
"id": call.tool_call.call_id,
|
| 1421 |
+
"output": result,
|
| 1422 |
+
# "id": "out" + call.tool_call.id, # TODO remove this, it should be optional
|
| 1423 |
+
},
|
| 1424 |
+
)
|
| 1425 |
+
|
| 1426 |
+
|
| 1427 |
+
def _build_litellm_json_tool_call(output: ResponseFunctionToolCall) -> FunctionTool:
|
| 1428 |
+
async def on_invoke_tool(_ctx: ToolContext[Any], value: Any) -> Any:
|
| 1429 |
+
if isinstance(value, str):
|
| 1430 |
+
import json
|
| 1431 |
+
|
| 1432 |
+
return json.loads(value)
|
| 1433 |
+
return value
|
| 1434 |
+
|
| 1435 |
+
return FunctionTool(
|
| 1436 |
+
name=output.name,
|
| 1437 |
+
description=output.name,
|
| 1438 |
+
params_json_schema={},
|
| 1439 |
+
on_invoke_tool=on_invoke_tool,
|
| 1440 |
+
strict_json_schema=True,
|
| 1441 |
+
is_enabled=True,
|
| 1442 |
+
)
|
agents/agent.py
ADDED
|
@@ -0,0 +1,476 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import asyncio
|
| 4 |
+
import dataclasses
|
| 5 |
+
import inspect
|
| 6 |
+
from collections.abc import Awaitable
|
| 7 |
+
from dataclasses import dataclass, field
|
| 8 |
+
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, cast
|
| 9 |
+
|
| 10 |
+
from openai.types.responses.response_prompt_param import ResponsePromptParam
|
| 11 |
+
from typing_extensions import NotRequired, TypeAlias, TypedDict
|
| 12 |
+
|
| 13 |
+
from .agent_output import AgentOutputSchemaBase
|
| 14 |
+
from .guardrail import InputGuardrail, OutputGuardrail
|
| 15 |
+
from .handoffs import Handoff
|
| 16 |
+
from .items import ItemHelpers
|
| 17 |
+
from .logger import logger
|
| 18 |
+
from .mcp import MCPUtil
|
| 19 |
+
from .model_settings import ModelSettings
|
| 20 |
+
from .models.default_models import (
|
| 21 |
+
get_default_model_settings,
|
| 22 |
+
gpt_5_reasoning_settings_required,
|
| 23 |
+
is_gpt_5_default,
|
| 24 |
+
)
|
| 25 |
+
from .models.interface import Model
|
| 26 |
+
from .prompts import DynamicPromptFunction, Prompt, PromptUtil
|
| 27 |
+
from .run_context import RunContextWrapper, TContext
|
| 28 |
+
from .tool import FunctionTool, FunctionToolResult, Tool, function_tool
|
| 29 |
+
from .util import _transforms
|
| 30 |
+
from .util._types import MaybeAwaitable
|
| 31 |
+
|
| 32 |
+
if TYPE_CHECKING:
|
| 33 |
+
from .lifecycle import AgentHooks, RunHooks
|
| 34 |
+
from .mcp import MCPServer
|
| 35 |
+
from .memory.session import Session
|
| 36 |
+
from .result import RunResult
|
| 37 |
+
from .run import RunConfig
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@dataclass
|
| 41 |
+
class ToolsToFinalOutputResult:
|
| 42 |
+
is_final_output: bool
|
| 43 |
+
"""Whether this is the final output. If False, the LLM will run again and receive the tool call
|
| 44 |
+
output.
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
final_output: Any | None = None
|
| 48 |
+
"""The final output. Can be None if `is_final_output` is False, otherwise must match the
|
| 49 |
+
`output_type` of the agent.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
ToolsToFinalOutputFunction: TypeAlias = Callable[
|
| 54 |
+
[RunContextWrapper[TContext], list[FunctionToolResult]],
|
| 55 |
+
MaybeAwaitable[ToolsToFinalOutputResult],
|
| 56 |
+
]
|
| 57 |
+
"""A function that takes a run context and a list of tool results, and returns a
|
| 58 |
+
`ToolsToFinalOutputResult`.
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class StopAtTools(TypedDict):
|
| 63 |
+
stop_at_tool_names: list[str]
|
| 64 |
+
"""A list of tool names, any of which will stop the agent from running further."""
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class MCPConfig(TypedDict):
|
| 68 |
+
"""Configuration for MCP servers."""
|
| 69 |
+
|
| 70 |
+
convert_schemas_to_strict: NotRequired[bool]
|
| 71 |
+
"""If True, we will attempt to convert the MCP schemas to strict-mode schemas. This is a
|
| 72 |
+
best-effort conversion, so some schemas may not be convertible. Defaults to False.
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
@dataclass
|
| 77 |
+
class AgentBase(Generic[TContext]):
|
| 78 |
+
"""Base class for `Agent` and `RealtimeAgent`."""
|
| 79 |
+
|
| 80 |
+
name: str
|
| 81 |
+
"""The name of the agent."""
|
| 82 |
+
|
| 83 |
+
handoff_description: str | None = None
|
| 84 |
+
"""A description of the agent. This is used when the agent is used as a handoff, so that an
|
| 85 |
+
LLM knows what it does and when to invoke it.
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
tools: list[Tool] = field(default_factory=list)
|
| 89 |
+
"""A list of tools that the agent can use."""
|
| 90 |
+
|
| 91 |
+
mcp_servers: list[MCPServer] = field(default_factory=list)
|
| 92 |
+
"""A list of [Model Context Protocol](https://modelcontextprotocol.io/) servers that
|
| 93 |
+
the agent can use. Every time the agent runs, it will include tools from these servers in the
|
| 94 |
+
list of available tools.
|
| 95 |
+
|
| 96 |
+
NOTE: You are expected to manage the lifecycle of these servers. Specifically, you must call
|
| 97 |
+
`server.connect()` before passing it to the agent, and `server.cleanup()` when the server is no
|
| 98 |
+
longer needed.
|
| 99 |
+
"""
|
| 100 |
+
|
| 101 |
+
mcp_config: MCPConfig = field(default_factory=lambda: MCPConfig())
|
| 102 |
+
"""Configuration for MCP servers."""
|
| 103 |
+
|
| 104 |
+
async def get_mcp_tools(self, run_context: RunContextWrapper[TContext]) -> list[Tool]:
|
| 105 |
+
"""Fetches the available tools from the MCP servers."""
|
| 106 |
+
convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False)
|
| 107 |
+
return await MCPUtil.get_all_function_tools(
|
| 108 |
+
self.mcp_servers, convert_schemas_to_strict, run_context, self
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
async def get_all_tools(self, run_context: RunContextWrapper[TContext]) -> list[Tool]:
|
| 112 |
+
"""All agent tools, including MCP tools and function tools."""
|
| 113 |
+
mcp_tools = await self.get_mcp_tools(run_context)
|
| 114 |
+
|
| 115 |
+
async def _check_tool_enabled(tool: Tool) -> bool:
|
| 116 |
+
if not isinstance(tool, FunctionTool):
|
| 117 |
+
return True
|
| 118 |
+
|
| 119 |
+
attr = tool.is_enabled
|
| 120 |
+
if isinstance(attr, bool):
|
| 121 |
+
return attr
|
| 122 |
+
res = attr(run_context, self)
|
| 123 |
+
if inspect.isawaitable(res):
|
| 124 |
+
return bool(await res)
|
| 125 |
+
return bool(res)
|
| 126 |
+
|
| 127 |
+
results = await asyncio.gather(*(_check_tool_enabled(t) for t in self.tools))
|
| 128 |
+
enabled: list[Tool] = [t for t, ok in zip(self.tools, results) if ok]
|
| 129 |
+
return [*mcp_tools, *enabled]
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
@dataclass
|
| 133 |
+
class Agent(AgentBase, Generic[TContext]):
|
| 134 |
+
"""An agent is an AI model configured with instructions, tools, guardrails, handoffs and more.
|
| 135 |
+
|
| 136 |
+
We strongly recommend passing `instructions`, which is the "system prompt" for the agent. In
|
| 137 |
+
addition, you can pass `handoff_description`, which is a human-readable description of the
|
| 138 |
+
agent, used when the agent is used inside tools/handoffs.
|
| 139 |
+
|
| 140 |
+
Agents are generic on the context type. The context is a (mutable) object you create. It is
|
| 141 |
+
passed to tool functions, handoffs, guardrails, etc.
|
| 142 |
+
|
| 143 |
+
See `AgentBase` for base parameters that are shared with `RealtimeAgent`s.
|
| 144 |
+
"""
|
| 145 |
+
|
| 146 |
+
instructions: (
|
| 147 |
+
str
|
| 148 |
+
| Callable[
|
| 149 |
+
[RunContextWrapper[TContext], Agent[TContext]],
|
| 150 |
+
MaybeAwaitable[str],
|
| 151 |
+
]
|
| 152 |
+
| None
|
| 153 |
+
) = None
|
| 154 |
+
"""The instructions for the agent. Will be used as the "system prompt" when this agent is
|
| 155 |
+
invoked. Describes what the agent should do, and how it responds.
|
| 156 |
+
|
| 157 |
+
Can either be a string, or a function that dynamically generates instructions for the agent. If
|
| 158 |
+
you provide a function, it will be called with the context and the agent instance. It must
|
| 159 |
+
return a string.
|
| 160 |
+
"""
|
| 161 |
+
|
| 162 |
+
prompt: Prompt | DynamicPromptFunction | None = None
|
| 163 |
+
"""A prompt object (or a function that returns a Prompt). Prompts allow you to dynamically
|
| 164 |
+
configure the instructions, tools and other config for an agent outside of your code. Only
|
| 165 |
+
usable with OpenAI models, using the Responses API.
|
| 166 |
+
"""
|
| 167 |
+
|
| 168 |
+
handoffs: list[Agent[Any] | Handoff[TContext, Any]] = field(default_factory=list)
|
| 169 |
+
"""Handoffs are sub-agents that the agent can delegate to. You can provide a list of handoffs,
|
| 170 |
+
and the agent can choose to delegate to them if relevant. Allows for separation of concerns and
|
| 171 |
+
modularity.
|
| 172 |
+
"""
|
| 173 |
+
|
| 174 |
+
model: str | Model | None = None
|
| 175 |
+
"""The model implementation to use when invoking the LLM.
|
| 176 |
+
|
| 177 |
+
By default, if not set, the agent will use the default model configured in
|
| 178 |
+
`agents.models.get_default_model()` (currently "gpt-4.1").
|
| 179 |
+
"""
|
| 180 |
+
|
| 181 |
+
model_settings: ModelSettings = field(default_factory=get_default_model_settings)
|
| 182 |
+
"""Configures model-specific tuning parameters (e.g. temperature, top_p).
|
| 183 |
+
"""
|
| 184 |
+
|
| 185 |
+
input_guardrails: list[InputGuardrail[TContext]] = field(default_factory=list)
|
| 186 |
+
"""A list of checks that run in parallel to the agent's execution, before generating a
|
| 187 |
+
response. Runs only if the agent is the first agent in the chain.
|
| 188 |
+
"""
|
| 189 |
+
|
| 190 |
+
output_guardrails: list[OutputGuardrail[TContext]] = field(default_factory=list)
|
| 191 |
+
"""A list of checks that run on the final output of the agent, after generating a response.
|
| 192 |
+
Runs only if the agent produces a final output.
|
| 193 |
+
"""
|
| 194 |
+
|
| 195 |
+
output_type: type[Any] | AgentOutputSchemaBase | None = None
|
| 196 |
+
"""The type of the output object. If not provided, the output will be `str`. In most cases,
|
| 197 |
+
you should pass a regular Python type (e.g. a dataclass, Pydantic model, TypedDict, etc).
|
| 198 |
+
You can customize this in two ways:
|
| 199 |
+
1. If you want non-strict schemas, pass `AgentOutputSchema(MyClass, strict_json_schema=False)`.
|
| 200 |
+
2. If you want to use a custom JSON schema (i.e. without using the SDK's automatic schema)
|
| 201 |
+
creation, subclass and pass an `AgentOutputSchemaBase` subclass.
|
| 202 |
+
"""
|
| 203 |
+
|
| 204 |
+
hooks: AgentHooks[TContext] | None = None
|
| 205 |
+
"""A class that receives callbacks on various lifecycle events for this agent.
|
| 206 |
+
"""
|
| 207 |
+
|
| 208 |
+
tool_use_behavior: (
|
| 209 |
+
Literal["run_llm_again", "stop_on_first_tool"] | StopAtTools | ToolsToFinalOutputFunction
|
| 210 |
+
) = "run_llm_again"
|
| 211 |
+
"""
|
| 212 |
+
This lets you configure how tool use is handled.
|
| 213 |
+
- "run_llm_again": The default behavior. Tools are run, and then the LLM receives the results
|
| 214 |
+
and gets to respond.
|
| 215 |
+
- "stop_on_first_tool": The output from the first tool call is treated as the final result.
|
| 216 |
+
In other words, it isn’t sent back to the LLM for further processing but is used directly
|
| 217 |
+
as the final output.
|
| 218 |
+
- A StopAtTools object: The agent will stop running if any of the tools listed in
|
| 219 |
+
`stop_at_tool_names` is called.
|
| 220 |
+
The final output will be the output of the first matching tool call.
|
| 221 |
+
The LLM does not process the result of the tool call.
|
| 222 |
+
- A function: If you pass a function, it will be called with the run context and the list of
|
| 223 |
+
tool results. It must return a `ToolsToFinalOutputResult`, which determines whether the tool
|
| 224 |
+
calls result in a final output.
|
| 225 |
+
|
| 226 |
+
NOTE: This configuration is specific to FunctionTools. Hosted tools, such as file search,
|
| 227 |
+
web search, etc. are always processed by the LLM.
|
| 228 |
+
"""
|
| 229 |
+
|
| 230 |
+
reset_tool_choice: bool = True
|
| 231 |
+
"""Whether to reset the tool choice to the default value after a tool has been called. Defaults
|
| 232 |
+
to True. This ensures that the agent doesn't enter an infinite loop of tool usage."""
|
| 233 |
+
|
| 234 |
+
def __post_init__(self):
|
| 235 |
+
from typing import get_origin
|
| 236 |
+
|
| 237 |
+
if not isinstance(self.name, str):
|
| 238 |
+
raise TypeError(f"Agent name must be a string, got {type(self.name).__name__}")
|
| 239 |
+
|
| 240 |
+
if self.handoff_description is not None and not isinstance(self.handoff_description, str):
|
| 241 |
+
raise TypeError(
|
| 242 |
+
f"Agent handoff_description must be a string or None, "
|
| 243 |
+
f"got {type(self.handoff_description).__name__}"
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
if not isinstance(self.tools, list):
|
| 247 |
+
raise TypeError(f"Agent tools must be a list, got {type(self.tools).__name__}")
|
| 248 |
+
|
| 249 |
+
if not isinstance(self.mcp_servers, list):
|
| 250 |
+
raise TypeError(
|
| 251 |
+
f"Agent mcp_servers must be a list, got {type(self.mcp_servers).__name__}"
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
if not isinstance(self.mcp_config, dict):
|
| 255 |
+
raise TypeError(
|
| 256 |
+
f"Agent mcp_config must be a dict, got {type(self.mcp_config).__name__}"
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
if (
|
| 260 |
+
self.instructions is not None
|
| 261 |
+
and not isinstance(self.instructions, str)
|
| 262 |
+
and not callable(self.instructions)
|
| 263 |
+
):
|
| 264 |
+
raise TypeError(
|
| 265 |
+
f"Agent instructions must be a string, callable, or None, "
|
| 266 |
+
f"got {type(self.instructions).__name__}"
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
if (
|
| 270 |
+
self.prompt is not None
|
| 271 |
+
and not callable(self.prompt)
|
| 272 |
+
and not hasattr(self.prompt, "get")
|
| 273 |
+
):
|
| 274 |
+
raise TypeError(
|
| 275 |
+
f"Agent prompt must be a Prompt, DynamicPromptFunction, or None, "
|
| 276 |
+
f"got {type(self.prompt).__name__}"
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
if not isinstance(self.handoffs, list):
|
| 280 |
+
raise TypeError(f"Agent handoffs must be a list, got {type(self.handoffs).__name__}")
|
| 281 |
+
|
| 282 |
+
if self.model is not None and not isinstance(self.model, str):
|
| 283 |
+
from .models.interface import Model
|
| 284 |
+
|
| 285 |
+
if not isinstance(self.model, Model):
|
| 286 |
+
raise TypeError(
|
| 287 |
+
f"Agent model must be a string, Model, or None, got {type(self.model).__name__}"
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
if not isinstance(self.model_settings, ModelSettings):
|
| 291 |
+
raise TypeError(
|
| 292 |
+
f"Agent model_settings must be a ModelSettings instance, "
|
| 293 |
+
f"got {type(self.model_settings).__name__}"
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
if (
|
| 297 |
+
# The user sets a non-default model
|
| 298 |
+
self.model is not None
|
| 299 |
+
and (
|
| 300 |
+
# The default model is gpt-5
|
| 301 |
+
is_gpt_5_default() is True
|
| 302 |
+
# However, the specified model is not a gpt-5 model
|
| 303 |
+
and (
|
| 304 |
+
isinstance(self.model, str) is False
|
| 305 |
+
or gpt_5_reasoning_settings_required(self.model) is False # type: ignore
|
| 306 |
+
)
|
| 307 |
+
# The model settings are not customized for the specified model
|
| 308 |
+
and self.model_settings == get_default_model_settings()
|
| 309 |
+
)
|
| 310 |
+
):
|
| 311 |
+
# In this scenario, we should use a generic model settings
|
| 312 |
+
# because non-gpt-5 models are not compatible with the default gpt-5 model settings.
|
| 313 |
+
# This is a best-effort attempt to make the agent work with non-gpt-5 models.
|
| 314 |
+
self.model_settings = ModelSettings()
|
| 315 |
+
|
| 316 |
+
if not isinstance(self.input_guardrails, list):
|
| 317 |
+
raise TypeError(
|
| 318 |
+
f"Agent input_guardrails must be a list, got {type(self.input_guardrails).__name__}"
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
if not isinstance(self.output_guardrails, list):
|
| 322 |
+
raise TypeError(
|
| 323 |
+
f"Agent output_guardrails must be a list, "
|
| 324 |
+
f"got {type(self.output_guardrails).__name__}"
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
if self.output_type is not None:
|
| 328 |
+
from .agent_output import AgentOutputSchemaBase
|
| 329 |
+
|
| 330 |
+
if not (
|
| 331 |
+
isinstance(self.output_type, (type, AgentOutputSchemaBase))
|
| 332 |
+
or get_origin(self.output_type) is not None
|
| 333 |
+
):
|
| 334 |
+
raise TypeError(
|
| 335 |
+
f"Agent output_type must be a type, AgentOutputSchemaBase, or None, "
|
| 336 |
+
f"got {type(self.output_type).__name__}"
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
if self.hooks is not None:
|
| 340 |
+
from .lifecycle import AgentHooksBase
|
| 341 |
+
|
| 342 |
+
if not isinstance(self.hooks, AgentHooksBase):
|
| 343 |
+
raise TypeError(
|
| 344 |
+
f"Agent hooks must be an AgentHooks instance or None, "
|
| 345 |
+
f"got {type(self.hooks).__name__}"
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
if (
|
| 349 |
+
not (
|
| 350 |
+
isinstance(self.tool_use_behavior, str)
|
| 351 |
+
and self.tool_use_behavior in ["run_llm_again", "stop_on_first_tool"]
|
| 352 |
+
)
|
| 353 |
+
and not isinstance(self.tool_use_behavior, dict)
|
| 354 |
+
and not callable(self.tool_use_behavior)
|
| 355 |
+
):
|
| 356 |
+
raise TypeError(
|
| 357 |
+
f"Agent tool_use_behavior must be 'run_llm_again', 'stop_on_first_tool', "
|
| 358 |
+
f"StopAtTools dict, or callable, got {type(self.tool_use_behavior).__name__}"
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
if not isinstance(self.reset_tool_choice, bool):
|
| 362 |
+
raise TypeError(
|
| 363 |
+
f"Agent reset_tool_choice must be a boolean, "
|
| 364 |
+
f"got {type(self.reset_tool_choice).__name__}"
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
def clone(self, **kwargs: Any) -> Agent[TContext]:
|
| 368 |
+
"""Make a copy of the agent, with the given arguments changed.
|
| 369 |
+
Notes:
|
| 370 |
+
- Uses `dataclasses.replace`, which performs a **shallow copy**.
|
| 371 |
+
- Mutable attributes like `tools` and `handoffs` are shallow-copied:
|
| 372 |
+
new list objects are created only if overridden, but their contents
|
| 373 |
+
(tool functions and handoff objects) are shared with the original.
|
| 374 |
+
- To modify these independently, pass new lists when calling `clone()`.
|
| 375 |
+
Example:
|
| 376 |
+
```python
|
| 377 |
+
new_agent = agent.clone(instructions="New instructions")
|
| 378 |
+
```
|
| 379 |
+
"""
|
| 380 |
+
return dataclasses.replace(self, **kwargs)
|
| 381 |
+
|
| 382 |
+
def as_tool(
|
| 383 |
+
self,
|
| 384 |
+
tool_name: str | None,
|
| 385 |
+
tool_description: str | None,
|
| 386 |
+
custom_output_extractor: Callable[[RunResult], Awaitable[str]] | None = None,
|
| 387 |
+
is_enabled: bool
|
| 388 |
+
| Callable[[RunContextWrapper[Any], AgentBase[Any]], MaybeAwaitable[bool]] = True,
|
| 389 |
+
run_config: RunConfig | None = None,
|
| 390 |
+
max_turns: int | None = None,
|
| 391 |
+
hooks: RunHooks[TContext] | None = None,
|
| 392 |
+
previous_response_id: str | None = None,
|
| 393 |
+
conversation_id: str | None = None,
|
| 394 |
+
session: Session | None = None,
|
| 395 |
+
) -> Tool:
|
| 396 |
+
"""Transform this agent into a tool, callable by other agents.
|
| 397 |
+
|
| 398 |
+
This is different from handoffs in two ways:
|
| 399 |
+
1. In handoffs, the new agent receives the conversation history. In this tool, the new agent
|
| 400 |
+
receives generated input.
|
| 401 |
+
2. In handoffs, the new agent takes over the conversation. In this tool, the new agent is
|
| 402 |
+
called as a tool, and the conversation is continued by the original agent.
|
| 403 |
+
|
| 404 |
+
Args:
|
| 405 |
+
tool_name: The name of the tool. If not provided, the agent's name will be used.
|
| 406 |
+
tool_description: The description of the tool, which should indicate what it does and
|
| 407 |
+
when to use it.
|
| 408 |
+
custom_output_extractor: A function that extracts the output from the agent. If not
|
| 409 |
+
provided, the last message from the agent will be used.
|
| 410 |
+
is_enabled: Whether the tool is enabled. Can be a bool or a callable that takes the run
|
| 411 |
+
context and agent and returns whether the tool is enabled. Disabled tools are hidden
|
| 412 |
+
from the LLM at runtime.
|
| 413 |
+
"""
|
| 414 |
+
|
| 415 |
+
@function_tool(
|
| 416 |
+
name_override=tool_name or _transforms.transform_string_function_style(self.name),
|
| 417 |
+
description_override=tool_description or "",
|
| 418 |
+
is_enabled=is_enabled,
|
| 419 |
+
)
|
| 420 |
+
async def run_agent(context: RunContextWrapper, input: str) -> str:
|
| 421 |
+
from .run import DEFAULT_MAX_TURNS, Runner
|
| 422 |
+
|
| 423 |
+
resolved_max_turns = max_turns if max_turns is not None else DEFAULT_MAX_TURNS
|
| 424 |
+
|
| 425 |
+
output = await Runner.run(
|
| 426 |
+
starting_agent=self,
|
| 427 |
+
input=input,
|
| 428 |
+
context=context.context,
|
| 429 |
+
run_config=run_config,
|
| 430 |
+
max_turns=resolved_max_turns,
|
| 431 |
+
hooks=hooks,
|
| 432 |
+
previous_response_id=previous_response_id,
|
| 433 |
+
conversation_id=conversation_id,
|
| 434 |
+
session=session,
|
| 435 |
+
)
|
| 436 |
+
if custom_output_extractor:
|
| 437 |
+
return await custom_output_extractor(output)
|
| 438 |
+
|
| 439 |
+
return ItemHelpers.text_message_outputs(output.new_items)
|
| 440 |
+
|
| 441 |
+
return run_agent
|
| 442 |
+
|
| 443 |
+
async def get_system_prompt(self, run_context: RunContextWrapper[TContext]) -> str | None:
|
| 444 |
+
if isinstance(self.instructions, str):
|
| 445 |
+
return self.instructions
|
| 446 |
+
elif callable(self.instructions):
|
| 447 |
+
# Inspect the signature of the instructions function
|
| 448 |
+
sig = inspect.signature(self.instructions)
|
| 449 |
+
params = list(sig.parameters.values())
|
| 450 |
+
|
| 451 |
+
# Enforce exactly 2 parameters
|
| 452 |
+
if len(params) != 2:
|
| 453 |
+
raise TypeError(
|
| 454 |
+
f"'instructions' callable must accept exactly 2 arguments (context, agent), "
|
| 455 |
+
f"but got {len(params)}: {[p.name for p in params]}"
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
# Call the instructions function properly
|
| 459 |
+
if inspect.iscoroutinefunction(self.instructions):
|
| 460 |
+
return await cast(Awaitable[str], self.instructions(run_context, self))
|
| 461 |
+
else:
|
| 462 |
+
return cast(str, self.instructions(run_context, self))
|
| 463 |
+
|
| 464 |
+
elif self.instructions is not None:
|
| 465 |
+
logger.error(
|
| 466 |
+
f"Instructions must be a string or a callable function, "
|
| 467 |
+
f"got {type(self.instructions).__name__}"
|
| 468 |
+
)
|
| 469 |
+
|
| 470 |
+
return None
|
| 471 |
+
|
| 472 |
+
async def get_prompt(
|
| 473 |
+
self, run_context: RunContextWrapper[TContext]
|
| 474 |
+
) -> ResponsePromptParam | None:
|
| 475 |
+
"""Get the prompt for the agent."""
|
| 476 |
+
return await PromptUtil.to_model_input(self.prompt, run_context, self)
|
agents/agent_output.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import Any
|
| 4 |
+
|
| 5 |
+
from pydantic import BaseModel, TypeAdapter
|
| 6 |
+
from typing_extensions import TypedDict, get_args, get_origin
|
| 7 |
+
|
| 8 |
+
from .exceptions import ModelBehaviorError, UserError
|
| 9 |
+
from .strict_schema import ensure_strict_json_schema
|
| 10 |
+
from .tracing import SpanError
|
| 11 |
+
from .util import _error_tracing, _json
|
| 12 |
+
|
| 13 |
+
_WRAPPER_DICT_KEY = "response"
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class AgentOutputSchemaBase(abc.ABC):
|
| 17 |
+
"""An object that captures the JSON schema of the output, as well as validating/parsing JSON
|
| 18 |
+
produced by the LLM into the output type.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
@abc.abstractmethod
|
| 22 |
+
def is_plain_text(self) -> bool:
|
| 23 |
+
"""Whether the output type is plain text (versus a JSON object)."""
|
| 24 |
+
pass
|
| 25 |
+
|
| 26 |
+
@abc.abstractmethod
|
| 27 |
+
def name(self) -> str:
|
| 28 |
+
"""The name of the output type."""
|
| 29 |
+
pass
|
| 30 |
+
|
| 31 |
+
@abc.abstractmethod
|
| 32 |
+
def json_schema(self) -> dict[str, Any]:
|
| 33 |
+
"""Returns the JSON schema of the output. Will only be called if the output type is not
|
| 34 |
+
plain text.
|
| 35 |
+
"""
|
| 36 |
+
pass
|
| 37 |
+
|
| 38 |
+
@abc.abstractmethod
|
| 39 |
+
def is_strict_json_schema(self) -> bool:
|
| 40 |
+
"""Whether the JSON schema is in strict mode. Strict mode constrains the JSON schema
|
| 41 |
+
features, but guarantees valid JSON. See here for details:
|
| 42 |
+
https://platform.openai.com/docs/guides/structured-outputs#supported-schemas
|
| 43 |
+
"""
|
| 44 |
+
pass
|
| 45 |
+
|
| 46 |
+
@abc.abstractmethod
|
| 47 |
+
def validate_json(self, json_str: str) -> Any:
|
| 48 |
+
"""Validate a JSON string against the output type. You must return the validated object,
|
| 49 |
+
or raise a `ModelBehaviorError` if the JSON is invalid.
|
| 50 |
+
"""
|
| 51 |
+
pass
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@dataclass(init=False)
|
| 55 |
+
class AgentOutputSchema(AgentOutputSchemaBase):
|
| 56 |
+
"""An object that captures the JSON schema of the output, as well as validating/parsing JSON
|
| 57 |
+
produced by the LLM into the output type.
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
output_type: type[Any]
|
| 61 |
+
"""The type of the output."""
|
| 62 |
+
|
| 63 |
+
_type_adapter: TypeAdapter[Any]
|
| 64 |
+
"""A type adapter that wraps the output type, so that we can validate JSON."""
|
| 65 |
+
|
| 66 |
+
_is_wrapped: bool
|
| 67 |
+
"""Whether the output type is wrapped in a dictionary. This is generally done if the base
|
| 68 |
+
output type cannot be represented as a JSON Schema object.
|
| 69 |
+
"""
|
| 70 |
+
|
| 71 |
+
_output_schema: dict[str, Any]
|
| 72 |
+
"""The JSON schema of the output."""
|
| 73 |
+
|
| 74 |
+
_strict_json_schema: bool
|
| 75 |
+
"""Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True,
|
| 76 |
+
as it increases the likelihood of correct JSON input.
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
def __init__(self, output_type: type[Any], strict_json_schema: bool = True):
|
| 80 |
+
"""
|
| 81 |
+
Args:
|
| 82 |
+
output_type: The type of the output.
|
| 83 |
+
strict_json_schema: Whether the JSON schema is in strict mode. We **strongly** recommend
|
| 84 |
+
setting this to True, as it increases the likelihood of correct JSON input.
|
| 85 |
+
"""
|
| 86 |
+
self.output_type = output_type
|
| 87 |
+
self._strict_json_schema = strict_json_schema
|
| 88 |
+
|
| 89 |
+
if output_type is None or output_type is str:
|
| 90 |
+
self._is_wrapped = False
|
| 91 |
+
self._type_adapter = TypeAdapter(output_type)
|
| 92 |
+
self._output_schema = self._type_adapter.json_schema()
|
| 93 |
+
return
|
| 94 |
+
|
| 95 |
+
# We should wrap for things that are not plain text, and for things that would definitely
|
| 96 |
+
# not be a JSON Schema object.
|
| 97 |
+
self._is_wrapped = not _is_subclass_of_base_model_or_dict(output_type)
|
| 98 |
+
|
| 99 |
+
if self._is_wrapped:
|
| 100 |
+
OutputType = TypedDict(
|
| 101 |
+
"OutputType",
|
| 102 |
+
{
|
| 103 |
+
_WRAPPER_DICT_KEY: output_type, # type: ignore
|
| 104 |
+
},
|
| 105 |
+
)
|
| 106 |
+
self._type_adapter = TypeAdapter(OutputType)
|
| 107 |
+
self._output_schema = self._type_adapter.json_schema()
|
| 108 |
+
else:
|
| 109 |
+
self._type_adapter = TypeAdapter(output_type)
|
| 110 |
+
self._output_schema = self._type_adapter.json_schema()
|
| 111 |
+
|
| 112 |
+
if self._strict_json_schema:
|
| 113 |
+
try:
|
| 114 |
+
self._output_schema = ensure_strict_json_schema(self._output_schema)
|
| 115 |
+
except UserError as e:
|
| 116 |
+
raise UserError(
|
| 117 |
+
"Strict JSON schema is enabled, but the output type is not valid. "
|
| 118 |
+
"Either make the output type strict, "
|
| 119 |
+
"or wrap your type with AgentOutputSchema(YourType, strict_json_schema=False)"
|
| 120 |
+
) from e
|
| 121 |
+
|
| 122 |
+
def is_plain_text(self) -> bool:
|
| 123 |
+
"""Whether the output type is plain text (versus a JSON object)."""
|
| 124 |
+
return self.output_type is None or self.output_type is str
|
| 125 |
+
|
| 126 |
+
def is_strict_json_schema(self) -> bool:
|
| 127 |
+
"""Whether the JSON schema is in strict mode."""
|
| 128 |
+
return self._strict_json_schema
|
| 129 |
+
|
| 130 |
+
def json_schema(self) -> dict[str, Any]:
|
| 131 |
+
"""The JSON schema of the output type."""
|
| 132 |
+
if self.is_plain_text():
|
| 133 |
+
raise UserError("Output type is plain text, so no JSON schema is available")
|
| 134 |
+
return self._output_schema
|
| 135 |
+
|
| 136 |
+
def validate_json(self, json_str: str) -> Any:
|
| 137 |
+
"""Validate a JSON string against the output type. Returns the validated object, or raises
|
| 138 |
+
a `ModelBehaviorError` if the JSON is invalid.
|
| 139 |
+
"""
|
| 140 |
+
validated = _json.validate_json(json_str, self._type_adapter, partial=False)
|
| 141 |
+
if self._is_wrapped:
|
| 142 |
+
if not isinstance(validated, dict):
|
| 143 |
+
_error_tracing.attach_error_to_current_span(
|
| 144 |
+
SpanError(
|
| 145 |
+
message="Invalid JSON",
|
| 146 |
+
data={"details": f"Expected a dict, got {type(validated)}"},
|
| 147 |
+
)
|
| 148 |
+
)
|
| 149 |
+
raise ModelBehaviorError(
|
| 150 |
+
f"Expected a dict, got {type(validated)} for JSON: {json_str}"
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
if _WRAPPER_DICT_KEY not in validated:
|
| 154 |
+
_error_tracing.attach_error_to_current_span(
|
| 155 |
+
SpanError(
|
| 156 |
+
message="Invalid JSON",
|
| 157 |
+
data={"details": f"Could not find key {_WRAPPER_DICT_KEY} in JSON"},
|
| 158 |
+
)
|
| 159 |
+
)
|
| 160 |
+
raise ModelBehaviorError(
|
| 161 |
+
f"Could not find key {_WRAPPER_DICT_KEY} in JSON: {json_str}"
|
| 162 |
+
)
|
| 163 |
+
return validated[_WRAPPER_DICT_KEY]
|
| 164 |
+
return validated
|
| 165 |
+
|
| 166 |
+
def name(self) -> str:
|
| 167 |
+
"""The name of the output type."""
|
| 168 |
+
return _type_to_str(self.output_type)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def _is_subclass_of_base_model_or_dict(t: Any) -> bool:
|
| 172 |
+
if not isinstance(t, type):
|
| 173 |
+
return False
|
| 174 |
+
|
| 175 |
+
# If it's a generic alias, 'origin' will be the actual type, e.g. 'list'
|
| 176 |
+
origin = get_origin(t)
|
| 177 |
+
|
| 178 |
+
allowed_types = (BaseModel, dict)
|
| 179 |
+
# If it's a generic alias e.g. list[str], then we should check the origin type i.e. list
|
| 180 |
+
return issubclass(origin or t, allowed_types)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def _type_to_str(t: type[Any]) -> str:
|
| 184 |
+
origin = get_origin(t)
|
| 185 |
+
args = get_args(t)
|
| 186 |
+
|
| 187 |
+
if origin is None:
|
| 188 |
+
# It's a simple type like `str`, `int`, etc.
|
| 189 |
+
return t.__name__
|
| 190 |
+
elif args:
|
| 191 |
+
args_str = ", ".join(_type_to_str(arg) for arg in args)
|
| 192 |
+
return f"{origin.__name__}[{args_str}]"
|
| 193 |
+
else:
|
| 194 |
+
return str(t)
|
agents/computer.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
from typing import Literal
|
| 3 |
+
|
| 4 |
+
Environment = Literal["mac", "windows", "ubuntu", "browser"]
|
| 5 |
+
Button = Literal["left", "right", "wheel", "back", "forward"]
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Computer(abc.ABC):
|
| 9 |
+
"""A computer implemented with sync operations. The Computer interface abstracts the
|
| 10 |
+
operations needed to control a computer or browser."""
|
| 11 |
+
|
| 12 |
+
@property
|
| 13 |
+
@abc.abstractmethod
|
| 14 |
+
def environment(self) -> Environment:
|
| 15 |
+
pass
|
| 16 |
+
|
| 17 |
+
@property
|
| 18 |
+
@abc.abstractmethod
|
| 19 |
+
def dimensions(self) -> tuple[int, int]:
|
| 20 |
+
pass
|
| 21 |
+
|
| 22 |
+
@abc.abstractmethod
|
| 23 |
+
def screenshot(self) -> str:
|
| 24 |
+
pass
|
| 25 |
+
|
| 26 |
+
@abc.abstractmethod
|
| 27 |
+
def click(self, x: int, y: int, button: Button) -> None:
|
| 28 |
+
pass
|
| 29 |
+
|
| 30 |
+
@abc.abstractmethod
|
| 31 |
+
def double_click(self, x: int, y: int) -> None:
|
| 32 |
+
pass
|
| 33 |
+
|
| 34 |
+
@abc.abstractmethod
|
| 35 |
+
def scroll(self, x: int, y: int, scroll_x: int, scroll_y: int) -> None:
|
| 36 |
+
pass
|
| 37 |
+
|
| 38 |
+
@abc.abstractmethod
|
| 39 |
+
def type(self, text: str) -> None:
|
| 40 |
+
pass
|
| 41 |
+
|
| 42 |
+
@abc.abstractmethod
|
| 43 |
+
def wait(self) -> None:
|
| 44 |
+
pass
|
| 45 |
+
|
| 46 |
+
@abc.abstractmethod
|
| 47 |
+
def move(self, x: int, y: int) -> None:
|
| 48 |
+
pass
|
| 49 |
+
|
| 50 |
+
@abc.abstractmethod
|
| 51 |
+
def keypress(self, keys: list[str]) -> None:
|
| 52 |
+
pass
|
| 53 |
+
|
| 54 |
+
@abc.abstractmethod
|
| 55 |
+
def drag(self, path: list[tuple[int, int]]) -> None:
|
| 56 |
+
pass
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class AsyncComputer(abc.ABC):
|
| 60 |
+
"""A computer implemented with async operations. The Computer interface abstracts the
|
| 61 |
+
operations needed to control a computer or browser."""
|
| 62 |
+
|
| 63 |
+
@property
|
| 64 |
+
@abc.abstractmethod
|
| 65 |
+
def environment(self) -> Environment:
|
| 66 |
+
pass
|
| 67 |
+
|
| 68 |
+
@property
|
| 69 |
+
@abc.abstractmethod
|
| 70 |
+
def dimensions(self) -> tuple[int, int]:
|
| 71 |
+
pass
|
| 72 |
+
|
| 73 |
+
@abc.abstractmethod
|
| 74 |
+
async def screenshot(self) -> str:
|
| 75 |
+
pass
|
| 76 |
+
|
| 77 |
+
@abc.abstractmethod
|
| 78 |
+
async def click(self, x: int, y: int, button: Button) -> None:
|
| 79 |
+
pass
|
| 80 |
+
|
| 81 |
+
@abc.abstractmethod
|
| 82 |
+
async def double_click(self, x: int, y: int) -> None:
|
| 83 |
+
pass
|
| 84 |
+
|
| 85 |
+
@abc.abstractmethod
|
| 86 |
+
async def scroll(self, x: int, y: int, scroll_x: int, scroll_y: int) -> None:
|
| 87 |
+
pass
|
| 88 |
+
|
| 89 |
+
@abc.abstractmethod
|
| 90 |
+
async def type(self, text: str) -> None:
|
| 91 |
+
pass
|
| 92 |
+
|
| 93 |
+
@abc.abstractmethod
|
| 94 |
+
async def wait(self) -> None:
|
| 95 |
+
pass
|
| 96 |
+
|
| 97 |
+
@abc.abstractmethod
|
| 98 |
+
async def move(self, x: int, y: int) -> None:
|
| 99 |
+
pass
|
| 100 |
+
|
| 101 |
+
@abc.abstractmethod
|
| 102 |
+
async def keypress(self, keys: list[str]) -> None:
|
| 103 |
+
pass
|
| 104 |
+
|
| 105 |
+
@abc.abstractmethod
|
| 106 |
+
async def drag(self, path: list[tuple[int, int]]) -> None:
|
| 107 |
+
pass
|
agents/exceptions.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import TYPE_CHECKING, Any
|
| 5 |
+
|
| 6 |
+
if TYPE_CHECKING:
|
| 7 |
+
from .agent import Agent
|
| 8 |
+
from .guardrail import InputGuardrailResult, OutputGuardrailResult
|
| 9 |
+
from .items import ModelResponse, RunItem, TResponseInputItem
|
| 10 |
+
from .run_context import RunContextWrapper
|
| 11 |
+
from .tool_guardrails import (
|
| 12 |
+
ToolGuardrailFunctionOutput,
|
| 13 |
+
ToolInputGuardrail,
|
| 14 |
+
ToolOutputGuardrail,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
from .util._pretty_print import pretty_print_run_error_details
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class RunErrorDetails:
|
| 22 |
+
"""Data collected from an agent run when an exception occurs."""
|
| 23 |
+
|
| 24 |
+
input: str | list[TResponseInputItem]
|
| 25 |
+
new_items: list[RunItem]
|
| 26 |
+
raw_responses: list[ModelResponse]
|
| 27 |
+
last_agent: Agent[Any]
|
| 28 |
+
context_wrapper: RunContextWrapper[Any]
|
| 29 |
+
input_guardrail_results: list[InputGuardrailResult]
|
| 30 |
+
output_guardrail_results: list[OutputGuardrailResult]
|
| 31 |
+
|
| 32 |
+
def __str__(self) -> str:
|
| 33 |
+
return pretty_print_run_error_details(self)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class AgentsException(Exception):
|
| 37 |
+
"""Base class for all exceptions in the Agents SDK."""
|
| 38 |
+
|
| 39 |
+
run_data: RunErrorDetails | None
|
| 40 |
+
|
| 41 |
+
def __init__(self, *args: object) -> None:
|
| 42 |
+
super().__init__(*args)
|
| 43 |
+
self.run_data = None
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class MaxTurnsExceeded(AgentsException):
|
| 47 |
+
"""Exception raised when the maximum number of turns is exceeded."""
|
| 48 |
+
|
| 49 |
+
message: str
|
| 50 |
+
|
| 51 |
+
def __init__(self, message: str):
|
| 52 |
+
self.message = message
|
| 53 |
+
super().__init__(message)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class ModelBehaviorError(AgentsException):
|
| 57 |
+
"""Exception raised when the model does something unexpected, e.g. calling a tool that doesn't
|
| 58 |
+
exist, or providing malformed JSON.
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
message: str
|
| 62 |
+
|
| 63 |
+
def __init__(self, message: str):
|
| 64 |
+
self.message = message
|
| 65 |
+
super().__init__(message)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class UserError(AgentsException):
|
| 69 |
+
"""Exception raised when the user makes an error using the SDK."""
|
| 70 |
+
|
| 71 |
+
message: str
|
| 72 |
+
|
| 73 |
+
def __init__(self, message: str):
|
| 74 |
+
self.message = message
|
| 75 |
+
super().__init__(message)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class InputGuardrailTripwireTriggered(AgentsException):
|
| 79 |
+
"""Exception raised when a guardrail tripwire is triggered."""
|
| 80 |
+
|
| 81 |
+
guardrail_result: InputGuardrailResult
|
| 82 |
+
"""The result data of the guardrail that was triggered."""
|
| 83 |
+
|
| 84 |
+
def __init__(self, guardrail_result: InputGuardrailResult):
|
| 85 |
+
self.guardrail_result = guardrail_result
|
| 86 |
+
super().__init__(
|
| 87 |
+
f"Guardrail {guardrail_result.guardrail.__class__.__name__} triggered tripwire"
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class OutputGuardrailTripwireTriggered(AgentsException):
|
| 92 |
+
"""Exception raised when a guardrail tripwire is triggered."""
|
| 93 |
+
|
| 94 |
+
guardrail_result: OutputGuardrailResult
|
| 95 |
+
"""The result data of the guardrail that was triggered."""
|
| 96 |
+
|
| 97 |
+
def __init__(self, guardrail_result: OutputGuardrailResult):
|
| 98 |
+
self.guardrail_result = guardrail_result
|
| 99 |
+
super().__init__(
|
| 100 |
+
f"Guardrail {guardrail_result.guardrail.__class__.__name__} triggered tripwire"
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class ToolInputGuardrailTripwireTriggered(AgentsException):
|
| 105 |
+
"""Exception raised when a tool input guardrail tripwire is triggered."""
|
| 106 |
+
|
| 107 |
+
guardrail: ToolInputGuardrail[Any]
|
| 108 |
+
"""The guardrail that was triggered."""
|
| 109 |
+
|
| 110 |
+
output: ToolGuardrailFunctionOutput
|
| 111 |
+
"""The output from the guardrail function."""
|
| 112 |
+
|
| 113 |
+
def __init__(self, guardrail: ToolInputGuardrail[Any], output: ToolGuardrailFunctionOutput):
|
| 114 |
+
self.guardrail = guardrail
|
| 115 |
+
self.output = output
|
| 116 |
+
super().__init__(f"Tool input guardrail {guardrail.__class__.__name__} triggered tripwire")
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class ToolOutputGuardrailTripwireTriggered(AgentsException):
|
| 120 |
+
"""Exception raised when a tool output guardrail tripwire is triggered."""
|
| 121 |
+
|
| 122 |
+
guardrail: ToolOutputGuardrail[Any]
|
| 123 |
+
"""The guardrail that was triggered."""
|
| 124 |
+
|
| 125 |
+
output: ToolGuardrailFunctionOutput
|
| 126 |
+
"""The output from the guardrail function."""
|
| 127 |
+
|
| 128 |
+
def __init__(self, guardrail: ToolOutputGuardrail[Any], output: ToolGuardrailFunctionOutput):
|
| 129 |
+
self.guardrail = guardrail
|
| 130 |
+
self.output = output
|
| 131 |
+
super().__init__(f"Tool output guardrail {guardrail.__class__.__name__} triggered tripwire")
|
agents/extensions/__init__.py
ADDED
|
File without changes
|
agents/extensions/handoff_filters.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from ..handoffs import HandoffInputData
|
| 4 |
+
from ..items import (
|
| 5 |
+
HandoffCallItem,
|
| 6 |
+
HandoffOutputItem,
|
| 7 |
+
ReasoningItem,
|
| 8 |
+
RunItem,
|
| 9 |
+
ToolCallItem,
|
| 10 |
+
ToolCallOutputItem,
|
| 11 |
+
TResponseInputItem,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
"""Contains common handoff input filters, for convenience. """
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def remove_all_tools(handoff_input_data: HandoffInputData) -> HandoffInputData:
|
| 18 |
+
"""Filters out all tool items: file search, web search and function calls+output."""
|
| 19 |
+
|
| 20 |
+
history = handoff_input_data.input_history
|
| 21 |
+
new_items = handoff_input_data.new_items
|
| 22 |
+
|
| 23 |
+
filtered_history = (
|
| 24 |
+
_remove_tool_types_from_input(history) if isinstance(history, tuple) else history
|
| 25 |
+
)
|
| 26 |
+
filtered_pre_handoff_items = _remove_tools_from_items(handoff_input_data.pre_handoff_items)
|
| 27 |
+
filtered_new_items = _remove_tools_from_items(new_items)
|
| 28 |
+
|
| 29 |
+
return HandoffInputData(
|
| 30 |
+
input_history=filtered_history,
|
| 31 |
+
pre_handoff_items=filtered_pre_handoff_items,
|
| 32 |
+
new_items=filtered_new_items,
|
| 33 |
+
run_context=handoff_input_data.run_context,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _remove_tools_from_items(items: tuple[RunItem, ...]) -> tuple[RunItem, ...]:
|
| 38 |
+
filtered_items = []
|
| 39 |
+
for item in items:
|
| 40 |
+
if (
|
| 41 |
+
isinstance(item, HandoffCallItem)
|
| 42 |
+
or isinstance(item, HandoffOutputItem)
|
| 43 |
+
or isinstance(item, ToolCallItem)
|
| 44 |
+
or isinstance(item, ToolCallOutputItem)
|
| 45 |
+
or isinstance(item, ReasoningItem)
|
| 46 |
+
):
|
| 47 |
+
continue
|
| 48 |
+
filtered_items.append(item)
|
| 49 |
+
return tuple(filtered_items)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _remove_tool_types_from_input(
|
| 53 |
+
items: tuple[TResponseInputItem, ...],
|
| 54 |
+
) -> tuple[TResponseInputItem, ...]:
|
| 55 |
+
tool_types = [
|
| 56 |
+
"function_call",
|
| 57 |
+
"function_call_output",
|
| 58 |
+
"computer_call",
|
| 59 |
+
"computer_call_output",
|
| 60 |
+
"file_search_call",
|
| 61 |
+
"web_search_call",
|
| 62 |
+
]
|
| 63 |
+
|
| 64 |
+
filtered_items: list[TResponseInputItem] = []
|
| 65 |
+
for item in items:
|
| 66 |
+
itype = item.get("type")
|
| 67 |
+
if itype in tool_types:
|
| 68 |
+
continue
|
| 69 |
+
filtered_items.append(item)
|
| 70 |
+
return tuple(filtered_items)
|
agents/extensions/handoff_prompt.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# A recommended prompt prefix for agents that use handoffs. We recommend including this or
|
| 2 |
+
# similar instructions in any agents that use handoffs.
|
| 3 |
+
RECOMMENDED_PROMPT_PREFIX = (
|
| 4 |
+
"# System context\n"
|
| 5 |
+
"You are part of a multi-agent system called the Agents SDK, designed to make agent "
|
| 6 |
+
"coordination and execution easy. Agents uses two primary abstraction: **Agents** and "
|
| 7 |
+
"**Handoffs**. An agent encompasses instructions and tools and can hand off a "
|
| 8 |
+
"conversation to another agent when appropriate. "
|
| 9 |
+
"Handoffs are achieved by calling a handoff function, generally named "
|
| 10 |
+
"`transfer_to_<agent_name>`. Transfers between agents are handled seamlessly in the background;"
|
| 11 |
+
" do not mention or draw attention to these transfers in your conversation with the user.\n"
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def prompt_with_handoff_instructions(prompt: str) -> str:
|
| 16 |
+
"""
|
| 17 |
+
Add recommended instructions to the prompt for agents that use handoffs.
|
| 18 |
+
"""
|
| 19 |
+
return f"{RECOMMENDED_PROMPT_PREFIX}\n\n{prompt}"
|
agents/extensions/memory/__init__.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Session memory backends living in the extensions namespace.
|
| 2 |
+
|
| 3 |
+
This package contains optional, production-grade session implementations that
|
| 4 |
+
introduce extra third-party dependencies (database drivers, ORMs, etc.). They
|
| 5 |
+
conform to the :class:`agents.memory.session.Session` protocol so they can be
|
| 6 |
+
used as a drop-in replacement for :class:`agents.memory.session.SQLiteSession`.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
from typing import Any
|
| 12 |
+
|
| 13 |
+
__all__: list[str] = [
|
| 14 |
+
"EncryptedSession",
|
| 15 |
+
"RedisSession",
|
| 16 |
+
"SQLAlchemySession",
|
| 17 |
+
"AdvancedSQLiteSession",
|
| 18 |
+
]
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def __getattr__(name: str) -> Any:
|
| 22 |
+
if name == "EncryptedSession":
|
| 23 |
+
try:
|
| 24 |
+
from .encrypt_session import EncryptedSession # noqa: F401
|
| 25 |
+
|
| 26 |
+
return EncryptedSession
|
| 27 |
+
except ModuleNotFoundError as e:
|
| 28 |
+
raise ImportError(
|
| 29 |
+
"EncryptedSession requires the 'cryptography' extra. "
|
| 30 |
+
"Install it with: pip install openai-agents[encrypt]"
|
| 31 |
+
) from e
|
| 32 |
+
|
| 33 |
+
if name == "RedisSession":
|
| 34 |
+
try:
|
| 35 |
+
from .redis_session import RedisSession # noqa: F401
|
| 36 |
+
|
| 37 |
+
return RedisSession
|
| 38 |
+
except ModuleNotFoundError as e:
|
| 39 |
+
raise ImportError(
|
| 40 |
+
"RedisSession requires the 'redis' extra. "
|
| 41 |
+
"Install it with: pip install openai-agents[redis]"
|
| 42 |
+
) from e
|
| 43 |
+
|
| 44 |
+
if name == "SQLAlchemySession":
|
| 45 |
+
try:
|
| 46 |
+
from .sqlalchemy_session import SQLAlchemySession # noqa: F401
|
| 47 |
+
|
| 48 |
+
return SQLAlchemySession
|
| 49 |
+
except ModuleNotFoundError as e:
|
| 50 |
+
raise ImportError(
|
| 51 |
+
"SQLAlchemySession requires the 'sqlalchemy' extra. "
|
| 52 |
+
"Install it with: pip install openai-agents[sqlalchemy]"
|
| 53 |
+
) from e
|
| 54 |
+
|
| 55 |
+
if name == "AdvancedSQLiteSession":
|
| 56 |
+
try:
|
| 57 |
+
from .advanced_sqlite_session import AdvancedSQLiteSession # noqa: F401
|
| 58 |
+
|
| 59 |
+
return AdvancedSQLiteSession
|
| 60 |
+
except ModuleNotFoundError as e:
|
| 61 |
+
raise ImportError(
|
| 62 |
+
f"Failed to import AdvancedSQLiteSession: {e}"
|
| 63 |
+
) from e
|
| 64 |
+
|
| 65 |
+
raise AttributeError(f"module {__name__} has no attribute {name}")
|
agents/extensions/memory/advanced_sqlite_session.py
ADDED
|
@@ -0,0 +1,1285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import asyncio
|
| 4 |
+
import json
|
| 5 |
+
import logging
|
| 6 |
+
import threading
|
| 7 |
+
from contextlib import closing
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Any, Union, cast
|
| 10 |
+
|
| 11 |
+
from agents.result import RunResult
|
| 12 |
+
from agents.usage import Usage
|
| 13 |
+
|
| 14 |
+
from ...items import TResponseInputItem
|
| 15 |
+
from ...memory import SQLiteSession
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class AdvancedSQLiteSession(SQLiteSession):
|
| 19 |
+
"""Enhanced SQLite session with conversation branching and usage analytics."""
|
| 20 |
+
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
*,
|
| 24 |
+
session_id: str,
|
| 25 |
+
db_path: str | Path = ":memory:",
|
| 26 |
+
create_tables: bool = False,
|
| 27 |
+
logger: logging.Logger | None = None,
|
| 28 |
+
**kwargs,
|
| 29 |
+
):
|
| 30 |
+
"""Initialize the AdvancedSQLiteSession.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
session_id: The ID of the session
|
| 34 |
+
db_path: The path to the SQLite database file. Defaults to `:memory:` for in-memory storage
|
| 35 |
+
create_tables: Whether to create the structure tables
|
| 36 |
+
logger: The logger to use. Defaults to the module logger
|
| 37 |
+
**kwargs: Additional keyword arguments to pass to the superclass
|
| 38 |
+
""" # noqa: E501
|
| 39 |
+
super().__init__(session_id, db_path, **kwargs)
|
| 40 |
+
if create_tables:
|
| 41 |
+
self._init_structure_tables()
|
| 42 |
+
self._current_branch_id = "main"
|
| 43 |
+
self._logger = logger or logging.getLogger(__name__)
|
| 44 |
+
|
| 45 |
+
def _init_structure_tables(self):
|
| 46 |
+
"""Add structure and usage tracking tables.
|
| 47 |
+
|
| 48 |
+
Creates the message_structure and turn_usage tables with appropriate
|
| 49 |
+
indexes for conversation branching and usage analytics.
|
| 50 |
+
"""
|
| 51 |
+
conn = self._get_connection()
|
| 52 |
+
|
| 53 |
+
# Message structure with branch support
|
| 54 |
+
conn.execute("""
|
| 55 |
+
CREATE TABLE IF NOT EXISTS message_structure (
|
| 56 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 57 |
+
session_id TEXT NOT NULL,
|
| 58 |
+
message_id INTEGER NOT NULL,
|
| 59 |
+
branch_id TEXT NOT NULL DEFAULT 'main',
|
| 60 |
+
message_type TEXT NOT NULL,
|
| 61 |
+
sequence_number INTEGER NOT NULL,
|
| 62 |
+
user_turn_number INTEGER,
|
| 63 |
+
branch_turn_number INTEGER,
|
| 64 |
+
tool_name TEXT,
|
| 65 |
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
| 66 |
+
FOREIGN KEY (session_id) REFERENCES agent_sessions(session_id) ON DELETE CASCADE,
|
| 67 |
+
FOREIGN KEY (message_id) REFERENCES agent_messages(id) ON DELETE CASCADE
|
| 68 |
+
)
|
| 69 |
+
""")
|
| 70 |
+
|
| 71 |
+
# Turn-level usage tracking with branch support and full JSON details
|
| 72 |
+
conn.execute("""
|
| 73 |
+
CREATE TABLE IF NOT EXISTS turn_usage (
|
| 74 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 75 |
+
session_id TEXT NOT NULL,
|
| 76 |
+
branch_id TEXT NOT NULL DEFAULT 'main',
|
| 77 |
+
user_turn_number INTEGER NOT NULL,
|
| 78 |
+
requests INTEGER DEFAULT 0,
|
| 79 |
+
input_tokens INTEGER DEFAULT 0,
|
| 80 |
+
output_tokens INTEGER DEFAULT 0,
|
| 81 |
+
total_tokens INTEGER DEFAULT 0,
|
| 82 |
+
input_tokens_details JSON,
|
| 83 |
+
output_tokens_details JSON,
|
| 84 |
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
| 85 |
+
FOREIGN KEY (session_id) REFERENCES agent_sessions(session_id) ON DELETE CASCADE,
|
| 86 |
+
UNIQUE(session_id, branch_id, user_turn_number)
|
| 87 |
+
)
|
| 88 |
+
""")
|
| 89 |
+
|
| 90 |
+
# Indexes
|
| 91 |
+
conn.execute("""
|
| 92 |
+
CREATE INDEX IF NOT EXISTS idx_structure_session_seq
|
| 93 |
+
ON message_structure(session_id, sequence_number)
|
| 94 |
+
""")
|
| 95 |
+
conn.execute("""
|
| 96 |
+
CREATE INDEX IF NOT EXISTS idx_structure_branch
|
| 97 |
+
ON message_structure(session_id, branch_id)
|
| 98 |
+
""")
|
| 99 |
+
conn.execute("""
|
| 100 |
+
CREATE INDEX IF NOT EXISTS idx_structure_turn
|
| 101 |
+
ON message_structure(session_id, branch_id, user_turn_number)
|
| 102 |
+
""")
|
| 103 |
+
conn.execute("""
|
| 104 |
+
CREATE INDEX IF NOT EXISTS idx_structure_branch_seq
|
| 105 |
+
ON message_structure(session_id, branch_id, sequence_number)
|
| 106 |
+
""")
|
| 107 |
+
conn.execute("""
|
| 108 |
+
CREATE INDEX IF NOT EXISTS idx_turn_usage_session_turn
|
| 109 |
+
ON turn_usage(session_id, branch_id, user_turn_number)
|
| 110 |
+
""")
|
| 111 |
+
|
| 112 |
+
conn.commit()
|
| 113 |
+
|
| 114 |
+
async def add_items(self, items: list[TResponseInputItem]) -> None:
|
| 115 |
+
"""Add items to the session.
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
items: The items to add to the session
|
| 119 |
+
"""
|
| 120 |
+
# Add to base table first
|
| 121 |
+
await super().add_items(items)
|
| 122 |
+
|
| 123 |
+
# Extract structure metadata with precise sequencing
|
| 124 |
+
if items:
|
| 125 |
+
await self._add_structure_metadata(items)
|
| 126 |
+
|
| 127 |
+
async def get_items(
|
| 128 |
+
self,
|
| 129 |
+
limit: int | None = None,
|
| 130 |
+
branch_id: str | None = None,
|
| 131 |
+
) -> list[TResponseInputItem]:
|
| 132 |
+
"""Get items from current or specified branch.
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
limit: Maximum number of items to return. If None, returns all items.
|
| 136 |
+
branch_id: Branch to get items from. If None, uses current branch.
|
| 137 |
+
|
| 138 |
+
Returns:
|
| 139 |
+
List of conversation items from the specified branch.
|
| 140 |
+
"""
|
| 141 |
+
if branch_id is None:
|
| 142 |
+
branch_id = self._current_branch_id
|
| 143 |
+
|
| 144 |
+
# Get all items for this branch
|
| 145 |
+
def _get_all_items_sync():
|
| 146 |
+
"""Synchronous helper to get all items for a branch."""
|
| 147 |
+
conn = self._get_connection()
|
| 148 |
+
# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
|
| 149 |
+
with self._lock if self._is_memory_db else threading.Lock():
|
| 150 |
+
with closing(conn.cursor()) as cursor:
|
| 151 |
+
if limit is None:
|
| 152 |
+
cursor.execute(
|
| 153 |
+
"""
|
| 154 |
+
SELECT m.message_data
|
| 155 |
+
FROM agent_messages m
|
| 156 |
+
JOIN message_structure s ON m.id = s.message_id
|
| 157 |
+
WHERE m.session_id = ? AND s.branch_id = ?
|
| 158 |
+
ORDER BY s.sequence_number ASC
|
| 159 |
+
""",
|
| 160 |
+
(self.session_id, branch_id),
|
| 161 |
+
)
|
| 162 |
+
else:
|
| 163 |
+
cursor.execute(
|
| 164 |
+
"""
|
| 165 |
+
SELECT m.message_data
|
| 166 |
+
FROM agent_messages m
|
| 167 |
+
JOIN message_structure s ON m.id = s.message_id
|
| 168 |
+
WHERE m.session_id = ? AND s.branch_id = ?
|
| 169 |
+
ORDER BY s.sequence_number DESC
|
| 170 |
+
LIMIT ?
|
| 171 |
+
""",
|
| 172 |
+
(self.session_id, branch_id, limit),
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
rows = cursor.fetchall()
|
| 176 |
+
if limit is not None:
|
| 177 |
+
rows = list(reversed(rows))
|
| 178 |
+
|
| 179 |
+
items = []
|
| 180 |
+
for (message_data,) in rows:
|
| 181 |
+
try:
|
| 182 |
+
item = json.loads(message_data)
|
| 183 |
+
items.append(item)
|
| 184 |
+
except json.JSONDecodeError:
|
| 185 |
+
continue
|
| 186 |
+
return items
|
| 187 |
+
|
| 188 |
+
return await asyncio.to_thread(_get_all_items_sync)
|
| 189 |
+
|
| 190 |
+
def _get_items_sync():
|
| 191 |
+
"""Synchronous helper to get items for a specific branch."""
|
| 192 |
+
conn = self._get_connection()
|
| 193 |
+
# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
|
| 194 |
+
with self._lock if self._is_memory_db else threading.Lock():
|
| 195 |
+
with closing(conn.cursor()) as cursor:
|
| 196 |
+
# Get message IDs in correct order for this branch
|
| 197 |
+
if limit is None:
|
| 198 |
+
cursor.execute(
|
| 199 |
+
"""
|
| 200 |
+
SELECT m.message_data
|
| 201 |
+
FROM agent_messages m
|
| 202 |
+
JOIN message_structure s ON m.id = s.message_id
|
| 203 |
+
WHERE m.session_id = ? AND s.branch_id = ?
|
| 204 |
+
ORDER BY s.sequence_number ASC
|
| 205 |
+
""",
|
| 206 |
+
(self.session_id, branch_id),
|
| 207 |
+
)
|
| 208 |
+
else:
|
| 209 |
+
cursor.execute(
|
| 210 |
+
"""
|
| 211 |
+
SELECT m.message_data
|
| 212 |
+
FROM agent_messages m
|
| 213 |
+
JOIN message_structure s ON m.id = s.message_id
|
| 214 |
+
WHERE m.session_id = ? AND s.branch_id = ?
|
| 215 |
+
ORDER BY s.sequence_number DESC
|
| 216 |
+
LIMIT ?
|
| 217 |
+
""",
|
| 218 |
+
(self.session_id, branch_id, limit),
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
rows = cursor.fetchall()
|
| 222 |
+
if limit is not None:
|
| 223 |
+
rows = list(reversed(rows))
|
| 224 |
+
|
| 225 |
+
items = []
|
| 226 |
+
for (message_data,) in rows:
|
| 227 |
+
try:
|
| 228 |
+
item = json.loads(message_data)
|
| 229 |
+
items.append(item)
|
| 230 |
+
except json.JSONDecodeError:
|
| 231 |
+
continue
|
| 232 |
+
return items
|
| 233 |
+
|
| 234 |
+
return await asyncio.to_thread(_get_items_sync)
|
| 235 |
+
|
| 236 |
+
async def store_run_usage(self, result: RunResult) -> None:
|
| 237 |
+
"""Store usage data for the current conversation turn.
|
| 238 |
+
|
| 239 |
+
This is designed to be called after `Runner.run()` completes.
|
| 240 |
+
Session-level usage can be aggregated from turn data when needed.
|
| 241 |
+
|
| 242 |
+
Args:
|
| 243 |
+
result: The result from the run
|
| 244 |
+
"""
|
| 245 |
+
try:
|
| 246 |
+
if result.context_wrapper.usage is not None:
|
| 247 |
+
# Get the current turn number for this branch
|
| 248 |
+
current_turn = self._get_current_turn_number()
|
| 249 |
+
# Only update turn-level usage - session usage is aggregated on demand
|
| 250 |
+
await self._update_turn_usage_internal(current_turn, result.context_wrapper.usage)
|
| 251 |
+
except Exception as e:
|
| 252 |
+
self._logger.error(f"Failed to store usage for session {self.session_id}: {e}")
|
| 253 |
+
|
| 254 |
+
def _get_next_turn_number(self, branch_id: str) -> int:
|
| 255 |
+
"""Get the next turn number for a specific branch.
|
| 256 |
+
|
| 257 |
+
Args:
|
| 258 |
+
branch_id: The branch ID to get the next turn number for.
|
| 259 |
+
|
| 260 |
+
Returns:
|
| 261 |
+
The next available turn number for the specified branch.
|
| 262 |
+
"""
|
| 263 |
+
conn = self._get_connection()
|
| 264 |
+
with closing(conn.cursor()) as cursor:
|
| 265 |
+
cursor.execute(
|
| 266 |
+
"""
|
| 267 |
+
SELECT COALESCE(MAX(user_turn_number), 0)
|
| 268 |
+
FROM message_structure
|
| 269 |
+
WHERE session_id = ? AND branch_id = ?
|
| 270 |
+
""",
|
| 271 |
+
(self.session_id, branch_id),
|
| 272 |
+
)
|
| 273 |
+
result = cursor.fetchone()
|
| 274 |
+
max_turn = result[0] if result else 0
|
| 275 |
+
return max_turn + 1
|
| 276 |
+
|
| 277 |
+
def _get_next_branch_turn_number(self, branch_id: str) -> int:
|
| 278 |
+
"""Get the next branch turn number for a specific branch.
|
| 279 |
+
|
| 280 |
+
Args:
|
| 281 |
+
branch_id: The branch ID to get the next branch turn number for.
|
| 282 |
+
|
| 283 |
+
Returns:
|
| 284 |
+
The next available branch turn number for the specified branch.
|
| 285 |
+
"""
|
| 286 |
+
conn = self._get_connection()
|
| 287 |
+
with closing(conn.cursor()) as cursor:
|
| 288 |
+
cursor.execute(
|
| 289 |
+
"""
|
| 290 |
+
SELECT COALESCE(MAX(branch_turn_number), 0)
|
| 291 |
+
FROM message_structure
|
| 292 |
+
WHERE session_id = ? AND branch_id = ?
|
| 293 |
+
""",
|
| 294 |
+
(self.session_id, branch_id),
|
| 295 |
+
)
|
| 296 |
+
result = cursor.fetchone()
|
| 297 |
+
max_turn = result[0] if result else 0
|
| 298 |
+
return max_turn + 1
|
| 299 |
+
|
| 300 |
+
def _get_current_turn_number(self) -> int:
|
| 301 |
+
"""Get the current turn number for the current branch.
|
| 302 |
+
|
| 303 |
+
Returns:
|
| 304 |
+
The current turn number for the active branch.
|
| 305 |
+
"""
|
| 306 |
+
conn = self._get_connection()
|
| 307 |
+
with closing(conn.cursor()) as cursor:
|
| 308 |
+
cursor.execute(
|
| 309 |
+
"""
|
| 310 |
+
SELECT COALESCE(MAX(user_turn_number), 0)
|
| 311 |
+
FROM message_structure
|
| 312 |
+
WHERE session_id = ? AND branch_id = ?
|
| 313 |
+
""",
|
| 314 |
+
(self.session_id, self._current_branch_id),
|
| 315 |
+
)
|
| 316 |
+
result = cursor.fetchone()
|
| 317 |
+
return result[0] if result else 0
|
| 318 |
+
|
| 319 |
+
async def _add_structure_metadata(self, items: list[TResponseInputItem]) -> None:
|
| 320 |
+
"""Extract structure metadata with branch-aware turn tracking.
|
| 321 |
+
|
| 322 |
+
This method:
|
| 323 |
+
- Assigns turn numbers per branch (not globally)
|
| 324 |
+
- Assigns explicit sequence numbers for precise ordering
|
| 325 |
+
- Links messages to their database IDs for structure tracking
|
| 326 |
+
- Handles multiple user messages in a single batch correctly
|
| 327 |
+
|
| 328 |
+
Args:
|
| 329 |
+
items: The items to add to the session
|
| 330 |
+
"""
|
| 331 |
+
|
| 332 |
+
def _add_structure_sync():
|
| 333 |
+
"""Synchronous helper to add structure metadata to database."""
|
| 334 |
+
conn = self._get_connection()
|
| 335 |
+
# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
|
| 336 |
+
with self._lock if self._is_memory_db else threading.Lock():
|
| 337 |
+
# Get the IDs of messages we just inserted, in order
|
| 338 |
+
with closing(conn.cursor()) as cursor:
|
| 339 |
+
cursor.execute(
|
| 340 |
+
f"SELECT id FROM {self.messages_table} "
|
| 341 |
+
f"WHERE session_id = ? ORDER BY id DESC LIMIT ?",
|
| 342 |
+
(self.session_id, len(items)),
|
| 343 |
+
)
|
| 344 |
+
message_ids = [row[0] for row in cursor.fetchall()]
|
| 345 |
+
message_ids.reverse() # Match order of items
|
| 346 |
+
|
| 347 |
+
# Get current max sequence number (global)
|
| 348 |
+
with closing(conn.cursor()) as cursor:
|
| 349 |
+
cursor.execute(
|
| 350 |
+
"""
|
| 351 |
+
SELECT COALESCE(MAX(sequence_number), 0)
|
| 352 |
+
FROM message_structure
|
| 353 |
+
WHERE session_id = ?
|
| 354 |
+
""",
|
| 355 |
+
(self.session_id,),
|
| 356 |
+
)
|
| 357 |
+
seq_start = cursor.fetchone()[0]
|
| 358 |
+
|
| 359 |
+
# Get current turn numbers atomically with a single query
|
| 360 |
+
with closing(conn.cursor()) as cursor:
|
| 361 |
+
cursor.execute(
|
| 362 |
+
"""
|
| 363 |
+
SELECT
|
| 364 |
+
COALESCE(MAX(user_turn_number), 0) as max_global_turn,
|
| 365 |
+
COALESCE(MAX(branch_turn_number), 0) as max_branch_turn
|
| 366 |
+
FROM message_structure
|
| 367 |
+
WHERE session_id = ? AND branch_id = ?
|
| 368 |
+
""",
|
| 369 |
+
(self.session_id, self._current_branch_id),
|
| 370 |
+
)
|
| 371 |
+
result = cursor.fetchone()
|
| 372 |
+
current_turn = result[0] if result else 0
|
| 373 |
+
current_branch_turn = result[1] if result else 0
|
| 374 |
+
|
| 375 |
+
# Process items and assign turn numbers correctly
|
| 376 |
+
structure_data = []
|
| 377 |
+
user_message_count = 0
|
| 378 |
+
|
| 379 |
+
for i, (item, msg_id) in enumerate(zip(items, message_ids)):
|
| 380 |
+
msg_type = self._classify_message_type(item)
|
| 381 |
+
tool_name = self._extract_tool_name(item)
|
| 382 |
+
|
| 383 |
+
# If this is a user message, increment turn counters
|
| 384 |
+
if self._is_user_message(item):
|
| 385 |
+
user_message_count += 1
|
| 386 |
+
item_turn = current_turn + user_message_count
|
| 387 |
+
item_branch_turn = current_branch_turn + user_message_count
|
| 388 |
+
else:
|
| 389 |
+
# Non-user messages inherit the turn number of the most recent user message
|
| 390 |
+
item_turn = current_turn + user_message_count
|
| 391 |
+
item_branch_turn = current_branch_turn + user_message_count
|
| 392 |
+
|
| 393 |
+
structure_data.append(
|
| 394 |
+
(
|
| 395 |
+
self.session_id,
|
| 396 |
+
msg_id,
|
| 397 |
+
self._current_branch_id,
|
| 398 |
+
msg_type,
|
| 399 |
+
seq_start + i + 1, # Global sequence
|
| 400 |
+
item_turn, # Global turn number
|
| 401 |
+
item_branch_turn, # Branch-specific turn number
|
| 402 |
+
tool_name,
|
| 403 |
+
)
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
with closing(conn.cursor()) as cursor:
|
| 407 |
+
cursor.executemany(
|
| 408 |
+
"""
|
| 409 |
+
INSERT INTO message_structure
|
| 410 |
+
(session_id, message_id, branch_id, message_type, sequence_number,
|
| 411 |
+
user_turn_number, branch_turn_number, tool_name)
|
| 412 |
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
| 413 |
+
""",
|
| 414 |
+
structure_data,
|
| 415 |
+
)
|
| 416 |
+
conn.commit()
|
| 417 |
+
|
| 418 |
+
try:
|
| 419 |
+
await asyncio.to_thread(_add_structure_sync)
|
| 420 |
+
except Exception as e:
|
| 421 |
+
self._logger.error(
|
| 422 |
+
f"Failed to add structure metadata for session {self.session_id}: {e}"
|
| 423 |
+
)
|
| 424 |
+
# Try to clean up any orphaned messages to maintain consistency
|
| 425 |
+
try:
|
| 426 |
+
await self._cleanup_orphaned_messages()
|
| 427 |
+
except Exception as cleanup_error:
|
| 428 |
+
self._logger.error(f"Failed to cleanup orphaned messages: {cleanup_error}")
|
| 429 |
+
# Don't re-raise - structure metadata is supplementary
|
| 430 |
+
|
| 431 |
+
async def _cleanup_orphaned_messages(self) -> None:
|
| 432 |
+
"""Remove messages that exist in agent_messages but not in message_structure.
|
| 433 |
+
|
| 434 |
+
This can happen if _add_structure_metadata fails after super().add_items() succeeds.
|
| 435 |
+
Used for maintaining data consistency.
|
| 436 |
+
"""
|
| 437 |
+
|
| 438 |
+
def _cleanup_sync():
|
| 439 |
+
"""Synchronous helper to cleanup orphaned messages."""
|
| 440 |
+
conn = self._get_connection()
|
| 441 |
+
# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
|
| 442 |
+
with self._lock if self._is_memory_db else threading.Lock():
|
| 443 |
+
with closing(conn.cursor()) as cursor:
|
| 444 |
+
# Find messages without structure metadata
|
| 445 |
+
cursor.execute(
|
| 446 |
+
"""
|
| 447 |
+
SELECT am.id
|
| 448 |
+
FROM agent_messages am
|
| 449 |
+
LEFT JOIN message_structure ms ON am.id = ms.message_id
|
| 450 |
+
WHERE am.session_id = ? AND ms.message_id IS NULL
|
| 451 |
+
""",
|
| 452 |
+
(self.session_id,),
|
| 453 |
+
)
|
| 454 |
+
|
| 455 |
+
orphaned_ids = [row[0] for row in cursor.fetchall()]
|
| 456 |
+
|
| 457 |
+
if orphaned_ids:
|
| 458 |
+
# Delete orphaned messages
|
| 459 |
+
placeholders = ",".join("?" * len(orphaned_ids))
|
| 460 |
+
cursor.execute(
|
| 461 |
+
f"DELETE FROM agent_messages WHERE id IN ({placeholders})", orphaned_ids
|
| 462 |
+
)
|
| 463 |
+
|
| 464 |
+
deleted_count = cursor.rowcount
|
| 465 |
+
conn.commit()
|
| 466 |
+
|
| 467 |
+
self._logger.info(f"Cleaned up {deleted_count} orphaned messages")
|
| 468 |
+
return deleted_count
|
| 469 |
+
|
| 470 |
+
return 0
|
| 471 |
+
|
| 472 |
+
return await asyncio.to_thread(_cleanup_sync)
|
| 473 |
+
|
| 474 |
+
def _classify_message_type(self, item: TResponseInputItem) -> str:
|
| 475 |
+
"""Classify the type of a message item.
|
| 476 |
+
|
| 477 |
+
Args:
|
| 478 |
+
item: The message item to classify.
|
| 479 |
+
|
| 480 |
+
Returns:
|
| 481 |
+
String representing the message type (user, assistant, etc.).
|
| 482 |
+
"""
|
| 483 |
+
if isinstance(item, dict):
|
| 484 |
+
if item.get("role") == "user":
|
| 485 |
+
return "user"
|
| 486 |
+
elif item.get("role") == "assistant":
|
| 487 |
+
return "assistant"
|
| 488 |
+
elif item.get("type"):
|
| 489 |
+
return str(item.get("type"))
|
| 490 |
+
return "other"
|
| 491 |
+
|
| 492 |
+
def _extract_tool_name(self, item: TResponseInputItem) -> str | None:
|
| 493 |
+
"""Extract tool name if this is a tool call/output.
|
| 494 |
+
|
| 495 |
+
Args:
|
| 496 |
+
item: The message item to extract tool name from.
|
| 497 |
+
|
| 498 |
+
Returns:
|
| 499 |
+
Tool name if item is a tool call, None otherwise.
|
| 500 |
+
"""
|
| 501 |
+
if isinstance(item, dict):
|
| 502 |
+
item_type = item.get("type")
|
| 503 |
+
|
| 504 |
+
# For MCP tools, try to extract from server_label if available
|
| 505 |
+
if item_type in {"mcp_call", "mcp_approval_request"} and "server_label" in item:
|
| 506 |
+
server_label = item.get("server_label")
|
| 507 |
+
tool_name = item.get("name")
|
| 508 |
+
if tool_name and server_label:
|
| 509 |
+
return f"{server_label}.{tool_name}"
|
| 510 |
+
elif server_label:
|
| 511 |
+
return str(server_label)
|
| 512 |
+
elif tool_name:
|
| 513 |
+
return str(tool_name)
|
| 514 |
+
|
| 515 |
+
# For tool types without a 'name' field, derive from the type
|
| 516 |
+
elif item_type in {
|
| 517 |
+
"computer_call",
|
| 518 |
+
"file_search_call",
|
| 519 |
+
"web_search_call",
|
| 520 |
+
"code_interpreter_call",
|
| 521 |
+
}:
|
| 522 |
+
return item_type
|
| 523 |
+
|
| 524 |
+
# Most other tool calls have a 'name' field
|
| 525 |
+
elif "name" in item:
|
| 526 |
+
name = item.get("name")
|
| 527 |
+
return str(name) if name is not None else None
|
| 528 |
+
|
| 529 |
+
return None
|
| 530 |
+
|
| 531 |
+
def _is_user_message(self, item: TResponseInputItem) -> bool:
|
| 532 |
+
"""Check if this is a user message.
|
| 533 |
+
|
| 534 |
+
Args:
|
| 535 |
+
item: The message item to check.
|
| 536 |
+
|
| 537 |
+
Returns:
|
| 538 |
+
True if the item is a user message, False otherwise.
|
| 539 |
+
"""
|
| 540 |
+
return isinstance(item, dict) and item.get("role") == "user"
|
| 541 |
+
|
| 542 |
+
async def create_branch_from_turn(
|
| 543 |
+
self, turn_number: int, branch_name: str | None = None
|
| 544 |
+
) -> str:
|
| 545 |
+
"""Create a new branch starting from a specific user message turn.
|
| 546 |
+
|
| 547 |
+
Args:
|
| 548 |
+
turn_number: The branch turn number of the user message to branch from
|
| 549 |
+
branch_name: Optional name for the branch (auto-generated if None)
|
| 550 |
+
|
| 551 |
+
Returns:
|
| 552 |
+
The branch_id of the newly created branch
|
| 553 |
+
|
| 554 |
+
Raises:
|
| 555 |
+
ValueError: If turn doesn't exist or doesn't contain a user message
|
| 556 |
+
"""
|
| 557 |
+
import time
|
| 558 |
+
|
| 559 |
+
# Validate the turn exists and contains a user message
|
| 560 |
+
def _validate_turn():
|
| 561 |
+
"""Synchronous helper to validate turn exists and contains user message."""
|
| 562 |
+
conn = self._get_connection()
|
| 563 |
+
with closing(conn.cursor()) as cursor:
|
| 564 |
+
cursor.execute(
|
| 565 |
+
"""
|
| 566 |
+
SELECT am.message_data
|
| 567 |
+
FROM message_structure ms
|
| 568 |
+
JOIN agent_messages am ON ms.message_id = am.id
|
| 569 |
+
WHERE ms.session_id = ? AND ms.branch_id = ?
|
| 570 |
+
AND ms.branch_turn_number = ? AND ms.message_type = 'user'
|
| 571 |
+
""",
|
| 572 |
+
(self.session_id, self._current_branch_id, turn_number),
|
| 573 |
+
)
|
| 574 |
+
|
| 575 |
+
result = cursor.fetchone()
|
| 576 |
+
if not result:
|
| 577 |
+
raise ValueError(
|
| 578 |
+
f"Turn {turn_number} does not contain a user message "
|
| 579 |
+
f"in branch '{self._current_branch_id}'"
|
| 580 |
+
)
|
| 581 |
+
|
| 582 |
+
message_data = result[0]
|
| 583 |
+
try:
|
| 584 |
+
content = json.loads(message_data).get("content", "")
|
| 585 |
+
return content[:50] + "..." if len(content) > 50 else content
|
| 586 |
+
except Exception:
|
| 587 |
+
return "Unable to parse content"
|
| 588 |
+
|
| 589 |
+
turn_content = await asyncio.to_thread(_validate_turn)
|
| 590 |
+
|
| 591 |
+
# Generate branch name if not provided
|
| 592 |
+
if branch_name is None:
|
| 593 |
+
timestamp = int(time.time())
|
| 594 |
+
branch_name = f"branch_from_turn_{turn_number}_{timestamp}"
|
| 595 |
+
|
| 596 |
+
# Copy messages before the branch point to the new branch
|
| 597 |
+
await self._copy_messages_to_new_branch(branch_name, turn_number)
|
| 598 |
+
|
| 599 |
+
# Switch to new branch
|
| 600 |
+
old_branch = self._current_branch_id
|
| 601 |
+
self._current_branch_id = branch_name
|
| 602 |
+
|
| 603 |
+
self._logger.debug(
|
| 604 |
+
f"Created branch '{branch_name}' from turn {turn_number} ('{turn_content}') in '{old_branch}'" # noqa: E501
|
| 605 |
+
)
|
| 606 |
+
return branch_name
|
| 607 |
+
|
| 608 |
+
async def create_branch_from_content(
|
| 609 |
+
self, search_term: str, branch_name: str | None = None
|
| 610 |
+
) -> str:
|
| 611 |
+
"""Create branch from the first user turn matching the search term.
|
| 612 |
+
|
| 613 |
+
Args:
|
| 614 |
+
search_term: Text to search for in user messages.
|
| 615 |
+
branch_name: Optional name for the branch (auto-generated if None).
|
| 616 |
+
|
| 617 |
+
Returns:
|
| 618 |
+
The branch_id of the newly created branch.
|
| 619 |
+
|
| 620 |
+
Raises:
|
| 621 |
+
ValueError: If no matching turns are found.
|
| 622 |
+
"""
|
| 623 |
+
matching_turns = await self.find_turns_by_content(search_term)
|
| 624 |
+
if not matching_turns:
|
| 625 |
+
raise ValueError(f"No user turns found containing '{search_term}'")
|
| 626 |
+
|
| 627 |
+
# Use the first (earliest) match
|
| 628 |
+
turn_number = matching_turns[0]["turn"]
|
| 629 |
+
return await self.create_branch_from_turn(turn_number, branch_name)
|
| 630 |
+
|
| 631 |
+
async def switch_to_branch(self, branch_id: str) -> None:
|
| 632 |
+
"""Switch to a different branch.
|
| 633 |
+
|
| 634 |
+
Args:
|
| 635 |
+
branch_id: The branch to switch to.
|
| 636 |
+
|
| 637 |
+
Raises:
|
| 638 |
+
ValueError: If the branch doesn't exist.
|
| 639 |
+
"""
|
| 640 |
+
|
| 641 |
+
# Validate branch exists
|
| 642 |
+
def _validate_branch():
|
| 643 |
+
"""Synchronous helper to validate branch exists."""
|
| 644 |
+
conn = self._get_connection()
|
| 645 |
+
with closing(conn.cursor()) as cursor:
|
| 646 |
+
cursor.execute(
|
| 647 |
+
"""
|
| 648 |
+
SELECT COUNT(*) FROM message_structure
|
| 649 |
+
WHERE session_id = ? AND branch_id = ?
|
| 650 |
+
""",
|
| 651 |
+
(self.session_id, branch_id),
|
| 652 |
+
)
|
| 653 |
+
|
| 654 |
+
count = cursor.fetchone()[0]
|
| 655 |
+
if count == 0:
|
| 656 |
+
raise ValueError(f"Branch '{branch_id}' does not exist")
|
| 657 |
+
|
| 658 |
+
await asyncio.to_thread(_validate_branch)
|
| 659 |
+
|
| 660 |
+
old_branch = self._current_branch_id
|
| 661 |
+
self._current_branch_id = branch_id
|
| 662 |
+
self._logger.info(f"Switched from branch '{old_branch}' to '{branch_id}'")
|
| 663 |
+
|
| 664 |
+
async def delete_branch(self, branch_id: str, force: bool = False) -> None:
|
| 665 |
+
"""Delete a branch and all its associated data.
|
| 666 |
+
|
| 667 |
+
Args:
|
| 668 |
+
branch_id: The branch to delete.
|
| 669 |
+
force: If True, allows deleting the current branch (will switch to 'main').
|
| 670 |
+
|
| 671 |
+
Raises:
|
| 672 |
+
ValueError: If branch doesn't exist, is 'main', or is current branch without force.
|
| 673 |
+
"""
|
| 674 |
+
if not branch_id or not branch_id.strip():
|
| 675 |
+
raise ValueError("Branch ID cannot be empty")
|
| 676 |
+
|
| 677 |
+
branch_id = branch_id.strip()
|
| 678 |
+
|
| 679 |
+
# Protect main branch
|
| 680 |
+
if branch_id == "main":
|
| 681 |
+
raise ValueError("Cannot delete the 'main' branch")
|
| 682 |
+
|
| 683 |
+
# Check if trying to delete current branch
|
| 684 |
+
if branch_id == self._current_branch_id:
|
| 685 |
+
if not force:
|
| 686 |
+
raise ValueError(
|
| 687 |
+
f"Cannot delete current branch '{branch_id}'. Use force=True or switch branches first" # noqa: E501
|
| 688 |
+
)
|
| 689 |
+
else:
|
| 690 |
+
# Switch to main before deleting
|
| 691 |
+
await self.switch_to_branch("main")
|
| 692 |
+
|
| 693 |
+
def _delete_sync():
|
| 694 |
+
"""Synchronous helper to delete branch and associated data."""
|
| 695 |
+
conn = self._get_connection()
|
| 696 |
+
# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
|
| 697 |
+
with self._lock if self._is_memory_db else threading.Lock():
|
| 698 |
+
with closing(conn.cursor()) as cursor:
|
| 699 |
+
# First verify the branch exists
|
| 700 |
+
cursor.execute(
|
| 701 |
+
"""
|
| 702 |
+
SELECT COUNT(*) FROM message_structure
|
| 703 |
+
WHERE session_id = ? AND branch_id = ?
|
| 704 |
+
""",
|
| 705 |
+
(self.session_id, branch_id),
|
| 706 |
+
)
|
| 707 |
+
|
| 708 |
+
count = cursor.fetchone()[0]
|
| 709 |
+
if count == 0:
|
| 710 |
+
raise ValueError(f"Branch '{branch_id}' does not exist")
|
| 711 |
+
|
| 712 |
+
# Delete from turn_usage first (foreign key constraint)
|
| 713 |
+
cursor.execute(
|
| 714 |
+
"""
|
| 715 |
+
DELETE FROM turn_usage
|
| 716 |
+
WHERE session_id = ? AND branch_id = ?
|
| 717 |
+
""",
|
| 718 |
+
(self.session_id, branch_id),
|
| 719 |
+
)
|
| 720 |
+
|
| 721 |
+
usage_deleted = cursor.rowcount
|
| 722 |
+
|
| 723 |
+
# Delete from message_structure
|
| 724 |
+
cursor.execute(
|
| 725 |
+
"""
|
| 726 |
+
DELETE FROM message_structure
|
| 727 |
+
WHERE session_id = ? AND branch_id = ?
|
| 728 |
+
""",
|
| 729 |
+
(self.session_id, branch_id),
|
| 730 |
+
)
|
| 731 |
+
|
| 732 |
+
structure_deleted = cursor.rowcount
|
| 733 |
+
|
| 734 |
+
conn.commit()
|
| 735 |
+
|
| 736 |
+
return usage_deleted, structure_deleted
|
| 737 |
+
|
| 738 |
+
usage_deleted, structure_deleted = await asyncio.to_thread(_delete_sync)
|
| 739 |
+
|
| 740 |
+
self._logger.info(
|
| 741 |
+
f"Deleted branch '{branch_id}': {structure_deleted} message entries, {usage_deleted} usage entries" # noqa: E501
|
| 742 |
+
)
|
| 743 |
+
|
| 744 |
+
async def list_branches(self) -> list[dict[str, Any]]:
|
| 745 |
+
"""List all branches in this session.
|
| 746 |
+
|
| 747 |
+
Returns:
|
| 748 |
+
List of dicts with branch info containing:
|
| 749 |
+
- 'branch_id': Branch identifier
|
| 750 |
+
- 'message_count': Number of messages in branch
|
| 751 |
+
- 'user_turns': Number of user turns in branch
|
| 752 |
+
- 'is_current': Whether this is the current branch
|
| 753 |
+
- 'created_at': When the branch was first created
|
| 754 |
+
"""
|
| 755 |
+
|
| 756 |
+
def _list_branches_sync():
|
| 757 |
+
"""Synchronous helper to list all branches."""
|
| 758 |
+
conn = self._get_connection()
|
| 759 |
+
with closing(conn.cursor()) as cursor:
|
| 760 |
+
cursor.execute(
|
| 761 |
+
"""
|
| 762 |
+
SELECT
|
| 763 |
+
ms.branch_id,
|
| 764 |
+
COUNT(*) as message_count,
|
| 765 |
+
COUNT(CASE WHEN ms.message_type = 'user' THEN 1 END) as user_turns,
|
| 766 |
+
MIN(ms.created_at) as created_at
|
| 767 |
+
FROM message_structure ms
|
| 768 |
+
WHERE ms.session_id = ?
|
| 769 |
+
GROUP BY ms.branch_id
|
| 770 |
+
ORDER BY created_at
|
| 771 |
+
""",
|
| 772 |
+
(self.session_id,),
|
| 773 |
+
)
|
| 774 |
+
|
| 775 |
+
branches = []
|
| 776 |
+
for row in cursor.fetchall():
|
| 777 |
+
branch_id, msg_count, user_turns, created_at = row
|
| 778 |
+
branches.append(
|
| 779 |
+
{
|
| 780 |
+
"branch_id": branch_id,
|
| 781 |
+
"message_count": msg_count,
|
| 782 |
+
"user_turns": user_turns,
|
| 783 |
+
"is_current": branch_id == self._current_branch_id,
|
| 784 |
+
"created_at": created_at,
|
| 785 |
+
}
|
| 786 |
+
)
|
| 787 |
+
|
| 788 |
+
return branches
|
| 789 |
+
|
| 790 |
+
return await asyncio.to_thread(_list_branches_sync)
|
| 791 |
+
|
| 792 |
+
async def _copy_messages_to_new_branch(self, new_branch_id: str, from_turn_number: int) -> None:
|
| 793 |
+
"""Copy messages before the branch point to the new branch.
|
| 794 |
+
|
| 795 |
+
Args:
|
| 796 |
+
new_branch_id: The ID of the new branch to copy messages to.
|
| 797 |
+
from_turn_number: The turn number to copy messages up to (exclusive).
|
| 798 |
+
"""
|
| 799 |
+
|
| 800 |
+
def _copy_sync():
|
| 801 |
+
"""Synchronous helper to copy messages to new branch."""
|
| 802 |
+
conn = self._get_connection()
|
| 803 |
+
# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
|
| 804 |
+
with self._lock if self._is_memory_db else threading.Lock():
|
| 805 |
+
with closing(conn.cursor()) as cursor:
|
| 806 |
+
# Get all messages before the branch point
|
| 807 |
+
cursor.execute(
|
| 808 |
+
"""
|
| 809 |
+
SELECT
|
| 810 |
+
ms.message_id,
|
| 811 |
+
ms.message_type,
|
| 812 |
+
ms.sequence_number,
|
| 813 |
+
ms.user_turn_number,
|
| 814 |
+
ms.branch_turn_number,
|
| 815 |
+
ms.tool_name
|
| 816 |
+
FROM message_structure ms
|
| 817 |
+
WHERE ms.session_id = ? AND ms.branch_id = ?
|
| 818 |
+
AND ms.branch_turn_number < ?
|
| 819 |
+
ORDER BY ms.sequence_number
|
| 820 |
+
""",
|
| 821 |
+
(self.session_id, self._current_branch_id, from_turn_number),
|
| 822 |
+
)
|
| 823 |
+
|
| 824 |
+
messages_to_copy = cursor.fetchall()
|
| 825 |
+
|
| 826 |
+
if messages_to_copy:
|
| 827 |
+
# Get the max sequence number for the new inserts
|
| 828 |
+
cursor.execute(
|
| 829 |
+
"""
|
| 830 |
+
SELECT COALESCE(MAX(sequence_number), 0)
|
| 831 |
+
FROM message_structure
|
| 832 |
+
WHERE session_id = ?
|
| 833 |
+
""",
|
| 834 |
+
(self.session_id,),
|
| 835 |
+
)
|
| 836 |
+
|
| 837 |
+
seq_start = cursor.fetchone()[0]
|
| 838 |
+
|
| 839 |
+
# Insert copied messages with new branch_id
|
| 840 |
+
new_structure_data = []
|
| 841 |
+
for i, (
|
| 842 |
+
msg_id,
|
| 843 |
+
msg_type,
|
| 844 |
+
_,
|
| 845 |
+
user_turn,
|
| 846 |
+
branch_turn,
|
| 847 |
+
tool_name,
|
| 848 |
+
) in enumerate(messages_to_copy):
|
| 849 |
+
new_structure_data.append(
|
| 850 |
+
(
|
| 851 |
+
self.session_id,
|
| 852 |
+
msg_id, # Same message_id (sharing the actual message data)
|
| 853 |
+
new_branch_id,
|
| 854 |
+
msg_type,
|
| 855 |
+
seq_start + i + 1, # New sequence number
|
| 856 |
+
user_turn, # Keep same global turn number
|
| 857 |
+
branch_turn, # Keep same branch turn number
|
| 858 |
+
tool_name,
|
| 859 |
+
)
|
| 860 |
+
)
|
| 861 |
+
|
| 862 |
+
cursor.executemany(
|
| 863 |
+
"""
|
| 864 |
+
INSERT INTO message_structure
|
| 865 |
+
(session_id, message_id, branch_id, message_type, sequence_number,
|
| 866 |
+
user_turn_number, branch_turn_number, tool_name)
|
| 867 |
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
| 868 |
+
""",
|
| 869 |
+
new_structure_data,
|
| 870 |
+
)
|
| 871 |
+
|
| 872 |
+
conn.commit()
|
| 873 |
+
|
| 874 |
+
await asyncio.to_thread(_copy_sync)
|
| 875 |
+
|
| 876 |
+
async def get_conversation_turns(self, branch_id: str | None = None) -> list[dict[str, Any]]:
|
| 877 |
+
"""Get user turns with content for easy browsing and branching decisions.
|
| 878 |
+
|
| 879 |
+
Args:
|
| 880 |
+
branch_id: Branch to get turns from (current branch if None).
|
| 881 |
+
|
| 882 |
+
Returns:
|
| 883 |
+
List of dicts with turn info containing:
|
| 884 |
+
- 'turn': Branch turn number
|
| 885 |
+
- 'content': User message content (truncated)
|
| 886 |
+
- 'full_content': Full user message content
|
| 887 |
+
- 'timestamp': When the turn was created
|
| 888 |
+
- 'can_branch': Always True (all user messages can branch)
|
| 889 |
+
"""
|
| 890 |
+
if branch_id is None:
|
| 891 |
+
branch_id = self._current_branch_id
|
| 892 |
+
|
| 893 |
+
def _get_turns_sync():
|
| 894 |
+
"""Synchronous helper to get conversation turns."""
|
| 895 |
+
conn = self._get_connection()
|
| 896 |
+
with closing(conn.cursor()) as cursor:
|
| 897 |
+
cursor.execute(
|
| 898 |
+
"""
|
| 899 |
+
SELECT
|
| 900 |
+
ms.branch_turn_number,
|
| 901 |
+
am.message_data,
|
| 902 |
+
ms.created_at
|
| 903 |
+
FROM message_structure ms
|
| 904 |
+
JOIN agent_messages am ON ms.message_id = am.id
|
| 905 |
+
WHERE ms.session_id = ? AND ms.branch_id = ?
|
| 906 |
+
AND ms.message_type = 'user'
|
| 907 |
+
ORDER BY ms.branch_turn_number
|
| 908 |
+
""",
|
| 909 |
+
(self.session_id, branch_id),
|
| 910 |
+
)
|
| 911 |
+
|
| 912 |
+
turns = []
|
| 913 |
+
for row in cursor.fetchall():
|
| 914 |
+
turn_num, message_data, created_at = row
|
| 915 |
+
try:
|
| 916 |
+
content = json.loads(message_data).get("content", "")
|
| 917 |
+
turns.append(
|
| 918 |
+
{
|
| 919 |
+
"turn": turn_num,
|
| 920 |
+
"content": content[:100] + "..." if len(content) > 100 else content,
|
| 921 |
+
"full_content": content,
|
| 922 |
+
"timestamp": created_at,
|
| 923 |
+
"can_branch": True,
|
| 924 |
+
}
|
| 925 |
+
)
|
| 926 |
+
except (json.JSONDecodeError, AttributeError):
|
| 927 |
+
continue
|
| 928 |
+
|
| 929 |
+
return turns
|
| 930 |
+
|
| 931 |
+
return await asyncio.to_thread(_get_turns_sync)
|
| 932 |
+
|
| 933 |
+
async def find_turns_by_content(
|
| 934 |
+
self, search_term: str, branch_id: str | None = None
|
| 935 |
+
) -> list[dict[str, Any]]:
|
| 936 |
+
"""Find user turns containing specific content.
|
| 937 |
+
|
| 938 |
+
Args:
|
| 939 |
+
search_term: Text to search for in user messages.
|
| 940 |
+
branch_id: Branch to search in (current branch if None).
|
| 941 |
+
|
| 942 |
+
Returns:
|
| 943 |
+
List of matching turns with same format as get_conversation_turns().
|
| 944 |
+
"""
|
| 945 |
+
if branch_id is None:
|
| 946 |
+
branch_id = self._current_branch_id
|
| 947 |
+
|
| 948 |
+
def _search_sync():
|
| 949 |
+
"""Synchronous helper to search turns by content."""
|
| 950 |
+
conn = self._get_connection()
|
| 951 |
+
with closing(conn.cursor()) as cursor:
|
| 952 |
+
cursor.execute(
|
| 953 |
+
"""
|
| 954 |
+
SELECT
|
| 955 |
+
ms.branch_turn_number,
|
| 956 |
+
am.message_data,
|
| 957 |
+
ms.created_at
|
| 958 |
+
FROM message_structure ms
|
| 959 |
+
JOIN agent_messages am ON ms.message_id = am.id
|
| 960 |
+
WHERE ms.session_id = ? AND ms.branch_id = ?
|
| 961 |
+
AND ms.message_type = 'user'
|
| 962 |
+
AND am.message_data LIKE ?
|
| 963 |
+
ORDER BY ms.branch_turn_number
|
| 964 |
+
""",
|
| 965 |
+
(self.session_id, branch_id, f"%{search_term}%"),
|
| 966 |
+
)
|
| 967 |
+
|
| 968 |
+
matches = []
|
| 969 |
+
for row in cursor.fetchall():
|
| 970 |
+
turn_num, message_data, created_at = row
|
| 971 |
+
try:
|
| 972 |
+
content = json.loads(message_data).get("content", "")
|
| 973 |
+
matches.append(
|
| 974 |
+
{
|
| 975 |
+
"turn": turn_num,
|
| 976 |
+
"content": content,
|
| 977 |
+
"full_content": content,
|
| 978 |
+
"timestamp": created_at,
|
| 979 |
+
"can_branch": True,
|
| 980 |
+
}
|
| 981 |
+
)
|
| 982 |
+
except (json.JSONDecodeError, AttributeError):
|
| 983 |
+
continue
|
| 984 |
+
|
| 985 |
+
return matches
|
| 986 |
+
|
| 987 |
+
return await asyncio.to_thread(_search_sync)
|
| 988 |
+
|
| 989 |
+
async def get_conversation_by_turns(
|
| 990 |
+
self, branch_id: str | None = None
|
| 991 |
+
) -> dict[int, list[dict[str, str | None]]]:
|
| 992 |
+
"""Get conversation grouped by user turns for specified branch.
|
| 993 |
+
|
| 994 |
+
Args:
|
| 995 |
+
branch_id: Branch to get conversation from (current branch if None).
|
| 996 |
+
|
| 997 |
+
Returns:
|
| 998 |
+
Dictionary mapping turn numbers to lists of message metadata.
|
| 999 |
+
"""
|
| 1000 |
+
if branch_id is None:
|
| 1001 |
+
branch_id = self._current_branch_id
|
| 1002 |
+
|
| 1003 |
+
def _get_conversation_sync():
|
| 1004 |
+
"""Synchronous helper to get conversation by turns."""
|
| 1005 |
+
conn = self._get_connection()
|
| 1006 |
+
with closing(conn.cursor()) as cursor:
|
| 1007 |
+
cursor.execute(
|
| 1008 |
+
"""
|
| 1009 |
+
SELECT user_turn_number, message_type, tool_name
|
| 1010 |
+
FROM message_structure
|
| 1011 |
+
WHERE session_id = ? AND branch_id = ?
|
| 1012 |
+
ORDER BY sequence_number
|
| 1013 |
+
""",
|
| 1014 |
+
(self.session_id, branch_id),
|
| 1015 |
+
)
|
| 1016 |
+
|
| 1017 |
+
turns: dict[int, list[dict[str, str | None]]] = {}
|
| 1018 |
+
for row in cursor.fetchall():
|
| 1019 |
+
turn_num, msg_type, tool_name = row
|
| 1020 |
+
if turn_num not in turns:
|
| 1021 |
+
turns[turn_num] = []
|
| 1022 |
+
turns[turn_num].append({"type": msg_type, "tool_name": tool_name})
|
| 1023 |
+
return turns
|
| 1024 |
+
|
| 1025 |
+
return await asyncio.to_thread(_get_conversation_sync)
|
| 1026 |
+
|
| 1027 |
+
async def get_tool_usage(self, branch_id: str | None = None) -> list[tuple[str, int, int]]:
|
| 1028 |
+
"""Get all tool usage by turn for specified branch.
|
| 1029 |
+
|
| 1030 |
+
Args:
|
| 1031 |
+
branch_id: Branch to get tool usage from (current branch if None).
|
| 1032 |
+
|
| 1033 |
+
Returns:
|
| 1034 |
+
List of tuples containing (tool_name, usage_count, turn_number).
|
| 1035 |
+
"""
|
| 1036 |
+
if branch_id is None:
|
| 1037 |
+
branch_id = self._current_branch_id
|
| 1038 |
+
|
| 1039 |
+
def _get_tool_usage_sync():
|
| 1040 |
+
"""Synchronous helper to get tool usage statistics."""
|
| 1041 |
+
conn = self._get_connection()
|
| 1042 |
+
with closing(conn.cursor()) as cursor:
|
| 1043 |
+
cursor.execute(
|
| 1044 |
+
"""
|
| 1045 |
+
SELECT tool_name, COUNT(*), user_turn_number
|
| 1046 |
+
FROM message_structure
|
| 1047 |
+
WHERE session_id = ? AND branch_id = ? AND message_type IN (
|
| 1048 |
+
'tool_call', 'function_call', 'computer_call', 'file_search_call',
|
| 1049 |
+
'web_search_call', 'code_interpreter_call', 'custom_tool_call',
|
| 1050 |
+
'mcp_call', 'mcp_approval_request'
|
| 1051 |
+
)
|
| 1052 |
+
GROUP BY tool_name, user_turn_number
|
| 1053 |
+
ORDER BY user_turn_number
|
| 1054 |
+
""",
|
| 1055 |
+
(self.session_id, branch_id),
|
| 1056 |
+
)
|
| 1057 |
+
return cursor.fetchall()
|
| 1058 |
+
|
| 1059 |
+
return await asyncio.to_thread(_get_tool_usage_sync)
|
| 1060 |
+
|
| 1061 |
+
async def get_session_usage(self, branch_id: str | None = None) -> dict[str, int] | None:
|
| 1062 |
+
"""Get cumulative usage for session or specific branch.
|
| 1063 |
+
|
| 1064 |
+
Args:
|
| 1065 |
+
branch_id: If provided, only get usage for that branch. If None, get all branches.
|
| 1066 |
+
|
| 1067 |
+
Returns:
|
| 1068 |
+
Dictionary with usage statistics or None if no usage data found.
|
| 1069 |
+
"""
|
| 1070 |
+
|
| 1071 |
+
def _get_usage_sync():
|
| 1072 |
+
"""Synchronous helper to get session usage data."""
|
| 1073 |
+
conn = self._get_connection()
|
| 1074 |
+
# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
|
| 1075 |
+
with self._lock if self._is_memory_db else threading.Lock():
|
| 1076 |
+
if branch_id:
|
| 1077 |
+
# Branch-specific usage
|
| 1078 |
+
query = """
|
| 1079 |
+
SELECT
|
| 1080 |
+
SUM(requests) as total_requests,
|
| 1081 |
+
SUM(input_tokens) as total_input_tokens,
|
| 1082 |
+
SUM(output_tokens) as total_output_tokens,
|
| 1083 |
+
SUM(total_tokens) as total_total_tokens,
|
| 1084 |
+
COUNT(*) as total_turns
|
| 1085 |
+
FROM turn_usage
|
| 1086 |
+
WHERE session_id = ? AND branch_id = ?
|
| 1087 |
+
"""
|
| 1088 |
+
params: tuple[str, ...] = (self.session_id, branch_id)
|
| 1089 |
+
else:
|
| 1090 |
+
# All branches
|
| 1091 |
+
query = """
|
| 1092 |
+
SELECT
|
| 1093 |
+
SUM(requests) as total_requests,
|
| 1094 |
+
SUM(input_tokens) as total_input_tokens,
|
| 1095 |
+
SUM(output_tokens) as total_output_tokens,
|
| 1096 |
+
SUM(total_tokens) as total_total_tokens,
|
| 1097 |
+
COUNT(*) as total_turns
|
| 1098 |
+
FROM turn_usage
|
| 1099 |
+
WHERE session_id = ?
|
| 1100 |
+
"""
|
| 1101 |
+
params = (self.session_id,)
|
| 1102 |
+
|
| 1103 |
+
with closing(conn.cursor()) as cursor:
|
| 1104 |
+
cursor.execute(query, params)
|
| 1105 |
+
row = cursor.fetchone()
|
| 1106 |
+
|
| 1107 |
+
if row and row[0] is not None:
|
| 1108 |
+
return {
|
| 1109 |
+
"requests": row[0] or 0,
|
| 1110 |
+
"input_tokens": row[1] or 0,
|
| 1111 |
+
"output_tokens": row[2] or 0,
|
| 1112 |
+
"total_tokens": row[3] or 0,
|
| 1113 |
+
"total_turns": row[4] or 0,
|
| 1114 |
+
}
|
| 1115 |
+
return None
|
| 1116 |
+
|
| 1117 |
+
result = await asyncio.to_thread(_get_usage_sync)
|
| 1118 |
+
|
| 1119 |
+
return cast(Union[dict[str, int], None], result)
|
| 1120 |
+
|
| 1121 |
+
async def get_turn_usage(
|
| 1122 |
+
self,
|
| 1123 |
+
user_turn_number: int | None = None,
|
| 1124 |
+
branch_id: str | None = None,
|
| 1125 |
+
) -> list[dict[str, Any]] | dict[str, Any]:
|
| 1126 |
+
"""Get usage statistics by turn with full JSON token details.
|
| 1127 |
+
|
| 1128 |
+
Args:
|
| 1129 |
+
user_turn_number: Specific turn to get usage for. If None, returns all turns.
|
| 1130 |
+
branch_id: Branch to get usage from (current branch if None).
|
| 1131 |
+
|
| 1132 |
+
Returns:
|
| 1133 |
+
Dictionary with usage data for specific turn, or list of dictionaries for all turns.
|
| 1134 |
+
"""
|
| 1135 |
+
|
| 1136 |
+
if branch_id is None:
|
| 1137 |
+
branch_id = self._current_branch_id
|
| 1138 |
+
|
| 1139 |
+
def _get_turn_usage_sync():
|
| 1140 |
+
"""Synchronous helper to get turn usage statistics."""
|
| 1141 |
+
conn = self._get_connection()
|
| 1142 |
+
|
| 1143 |
+
if user_turn_number is not None:
|
| 1144 |
+
query = """
|
| 1145 |
+
SELECT requests, input_tokens, output_tokens, total_tokens,
|
| 1146 |
+
input_tokens_details, output_tokens_details
|
| 1147 |
+
FROM turn_usage
|
| 1148 |
+
WHERE session_id = ? AND branch_id = ? AND user_turn_number = ?
|
| 1149 |
+
"""
|
| 1150 |
+
|
| 1151 |
+
with closing(conn.cursor()) as cursor:
|
| 1152 |
+
cursor.execute(query, (self.session_id, branch_id, user_turn_number))
|
| 1153 |
+
row = cursor.fetchone()
|
| 1154 |
+
|
| 1155 |
+
if row:
|
| 1156 |
+
# Parse JSON details if present
|
| 1157 |
+
input_details = None
|
| 1158 |
+
output_details = None
|
| 1159 |
+
|
| 1160 |
+
if row[4]: # input_tokens_details
|
| 1161 |
+
try:
|
| 1162 |
+
input_details = json.loads(row[4])
|
| 1163 |
+
except json.JSONDecodeError:
|
| 1164 |
+
pass
|
| 1165 |
+
|
| 1166 |
+
if row[5]: # output_tokens_details
|
| 1167 |
+
try:
|
| 1168 |
+
output_details = json.loads(row[5])
|
| 1169 |
+
except json.JSONDecodeError:
|
| 1170 |
+
pass
|
| 1171 |
+
|
| 1172 |
+
return {
|
| 1173 |
+
"requests": row[0],
|
| 1174 |
+
"input_tokens": row[1],
|
| 1175 |
+
"output_tokens": row[2],
|
| 1176 |
+
"total_tokens": row[3],
|
| 1177 |
+
"input_tokens_details": input_details,
|
| 1178 |
+
"output_tokens_details": output_details,
|
| 1179 |
+
}
|
| 1180 |
+
return {}
|
| 1181 |
+
else:
|
| 1182 |
+
query = """
|
| 1183 |
+
SELECT user_turn_number, requests, input_tokens, output_tokens,
|
| 1184 |
+
total_tokens, input_tokens_details, output_tokens_details
|
| 1185 |
+
FROM turn_usage
|
| 1186 |
+
WHERE session_id = ? AND branch_id = ?
|
| 1187 |
+
ORDER BY user_turn_number
|
| 1188 |
+
"""
|
| 1189 |
+
|
| 1190 |
+
with closing(conn.cursor()) as cursor:
|
| 1191 |
+
cursor.execute(query, (self.session_id, branch_id))
|
| 1192 |
+
results = []
|
| 1193 |
+
for row in cursor.fetchall():
|
| 1194 |
+
# Parse JSON details if present
|
| 1195 |
+
input_details = None
|
| 1196 |
+
output_details = None
|
| 1197 |
+
|
| 1198 |
+
if row[5]: # input_tokens_details
|
| 1199 |
+
try:
|
| 1200 |
+
input_details = json.loads(row[5])
|
| 1201 |
+
except json.JSONDecodeError:
|
| 1202 |
+
pass
|
| 1203 |
+
|
| 1204 |
+
if row[6]: # output_tokens_details
|
| 1205 |
+
try:
|
| 1206 |
+
output_details = json.loads(row[6])
|
| 1207 |
+
except json.JSONDecodeError:
|
| 1208 |
+
pass
|
| 1209 |
+
|
| 1210 |
+
results.append(
|
| 1211 |
+
{
|
| 1212 |
+
"user_turn_number": row[0],
|
| 1213 |
+
"requests": row[1],
|
| 1214 |
+
"input_tokens": row[2],
|
| 1215 |
+
"output_tokens": row[3],
|
| 1216 |
+
"total_tokens": row[4],
|
| 1217 |
+
"input_tokens_details": input_details,
|
| 1218 |
+
"output_tokens_details": output_details,
|
| 1219 |
+
}
|
| 1220 |
+
)
|
| 1221 |
+
return results
|
| 1222 |
+
|
| 1223 |
+
result = await asyncio.to_thread(_get_turn_usage_sync)
|
| 1224 |
+
|
| 1225 |
+
return cast(Union[list[dict[str, Any]], dict[str, Any]], result)
|
| 1226 |
+
|
| 1227 |
+
async def _update_turn_usage_internal(self, user_turn_number: int, usage_data: Usage) -> None:
|
| 1228 |
+
"""Internal method to update usage for a specific turn with full JSON details.
|
| 1229 |
+
|
| 1230 |
+
Args:
|
| 1231 |
+
user_turn_number: The turn number to update usage for.
|
| 1232 |
+
usage_data: The usage data to store.
|
| 1233 |
+
"""
|
| 1234 |
+
|
| 1235 |
+
def _update_sync():
|
| 1236 |
+
"""Synchronous helper to update turn usage data."""
|
| 1237 |
+
conn = self._get_connection()
|
| 1238 |
+
# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
|
| 1239 |
+
with self._lock if self._is_memory_db else threading.Lock():
|
| 1240 |
+
# Serialize token details as JSON
|
| 1241 |
+
input_details_json = None
|
| 1242 |
+
output_details_json = None
|
| 1243 |
+
|
| 1244 |
+
if hasattr(usage_data, "input_tokens_details") and usage_data.input_tokens_details:
|
| 1245 |
+
try:
|
| 1246 |
+
input_details_json = json.dumps(usage_data.input_tokens_details.__dict__)
|
| 1247 |
+
except (TypeError, ValueError) as e:
|
| 1248 |
+
self._logger.warning(f"Failed to serialize input tokens details: {e}")
|
| 1249 |
+
input_details_json = None
|
| 1250 |
+
|
| 1251 |
+
if (
|
| 1252 |
+
hasattr(usage_data, "output_tokens_details")
|
| 1253 |
+
and usage_data.output_tokens_details
|
| 1254 |
+
):
|
| 1255 |
+
try:
|
| 1256 |
+
output_details_json = json.dumps(
|
| 1257 |
+
usage_data.output_tokens_details.__dict__
|
| 1258 |
+
)
|
| 1259 |
+
except (TypeError, ValueError) as e:
|
| 1260 |
+
self._logger.warning(f"Failed to serialize output tokens details: {e}")
|
| 1261 |
+
output_details_json = None
|
| 1262 |
+
|
| 1263 |
+
with closing(conn.cursor()) as cursor:
|
| 1264 |
+
cursor.execute(
|
| 1265 |
+
"""
|
| 1266 |
+
INSERT OR REPLACE INTO turn_usage
|
| 1267 |
+
(session_id, branch_id, user_turn_number, requests, input_tokens, output_tokens,
|
| 1268 |
+
total_tokens, input_tokens_details, output_tokens_details)
|
| 1269 |
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
| 1270 |
+
""", # noqa: E501
|
| 1271 |
+
(
|
| 1272 |
+
self.session_id,
|
| 1273 |
+
self._current_branch_id,
|
| 1274 |
+
user_turn_number,
|
| 1275 |
+
usage_data.requests or 0,
|
| 1276 |
+
usage_data.input_tokens or 0,
|
| 1277 |
+
usage_data.output_tokens or 0,
|
| 1278 |
+
usage_data.total_tokens or 0,
|
| 1279 |
+
input_details_json,
|
| 1280 |
+
output_details_json,
|
| 1281 |
+
),
|
| 1282 |
+
)
|
| 1283 |
+
conn.commit()
|
| 1284 |
+
|
| 1285 |
+
await asyncio.to_thread(_update_sync)
|
agents/extensions/memory/encrypt_session.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Encrypted Session wrapper for secure conversation storage.
|
| 2 |
+
|
| 3 |
+
This module provides transparent encryption for session storage with automatic
|
| 4 |
+
expiration of old data. When TTL expires, expired items are silently skipped.
|
| 5 |
+
|
| 6 |
+
Usage::
|
| 7 |
+
|
| 8 |
+
from agents.extensions.memory import EncryptedSession, SQLAlchemySession
|
| 9 |
+
|
| 10 |
+
# Create underlying session (e.g. SQLAlchemySession)
|
| 11 |
+
underlying_session = SQLAlchemySession.from_url(
|
| 12 |
+
session_id="user-123",
|
| 13 |
+
url="postgresql+asyncpg://app:[email protected]/agents",
|
| 14 |
+
create_tables=True,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
# Wrap with encryption and TTL-based expiration
|
| 18 |
+
session = EncryptedSession(
|
| 19 |
+
session_id="user-123",
|
| 20 |
+
underlying_session=underlying_session,
|
| 21 |
+
encryption_key="your-encryption-key",
|
| 22 |
+
ttl=600, # 10 minutes
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
await Runner.run(agent, "Hello", session=session)
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
from __future__ import annotations
|
| 29 |
+
|
| 30 |
+
import base64
|
| 31 |
+
import json
|
| 32 |
+
from typing import Any, cast
|
| 33 |
+
|
| 34 |
+
from cryptography.fernet import Fernet, InvalidToken
|
| 35 |
+
from cryptography.hazmat.primitives import hashes
|
| 36 |
+
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
|
| 37 |
+
from typing_extensions import Literal, TypedDict, TypeGuard
|
| 38 |
+
|
| 39 |
+
from ...items import TResponseInputItem
|
| 40 |
+
from ...memory.session import SessionABC
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class EncryptedEnvelope(TypedDict):
|
| 44 |
+
"""TypedDict for encrypted message envelopes stored in the underlying session."""
|
| 45 |
+
|
| 46 |
+
__enc__: Literal[1]
|
| 47 |
+
v: int
|
| 48 |
+
kid: str
|
| 49 |
+
payload: str
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _ensure_fernet_key_bytes(master_key: str) -> bytes:
|
| 53 |
+
"""
|
| 54 |
+
Accept either a Fernet key (urlsafe-b64, 32 bytes after decode) or a raw string.
|
| 55 |
+
Returns raw bytes suitable for HKDF input.
|
| 56 |
+
"""
|
| 57 |
+
if not master_key:
|
| 58 |
+
raise ValueError("encryption_key not set; required for EncryptedSession.")
|
| 59 |
+
try:
|
| 60 |
+
key_bytes = base64.urlsafe_b64decode(master_key)
|
| 61 |
+
if len(key_bytes) == 32:
|
| 62 |
+
return key_bytes
|
| 63 |
+
except Exception:
|
| 64 |
+
pass
|
| 65 |
+
return master_key.encode("utf-8")
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def _derive_session_fernet_key(master_key_bytes: bytes, session_id: str) -> Fernet:
|
| 69 |
+
hkdf = HKDF(
|
| 70 |
+
algorithm=hashes.SHA256(),
|
| 71 |
+
length=32,
|
| 72 |
+
salt=session_id.encode("utf-8"),
|
| 73 |
+
info=b"agents.session-store.hkdf.v1",
|
| 74 |
+
)
|
| 75 |
+
derived = hkdf.derive(master_key_bytes)
|
| 76 |
+
return Fernet(base64.urlsafe_b64encode(derived))
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def _to_json_bytes(obj: Any) -> bytes:
|
| 80 |
+
return json.dumps(obj, ensure_ascii=False, separators=(",", ":"), default=str).encode("utf-8")
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def _from_json_bytes(data: bytes) -> Any:
|
| 84 |
+
return json.loads(data.decode("utf-8"))
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def _is_encrypted_envelope(item: object) -> TypeGuard[EncryptedEnvelope]:
|
| 88 |
+
"""Type guard to check if an item is an encrypted envelope."""
|
| 89 |
+
return (
|
| 90 |
+
isinstance(item, dict)
|
| 91 |
+
and item.get("__enc__") == 1
|
| 92 |
+
and "payload" in item
|
| 93 |
+
and "kid" in item
|
| 94 |
+
and "v" in item
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class EncryptedSession(SessionABC):
|
| 99 |
+
"""Encrypted wrapper for Session implementations with TTL-based expiration.
|
| 100 |
+
|
| 101 |
+
This class wraps any SessionABC implementation to provide transparent
|
| 102 |
+
encryption/decryption of stored items using Fernet encryption with
|
| 103 |
+
per-session key derivation and automatic expiration of old data.
|
| 104 |
+
|
| 105 |
+
When items expire (exceed TTL), they are silently skipped during retrieval.
|
| 106 |
+
|
| 107 |
+
Note: Expired tokens are rejected based on the system clock of the application server.
|
| 108 |
+
To avoid valid tokens being rejected due to clock drift, ensure all servers in
|
| 109 |
+
your environment are synchronized using NTP.
|
| 110 |
+
"""
|
| 111 |
+
|
| 112 |
+
def __init__(
|
| 113 |
+
self,
|
| 114 |
+
session_id: str,
|
| 115 |
+
underlying_session: SessionABC,
|
| 116 |
+
encryption_key: str,
|
| 117 |
+
ttl: int = 600,
|
| 118 |
+
):
|
| 119 |
+
"""
|
| 120 |
+
Args:
|
| 121 |
+
session_id: ID for this session
|
| 122 |
+
underlying_session: The real session store (e.g. SQLiteSession, SQLAlchemySession)
|
| 123 |
+
encryption_key: Master key (Fernet key or raw secret)
|
| 124 |
+
ttl: Token time-to-live in seconds (default 10 min)
|
| 125 |
+
"""
|
| 126 |
+
self.session_id = session_id
|
| 127 |
+
self.underlying_session = underlying_session
|
| 128 |
+
self.ttl = ttl
|
| 129 |
+
|
| 130 |
+
master = _ensure_fernet_key_bytes(encryption_key)
|
| 131 |
+
self.cipher = _derive_session_fernet_key(master, session_id)
|
| 132 |
+
self._kid = "hkdf-v1"
|
| 133 |
+
self._ver = 1
|
| 134 |
+
|
| 135 |
+
def __getattr__(self, name):
|
| 136 |
+
return getattr(self.underlying_session, name)
|
| 137 |
+
|
| 138 |
+
def _wrap(self, item: TResponseInputItem) -> EncryptedEnvelope:
|
| 139 |
+
if isinstance(item, dict):
|
| 140 |
+
payload = item
|
| 141 |
+
elif hasattr(item, "model_dump"):
|
| 142 |
+
payload = item.model_dump()
|
| 143 |
+
elif hasattr(item, "__dict__"):
|
| 144 |
+
payload = item.__dict__
|
| 145 |
+
else:
|
| 146 |
+
payload = dict(item)
|
| 147 |
+
|
| 148 |
+
token = self.cipher.encrypt(_to_json_bytes(payload)).decode("utf-8")
|
| 149 |
+
return {"__enc__": 1, "v": self._ver, "kid": self._kid, "payload": token}
|
| 150 |
+
|
| 151 |
+
def _unwrap(self, item: TResponseInputItem | EncryptedEnvelope) -> TResponseInputItem | None:
|
| 152 |
+
if not _is_encrypted_envelope(item):
|
| 153 |
+
return cast(TResponseInputItem, item)
|
| 154 |
+
|
| 155 |
+
try:
|
| 156 |
+
token = item["payload"].encode("utf-8")
|
| 157 |
+
plaintext = self.cipher.decrypt(token, ttl=self.ttl)
|
| 158 |
+
return cast(TResponseInputItem, _from_json_bytes(plaintext))
|
| 159 |
+
except (InvalidToken, KeyError):
|
| 160 |
+
return None
|
| 161 |
+
|
| 162 |
+
async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
|
| 163 |
+
encrypted_items = await self.underlying_session.get_items(limit)
|
| 164 |
+
valid_items: list[TResponseInputItem] = []
|
| 165 |
+
for enc in encrypted_items:
|
| 166 |
+
item = self._unwrap(enc)
|
| 167 |
+
if item is not None:
|
| 168 |
+
valid_items.append(item)
|
| 169 |
+
return valid_items
|
| 170 |
+
|
| 171 |
+
async def add_items(self, items: list[TResponseInputItem]) -> None:
|
| 172 |
+
wrapped: list[EncryptedEnvelope] = [self._wrap(it) for it in items]
|
| 173 |
+
await self.underlying_session.add_items(cast(list[TResponseInputItem], wrapped))
|
| 174 |
+
|
| 175 |
+
async def pop_item(self) -> TResponseInputItem | None:
|
| 176 |
+
while True:
|
| 177 |
+
enc = await self.underlying_session.pop_item()
|
| 178 |
+
if not enc:
|
| 179 |
+
return None
|
| 180 |
+
item = self._unwrap(enc)
|
| 181 |
+
if item is not None:
|
| 182 |
+
return item
|
| 183 |
+
|
| 184 |
+
async def clear_session(self) -> None:
|
| 185 |
+
await self.underlying_session.clear_session()
|
agents/extensions/memory/redis_session.py
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Redis-powered Session backend.
|
| 2 |
+
|
| 3 |
+
Usage::
|
| 4 |
+
|
| 5 |
+
from agents.extensions.memory import RedisSession
|
| 6 |
+
|
| 7 |
+
# Create from Redis URL
|
| 8 |
+
session = RedisSession.from_url(
|
| 9 |
+
session_id="user-123",
|
| 10 |
+
url="redis://localhost:6379/0",
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
# Or pass an existing Redis client that your application already manages
|
| 14 |
+
session = RedisSession(
|
| 15 |
+
session_id="user-123",
|
| 16 |
+
redis_client=my_redis_client,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
await Runner.run(agent, "Hello", session=session)
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
from __future__ import annotations
|
| 23 |
+
|
| 24 |
+
import asyncio
|
| 25 |
+
import json
|
| 26 |
+
import time
|
| 27 |
+
from typing import Any
|
| 28 |
+
from urllib.parse import urlparse
|
| 29 |
+
|
| 30 |
+
try:
|
| 31 |
+
import redis.asyncio as redis
|
| 32 |
+
from redis.asyncio import Redis
|
| 33 |
+
except ImportError as e:
|
| 34 |
+
raise ImportError(
|
| 35 |
+
"RedisSession requires the 'redis' package. Install it with: pip install redis"
|
| 36 |
+
) from e
|
| 37 |
+
|
| 38 |
+
from ...items import TResponseInputItem
|
| 39 |
+
from ...memory.session import SessionABC
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class RedisSession(SessionABC):
|
| 43 |
+
"""Redis implementation of :pyclass:`agents.memory.session.Session`."""
|
| 44 |
+
|
| 45 |
+
def __init__(
|
| 46 |
+
self,
|
| 47 |
+
session_id: str,
|
| 48 |
+
*,
|
| 49 |
+
redis_client: Redis,
|
| 50 |
+
key_prefix: str = "agents:session",
|
| 51 |
+
ttl: int | None = None,
|
| 52 |
+
):
|
| 53 |
+
"""Initializes a new RedisSession.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
session_id (str): Unique identifier for the conversation.
|
| 57 |
+
redis_client (Redis[bytes]): A pre-configured Redis async client.
|
| 58 |
+
key_prefix (str, optional): Prefix for Redis keys to avoid collisions.
|
| 59 |
+
Defaults to "agents:session".
|
| 60 |
+
ttl (int | None, optional): Time-to-live in seconds for session data.
|
| 61 |
+
If None, data persists indefinitely. Defaults to None.
|
| 62 |
+
"""
|
| 63 |
+
self.session_id = session_id
|
| 64 |
+
self._redis = redis_client
|
| 65 |
+
self._key_prefix = key_prefix
|
| 66 |
+
self._ttl = ttl
|
| 67 |
+
self._lock = asyncio.Lock()
|
| 68 |
+
self._owns_client = False # Track if we own the Redis client
|
| 69 |
+
|
| 70 |
+
# Redis key patterns
|
| 71 |
+
self._session_key = f"{self._key_prefix}:{self.session_id}"
|
| 72 |
+
self._messages_key = f"{self._session_key}:messages"
|
| 73 |
+
self._counter_key = f"{self._session_key}:counter"
|
| 74 |
+
|
| 75 |
+
@classmethod
|
| 76 |
+
def from_url(
|
| 77 |
+
cls,
|
| 78 |
+
session_id: str,
|
| 79 |
+
*,
|
| 80 |
+
url: str,
|
| 81 |
+
redis_kwargs: dict[str, Any] | None = None,
|
| 82 |
+
**kwargs: Any,
|
| 83 |
+
) -> RedisSession:
|
| 84 |
+
"""Create a session from a Redis URL string.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
session_id (str): Conversation ID.
|
| 88 |
+
url (str): Redis URL, e.g. "redis://localhost:6379/0" or "rediss://host:6380".
|
| 89 |
+
redis_kwargs (dict[str, Any] | None): Additional keyword arguments forwarded to
|
| 90 |
+
redis.asyncio.from_url.
|
| 91 |
+
**kwargs: Additional keyword arguments forwarded to the main constructor
|
| 92 |
+
(e.g., key_prefix, ttl, etc.).
|
| 93 |
+
|
| 94 |
+
Returns:
|
| 95 |
+
RedisSession: An instance of RedisSession connected to the specified Redis server.
|
| 96 |
+
"""
|
| 97 |
+
redis_kwargs = redis_kwargs or {}
|
| 98 |
+
|
| 99 |
+
# Parse URL to determine if we need SSL
|
| 100 |
+
parsed = urlparse(url)
|
| 101 |
+
if parsed.scheme == "rediss":
|
| 102 |
+
redis_kwargs.setdefault("ssl", True)
|
| 103 |
+
|
| 104 |
+
redis_client = redis.from_url(url, **redis_kwargs)
|
| 105 |
+
session = cls(session_id, redis_client=redis_client, **kwargs)
|
| 106 |
+
session._owns_client = True # We created the client, so we own it
|
| 107 |
+
return session
|
| 108 |
+
|
| 109 |
+
async def _serialize_item(self, item: TResponseInputItem) -> str:
|
| 110 |
+
"""Serialize an item to JSON string. Can be overridden by subclasses."""
|
| 111 |
+
return json.dumps(item, separators=(",", ":"))
|
| 112 |
+
|
| 113 |
+
async def _deserialize_item(self, item: str) -> TResponseInputItem:
|
| 114 |
+
"""Deserialize a JSON string to an item. Can be overridden by subclasses."""
|
| 115 |
+
return json.loads(item) # type: ignore[no-any-return] # json.loads returns Any but we know the structure
|
| 116 |
+
|
| 117 |
+
async def _get_next_id(self) -> int:
|
| 118 |
+
"""Get the next message ID using Redis INCR for atomic increment."""
|
| 119 |
+
result = await self._redis.incr(self._counter_key)
|
| 120 |
+
return int(result)
|
| 121 |
+
|
| 122 |
+
async def _set_ttl_if_configured(self, *keys: str) -> None:
|
| 123 |
+
"""Set TTL on keys if configured."""
|
| 124 |
+
if self._ttl is not None:
|
| 125 |
+
pipe = self._redis.pipeline()
|
| 126 |
+
for key in keys:
|
| 127 |
+
pipe.expire(key, self._ttl)
|
| 128 |
+
await pipe.execute()
|
| 129 |
+
|
| 130 |
+
# ------------------------------------------------------------------
|
| 131 |
+
# Session protocol implementation
|
| 132 |
+
# ------------------------------------------------------------------
|
| 133 |
+
|
| 134 |
+
async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
|
| 135 |
+
"""Retrieve the conversation history for this session.
|
| 136 |
+
|
| 137 |
+
Args:
|
| 138 |
+
limit: Maximum number of items to retrieve. If None, retrieves all items.
|
| 139 |
+
When specified, returns the latest N items in chronological order.
|
| 140 |
+
|
| 141 |
+
Returns:
|
| 142 |
+
List of input items representing the conversation history
|
| 143 |
+
"""
|
| 144 |
+
async with self._lock:
|
| 145 |
+
if limit is None:
|
| 146 |
+
# Get all messages in chronological order
|
| 147 |
+
raw_messages = await self._redis.lrange(self._messages_key, 0, -1) # type: ignore[misc] # Redis library returns Union[Awaitable[T], T] in async context
|
| 148 |
+
else:
|
| 149 |
+
if limit <= 0:
|
| 150 |
+
return []
|
| 151 |
+
# Get the latest N messages (Redis list is ordered chronologically)
|
| 152 |
+
# Use negative indices to get from the end - Redis uses -N to -1 for last N items
|
| 153 |
+
raw_messages = await self._redis.lrange(self._messages_key, -limit, -1) # type: ignore[misc] # Redis library returns Union[Awaitable[T], T] in async context
|
| 154 |
+
|
| 155 |
+
items: list[TResponseInputItem] = []
|
| 156 |
+
for raw_msg in raw_messages:
|
| 157 |
+
try:
|
| 158 |
+
# Handle both bytes (default) and str (decode_responses=True) Redis clients
|
| 159 |
+
if isinstance(raw_msg, bytes):
|
| 160 |
+
msg_str = raw_msg.decode("utf-8")
|
| 161 |
+
else:
|
| 162 |
+
msg_str = raw_msg # Already a string
|
| 163 |
+
item = await self._deserialize_item(msg_str)
|
| 164 |
+
items.append(item)
|
| 165 |
+
except (json.JSONDecodeError, UnicodeDecodeError):
|
| 166 |
+
# Skip corrupted messages
|
| 167 |
+
continue
|
| 168 |
+
|
| 169 |
+
return items
|
| 170 |
+
|
| 171 |
+
async def add_items(self, items: list[TResponseInputItem]) -> None:
|
| 172 |
+
"""Add new items to the conversation history.
|
| 173 |
+
|
| 174 |
+
Args:
|
| 175 |
+
items: List of input items to add to the history
|
| 176 |
+
"""
|
| 177 |
+
if not items:
|
| 178 |
+
return
|
| 179 |
+
|
| 180 |
+
async with self._lock:
|
| 181 |
+
pipe = self._redis.pipeline()
|
| 182 |
+
|
| 183 |
+
# Set session metadata with current timestamp
|
| 184 |
+
pipe.hset(
|
| 185 |
+
self._session_key,
|
| 186 |
+
mapping={
|
| 187 |
+
"session_id": self.session_id,
|
| 188 |
+
"created_at": str(int(time.time())),
|
| 189 |
+
"updated_at": str(int(time.time())),
|
| 190 |
+
},
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
# Add all items to the messages list
|
| 194 |
+
serialized_items = []
|
| 195 |
+
for item in items:
|
| 196 |
+
serialized = await self._serialize_item(item)
|
| 197 |
+
serialized_items.append(serialized)
|
| 198 |
+
|
| 199 |
+
if serialized_items:
|
| 200 |
+
pipe.rpush(self._messages_key, *serialized_items)
|
| 201 |
+
|
| 202 |
+
# Update the session timestamp
|
| 203 |
+
pipe.hset(self._session_key, "updated_at", str(int(time.time())))
|
| 204 |
+
|
| 205 |
+
# Execute all commands
|
| 206 |
+
await pipe.execute()
|
| 207 |
+
|
| 208 |
+
# Set TTL if configured
|
| 209 |
+
await self._set_ttl_if_configured(
|
| 210 |
+
self._session_key, self._messages_key, self._counter_key
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
async def pop_item(self) -> TResponseInputItem | None:
|
| 214 |
+
"""Remove and return the most recent item from the session.
|
| 215 |
+
|
| 216 |
+
Returns:
|
| 217 |
+
The most recent item if it exists, None if the session is empty
|
| 218 |
+
"""
|
| 219 |
+
async with self._lock:
|
| 220 |
+
# Use RPOP to atomically remove and return the rightmost (most recent) item
|
| 221 |
+
raw_msg = await self._redis.rpop(self._messages_key) # type: ignore[misc] # Redis library returns Union[Awaitable[T], T] in async context
|
| 222 |
+
|
| 223 |
+
if raw_msg is None:
|
| 224 |
+
return None
|
| 225 |
+
|
| 226 |
+
try:
|
| 227 |
+
# Handle both bytes (default) and str (decode_responses=True) Redis clients
|
| 228 |
+
if isinstance(raw_msg, bytes):
|
| 229 |
+
msg_str = raw_msg.decode("utf-8")
|
| 230 |
+
else:
|
| 231 |
+
msg_str = raw_msg # Already a string
|
| 232 |
+
return await self._deserialize_item(msg_str)
|
| 233 |
+
except (json.JSONDecodeError, UnicodeDecodeError):
|
| 234 |
+
# Return None for corrupted messages (already removed)
|
| 235 |
+
return None
|
| 236 |
+
|
| 237 |
+
async def clear_session(self) -> None:
|
| 238 |
+
"""Clear all items for this session."""
|
| 239 |
+
async with self._lock:
|
| 240 |
+
# Delete all keys associated with this session
|
| 241 |
+
await self._redis.delete(
|
| 242 |
+
self._session_key,
|
| 243 |
+
self._messages_key,
|
| 244 |
+
self._counter_key,
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
async def close(self) -> None:
|
| 248 |
+
"""Close the Redis connection.
|
| 249 |
+
|
| 250 |
+
Only closes the connection if this session owns the Redis client
|
| 251 |
+
(i.e., created via from_url). If the client was injected externally,
|
| 252 |
+
the caller is responsible for managing its lifecycle.
|
| 253 |
+
"""
|
| 254 |
+
if self._owns_client:
|
| 255 |
+
await self._redis.aclose()
|
| 256 |
+
|
| 257 |
+
async def ping(self) -> bool:
|
| 258 |
+
"""Test Redis connectivity.
|
| 259 |
+
|
| 260 |
+
Returns:
|
| 261 |
+
True if Redis is reachable, False otherwise.
|
| 262 |
+
"""
|
| 263 |
+
try:
|
| 264 |
+
await self._redis.ping()
|
| 265 |
+
return True
|
| 266 |
+
except Exception:
|
| 267 |
+
return False
|
agents/extensions/memory/sqlalchemy_session.py
ADDED
|
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SQLAlchemy-powered Session backend.
|
| 2 |
+
|
| 3 |
+
Usage::
|
| 4 |
+
|
| 5 |
+
from agents.extensions.memory import SQLAlchemySession
|
| 6 |
+
|
| 7 |
+
# Create from SQLAlchemy URL (uses asyncpg driver under the hood for Postgres)
|
| 8 |
+
session = SQLAlchemySession.from_url(
|
| 9 |
+
session_id="user-123",
|
| 10 |
+
url="postgresql+asyncpg://app:[email protected]/agents",
|
| 11 |
+
create_tables=True, # If you want to auto-create tables, set to True.
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
# Or pass an existing AsyncEngine that your application already manages
|
| 15 |
+
session = SQLAlchemySession(
|
| 16 |
+
session_id="user-123",
|
| 17 |
+
engine=my_async_engine,
|
| 18 |
+
create_tables=True, # If you want to auto-create tables, set to True.
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
await Runner.run(agent, "Hello", session=session)
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
from __future__ import annotations
|
| 25 |
+
|
| 26 |
+
import asyncio
|
| 27 |
+
import json
|
| 28 |
+
from typing import Any
|
| 29 |
+
|
| 30 |
+
from sqlalchemy import (
|
| 31 |
+
TIMESTAMP,
|
| 32 |
+
Column,
|
| 33 |
+
ForeignKey,
|
| 34 |
+
Index,
|
| 35 |
+
Integer,
|
| 36 |
+
MetaData,
|
| 37 |
+
String,
|
| 38 |
+
Table,
|
| 39 |
+
Text,
|
| 40 |
+
delete,
|
| 41 |
+
insert,
|
| 42 |
+
select,
|
| 43 |
+
text as sql_text,
|
| 44 |
+
update,
|
| 45 |
+
)
|
| 46 |
+
from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine
|
| 47 |
+
|
| 48 |
+
from ...items import TResponseInputItem
|
| 49 |
+
from ...memory.session import SessionABC
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class SQLAlchemySession(SessionABC):
|
| 53 |
+
"""SQLAlchemy implementation of :pyclass:`agents.memory.session.Session`."""
|
| 54 |
+
|
| 55 |
+
_metadata: MetaData
|
| 56 |
+
_sessions: Table
|
| 57 |
+
_messages: Table
|
| 58 |
+
|
| 59 |
+
def __init__(
|
| 60 |
+
self,
|
| 61 |
+
session_id: str,
|
| 62 |
+
*,
|
| 63 |
+
engine: AsyncEngine,
|
| 64 |
+
create_tables: bool = False,
|
| 65 |
+
sessions_table: str = "agent_sessions",
|
| 66 |
+
messages_table: str = "agent_messages",
|
| 67 |
+
):
|
| 68 |
+
"""Initializes a new SQLAlchemySession.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
session_id (str): Unique identifier for the conversation.
|
| 72 |
+
engine (AsyncEngine): A pre-configured SQLAlchemy async engine. The engine
|
| 73 |
+
must be created with an async driver (e.g., 'postgresql+asyncpg://',
|
| 74 |
+
'mysql+aiomysql://', or 'sqlite+aiosqlite://').
|
| 75 |
+
create_tables (bool, optional): Whether to automatically create the required
|
| 76 |
+
tables and indexes. Defaults to False for production use. Set to True for
|
| 77 |
+
development and testing when migrations aren't used.
|
| 78 |
+
sessions_table (str, optional): Override the default table name for sessions if needed.
|
| 79 |
+
messages_table (str, optional): Override the default table name for messages if needed.
|
| 80 |
+
"""
|
| 81 |
+
self.session_id = session_id
|
| 82 |
+
self._engine = engine
|
| 83 |
+
self._lock = asyncio.Lock()
|
| 84 |
+
|
| 85 |
+
self._metadata = MetaData()
|
| 86 |
+
self._sessions = Table(
|
| 87 |
+
sessions_table,
|
| 88 |
+
self._metadata,
|
| 89 |
+
Column("session_id", String, primary_key=True),
|
| 90 |
+
Column(
|
| 91 |
+
"created_at",
|
| 92 |
+
TIMESTAMP(timezone=False),
|
| 93 |
+
server_default=sql_text("CURRENT_TIMESTAMP"),
|
| 94 |
+
nullable=False,
|
| 95 |
+
),
|
| 96 |
+
Column(
|
| 97 |
+
"updated_at",
|
| 98 |
+
TIMESTAMP(timezone=False),
|
| 99 |
+
server_default=sql_text("CURRENT_TIMESTAMP"),
|
| 100 |
+
onupdate=sql_text("CURRENT_TIMESTAMP"),
|
| 101 |
+
nullable=False,
|
| 102 |
+
),
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
self._messages = Table(
|
| 106 |
+
messages_table,
|
| 107 |
+
self._metadata,
|
| 108 |
+
Column("id", Integer, primary_key=True, autoincrement=True),
|
| 109 |
+
Column(
|
| 110 |
+
"session_id",
|
| 111 |
+
String,
|
| 112 |
+
ForeignKey(f"{sessions_table}.session_id", ondelete="CASCADE"),
|
| 113 |
+
nullable=False,
|
| 114 |
+
),
|
| 115 |
+
Column("message_data", Text, nullable=False),
|
| 116 |
+
Column(
|
| 117 |
+
"created_at",
|
| 118 |
+
TIMESTAMP(timezone=False),
|
| 119 |
+
server_default=sql_text("CURRENT_TIMESTAMP"),
|
| 120 |
+
nullable=False,
|
| 121 |
+
),
|
| 122 |
+
Index(
|
| 123 |
+
f"idx_{messages_table}_session_time",
|
| 124 |
+
"session_id",
|
| 125 |
+
"created_at",
|
| 126 |
+
),
|
| 127 |
+
sqlite_autoincrement=True,
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
# Async session factory
|
| 131 |
+
self._session_factory = async_sessionmaker(self._engine, expire_on_commit=False)
|
| 132 |
+
|
| 133 |
+
self._create_tables = create_tables
|
| 134 |
+
|
| 135 |
+
# ---------------------------------------------------------------------
|
| 136 |
+
# Convenience constructors
|
| 137 |
+
# ---------------------------------------------------------------------
|
| 138 |
+
@classmethod
|
| 139 |
+
def from_url(
|
| 140 |
+
cls,
|
| 141 |
+
session_id: str,
|
| 142 |
+
*,
|
| 143 |
+
url: str,
|
| 144 |
+
engine_kwargs: dict[str, Any] | None = None,
|
| 145 |
+
**kwargs: Any,
|
| 146 |
+
) -> SQLAlchemySession:
|
| 147 |
+
"""Create a session from a database URL string.
|
| 148 |
+
|
| 149 |
+
Args:
|
| 150 |
+
session_id (str): Conversation ID.
|
| 151 |
+
url (str): Any SQLAlchemy async URL, e.g. "postgresql+asyncpg://user:pass@host/db".
|
| 152 |
+
engine_kwargs (dict[str, Any] | None): Additional keyword arguments forwarded to
|
| 153 |
+
sqlalchemy.ext.asyncio.create_async_engine.
|
| 154 |
+
**kwargs: Additional keyword arguments forwarded to the main constructor
|
| 155 |
+
(e.g., create_tables, custom table names, etc.).
|
| 156 |
+
|
| 157 |
+
Returns:
|
| 158 |
+
SQLAlchemySession: An instance of SQLAlchemySession connected to the specified database.
|
| 159 |
+
"""
|
| 160 |
+
engine_kwargs = engine_kwargs or {}
|
| 161 |
+
engine = create_async_engine(url, **engine_kwargs)
|
| 162 |
+
return cls(session_id, engine=engine, **kwargs)
|
| 163 |
+
|
| 164 |
+
async def _serialize_item(self, item: TResponseInputItem) -> str:
|
| 165 |
+
"""Serialize an item to JSON string. Can be overridden by subclasses."""
|
| 166 |
+
return json.dumps(item, separators=(",", ":"))
|
| 167 |
+
|
| 168 |
+
async def _deserialize_item(self, item: str) -> TResponseInputItem:
|
| 169 |
+
"""Deserialize a JSON string to an item. Can be overridden by subclasses."""
|
| 170 |
+
return json.loads(item) # type: ignore[no-any-return]
|
| 171 |
+
|
| 172 |
+
# ------------------------------------------------------------------
|
| 173 |
+
# Session protocol implementation
|
| 174 |
+
# ------------------------------------------------------------------
|
| 175 |
+
async def _ensure_tables(self) -> None:
|
| 176 |
+
"""Ensure tables are created before any database operations."""
|
| 177 |
+
if self._create_tables:
|
| 178 |
+
async with self._engine.begin() as conn:
|
| 179 |
+
await conn.run_sync(self._metadata.create_all)
|
| 180 |
+
self._create_tables = False # Only create once
|
| 181 |
+
|
| 182 |
+
async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
|
| 183 |
+
"""Retrieve the conversation history for this session.
|
| 184 |
+
|
| 185 |
+
Args:
|
| 186 |
+
limit: Maximum number of items to retrieve. If None, retrieves all items.
|
| 187 |
+
When specified, returns the latest N items in chronological order.
|
| 188 |
+
|
| 189 |
+
Returns:
|
| 190 |
+
List of input items representing the conversation history
|
| 191 |
+
"""
|
| 192 |
+
await self._ensure_tables()
|
| 193 |
+
async with self._session_factory() as sess:
|
| 194 |
+
if limit is None:
|
| 195 |
+
stmt = (
|
| 196 |
+
select(self._messages.c.message_data)
|
| 197 |
+
.where(self._messages.c.session_id == self.session_id)
|
| 198 |
+
.order_by(self._messages.c.created_at.asc())
|
| 199 |
+
)
|
| 200 |
+
else:
|
| 201 |
+
stmt = (
|
| 202 |
+
select(self._messages.c.message_data)
|
| 203 |
+
.where(self._messages.c.session_id == self.session_id)
|
| 204 |
+
# Use DESC + LIMIT to get the latest N
|
| 205 |
+
# then reverse later for chronological order.
|
| 206 |
+
.order_by(self._messages.c.created_at.desc())
|
| 207 |
+
.limit(limit)
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
result = await sess.execute(stmt)
|
| 211 |
+
rows: list[str] = [row[0] for row in result.all()]
|
| 212 |
+
|
| 213 |
+
if limit is not None:
|
| 214 |
+
rows.reverse()
|
| 215 |
+
|
| 216 |
+
items: list[TResponseInputItem] = []
|
| 217 |
+
for raw in rows:
|
| 218 |
+
try:
|
| 219 |
+
items.append(await self._deserialize_item(raw))
|
| 220 |
+
except json.JSONDecodeError:
|
| 221 |
+
# Skip corrupted rows
|
| 222 |
+
continue
|
| 223 |
+
return items
|
| 224 |
+
|
| 225 |
+
async def add_items(self, items: list[TResponseInputItem]) -> None:
|
| 226 |
+
"""Add new items to the conversation history.
|
| 227 |
+
|
| 228 |
+
Args:
|
| 229 |
+
items: List of input items to add to the history
|
| 230 |
+
"""
|
| 231 |
+
if not items:
|
| 232 |
+
return
|
| 233 |
+
|
| 234 |
+
await self._ensure_tables()
|
| 235 |
+
payload = [
|
| 236 |
+
{
|
| 237 |
+
"session_id": self.session_id,
|
| 238 |
+
"message_data": await self._serialize_item(item),
|
| 239 |
+
}
|
| 240 |
+
for item in items
|
| 241 |
+
]
|
| 242 |
+
|
| 243 |
+
async with self._session_factory() as sess:
|
| 244 |
+
async with sess.begin():
|
| 245 |
+
# Ensure the parent session row exists - use merge for cross-DB compatibility
|
| 246 |
+
# Check if session exists
|
| 247 |
+
existing = await sess.execute(
|
| 248 |
+
select(self._sessions.c.session_id).where(
|
| 249 |
+
self._sessions.c.session_id == self.session_id
|
| 250 |
+
)
|
| 251 |
+
)
|
| 252 |
+
if not existing.scalar_one_or_none():
|
| 253 |
+
# Session doesn't exist, create it
|
| 254 |
+
await sess.execute(
|
| 255 |
+
insert(self._sessions).values({"session_id": self.session_id})
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
# Insert messages in bulk
|
| 259 |
+
await sess.execute(insert(self._messages), payload)
|
| 260 |
+
|
| 261 |
+
# Touch updated_at column
|
| 262 |
+
await sess.execute(
|
| 263 |
+
update(self._sessions)
|
| 264 |
+
.where(self._sessions.c.session_id == self.session_id)
|
| 265 |
+
.values(updated_at=sql_text("CURRENT_TIMESTAMP"))
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
async def pop_item(self) -> TResponseInputItem | None:
|
| 269 |
+
"""Remove and return the most recent item from the session.
|
| 270 |
+
|
| 271 |
+
Returns:
|
| 272 |
+
The most recent item if it exists, None if the session is empty
|
| 273 |
+
"""
|
| 274 |
+
await self._ensure_tables()
|
| 275 |
+
async with self._session_factory() as sess:
|
| 276 |
+
async with sess.begin():
|
| 277 |
+
# Fallback for all dialects - get ID first, then delete
|
| 278 |
+
subq = (
|
| 279 |
+
select(self._messages.c.id)
|
| 280 |
+
.where(self._messages.c.session_id == self.session_id)
|
| 281 |
+
.order_by(self._messages.c.created_at.desc())
|
| 282 |
+
.limit(1)
|
| 283 |
+
)
|
| 284 |
+
res = await sess.execute(subq)
|
| 285 |
+
row_id = res.scalar_one_or_none()
|
| 286 |
+
if row_id is None:
|
| 287 |
+
return None
|
| 288 |
+
# Fetch data before deleting
|
| 289 |
+
res_data = await sess.execute(
|
| 290 |
+
select(self._messages.c.message_data).where(self._messages.c.id == row_id)
|
| 291 |
+
)
|
| 292 |
+
row = res_data.scalar_one_or_none()
|
| 293 |
+
await sess.execute(delete(self._messages).where(self._messages.c.id == row_id))
|
| 294 |
+
|
| 295 |
+
if row is None:
|
| 296 |
+
return None
|
| 297 |
+
try:
|
| 298 |
+
return await self._deserialize_item(row)
|
| 299 |
+
except json.JSONDecodeError:
|
| 300 |
+
return None
|
| 301 |
+
|
| 302 |
+
async def clear_session(self) -> None:
|
| 303 |
+
"""Clear all items for this session."""
|
| 304 |
+
await self._ensure_tables()
|
| 305 |
+
async with self._session_factory() as sess:
|
| 306 |
+
async with sess.begin():
|
| 307 |
+
await sess.execute(
|
| 308 |
+
delete(self._messages).where(self._messages.c.session_id == self.session_id)
|
| 309 |
+
)
|
| 310 |
+
await sess.execute(
|
| 311 |
+
delete(self._sessions).where(self._sessions.c.session_id == self.session_id)
|
| 312 |
+
)
|
agents/extensions/models/__init__.py
ADDED
|
File without changes
|
agents/extensions/models/litellm_model.py
ADDED
|
@@ -0,0 +1,601 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import time
|
| 5 |
+
from collections.abc import AsyncIterator
|
| 6 |
+
from copy import copy
|
| 7 |
+
from typing import Any, Literal, cast, overload
|
| 8 |
+
|
| 9 |
+
from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails
|
| 10 |
+
|
| 11 |
+
from agents.exceptions import ModelBehaviorError
|
| 12 |
+
|
| 13 |
+
try:
|
| 14 |
+
import litellm
|
| 15 |
+
except ImportError as _e:
|
| 16 |
+
raise ImportError(
|
| 17 |
+
"`litellm` is required to use the LitellmModel. You can install it via the optional "
|
| 18 |
+
"dependency group: `pip install 'openai-agents[litellm]'`."
|
| 19 |
+
) from _e
|
| 20 |
+
|
| 21 |
+
from openai import NOT_GIVEN, AsyncStream, NotGiven
|
| 22 |
+
from openai.types.chat import (
|
| 23 |
+
ChatCompletionChunk,
|
| 24 |
+
ChatCompletionMessageCustomToolCall,
|
| 25 |
+
ChatCompletionMessageFunctionToolCall,
|
| 26 |
+
ChatCompletionMessageParam,
|
| 27 |
+
)
|
| 28 |
+
from openai.types.chat.chat_completion_message import (
|
| 29 |
+
Annotation,
|
| 30 |
+
AnnotationURLCitation,
|
| 31 |
+
ChatCompletionMessage,
|
| 32 |
+
)
|
| 33 |
+
from openai.types.chat.chat_completion_message_function_tool_call import Function
|
| 34 |
+
from openai.types.responses import Response
|
| 35 |
+
|
| 36 |
+
from ... import _debug
|
| 37 |
+
from ...agent_output import AgentOutputSchemaBase
|
| 38 |
+
from ...handoffs import Handoff
|
| 39 |
+
from ...items import ModelResponse, TResponseInputItem, TResponseStreamEvent
|
| 40 |
+
from ...logger import logger
|
| 41 |
+
from ...model_settings import ModelSettings
|
| 42 |
+
from ...models.chatcmpl_converter import Converter
|
| 43 |
+
from ...models.chatcmpl_helpers import HEADERS, HEADERS_OVERRIDE
|
| 44 |
+
from ...models.chatcmpl_stream_handler import ChatCmplStreamHandler
|
| 45 |
+
from ...models.fake_id import FAKE_RESPONSES_ID
|
| 46 |
+
from ...models.interface import Model, ModelTracing
|
| 47 |
+
from ...tool import Tool
|
| 48 |
+
from ...tracing import generation_span
|
| 49 |
+
from ...tracing.span_data import GenerationSpanData
|
| 50 |
+
from ...tracing.spans import Span
|
| 51 |
+
from ...usage import Usage
|
| 52 |
+
from ...util._json import _to_dump_compatible
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class InternalChatCompletionMessage(ChatCompletionMessage):
|
| 56 |
+
"""
|
| 57 |
+
An internal subclass to carry reasoning_content and thinking_blocks without modifying the original model.
|
| 58 |
+
""" # noqa: E501
|
| 59 |
+
|
| 60 |
+
reasoning_content: str
|
| 61 |
+
thinking_blocks: list[dict[str, Any]] | None = None
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class LitellmModel(Model):
|
| 65 |
+
"""This class enables using any model via LiteLLM. LiteLLM allows you to acess OpenAPI,
|
| 66 |
+
Anthropic, Gemini, Mistral, and many other models.
|
| 67 |
+
See supported models here: [litellm models](https://docs.litellm.ai/docs/providers).
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
def __init__(
|
| 71 |
+
self,
|
| 72 |
+
model: str,
|
| 73 |
+
base_url: str | None = None,
|
| 74 |
+
api_key: str | None = None,
|
| 75 |
+
):
|
| 76 |
+
self.model = model
|
| 77 |
+
self.base_url = base_url
|
| 78 |
+
self.api_key = api_key
|
| 79 |
+
|
| 80 |
+
async def get_response(
|
| 81 |
+
self,
|
| 82 |
+
system_instructions: str | None,
|
| 83 |
+
input: str | list[TResponseInputItem],
|
| 84 |
+
model_settings: ModelSettings,
|
| 85 |
+
tools: list[Tool],
|
| 86 |
+
output_schema: AgentOutputSchemaBase | None,
|
| 87 |
+
handoffs: list[Handoff],
|
| 88 |
+
tracing: ModelTracing,
|
| 89 |
+
previous_response_id: str | None = None, # unused
|
| 90 |
+
conversation_id: str | None = None, # unused
|
| 91 |
+
prompt: Any | None = None,
|
| 92 |
+
) -> ModelResponse:
|
| 93 |
+
with generation_span(
|
| 94 |
+
model=str(self.model),
|
| 95 |
+
model_config=model_settings.to_json_dict()
|
| 96 |
+
| {"base_url": str(self.base_url or ""), "model_impl": "litellm"},
|
| 97 |
+
disabled=tracing.is_disabled(),
|
| 98 |
+
) as span_generation:
|
| 99 |
+
response = await self._fetch_response(
|
| 100 |
+
system_instructions,
|
| 101 |
+
input,
|
| 102 |
+
model_settings,
|
| 103 |
+
tools,
|
| 104 |
+
output_schema,
|
| 105 |
+
handoffs,
|
| 106 |
+
span_generation,
|
| 107 |
+
tracing,
|
| 108 |
+
stream=False,
|
| 109 |
+
prompt=prompt,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
assert isinstance(response.choices[0], litellm.types.utils.Choices)
|
| 113 |
+
|
| 114 |
+
if _debug.DONT_LOG_MODEL_DATA:
|
| 115 |
+
logger.debug("Received model response")
|
| 116 |
+
else:
|
| 117 |
+
logger.debug(
|
| 118 |
+
f"""LLM resp:\n{
|
| 119 |
+
json.dumps(
|
| 120 |
+
response.choices[0].message.model_dump(), indent=2, ensure_ascii=False
|
| 121 |
+
)
|
| 122 |
+
}\n"""
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
if hasattr(response, "usage"):
|
| 126 |
+
response_usage = response.usage
|
| 127 |
+
usage = (
|
| 128 |
+
Usage(
|
| 129 |
+
requests=1,
|
| 130 |
+
input_tokens=response_usage.prompt_tokens,
|
| 131 |
+
output_tokens=response_usage.completion_tokens,
|
| 132 |
+
total_tokens=response_usage.total_tokens,
|
| 133 |
+
input_tokens_details=InputTokensDetails(
|
| 134 |
+
cached_tokens=getattr(
|
| 135 |
+
response_usage.prompt_tokens_details, "cached_tokens", 0
|
| 136 |
+
)
|
| 137 |
+
or 0
|
| 138 |
+
),
|
| 139 |
+
output_tokens_details=OutputTokensDetails(
|
| 140 |
+
reasoning_tokens=getattr(
|
| 141 |
+
response_usage.completion_tokens_details, "reasoning_tokens", 0
|
| 142 |
+
)
|
| 143 |
+
or 0
|
| 144 |
+
),
|
| 145 |
+
)
|
| 146 |
+
if response.usage
|
| 147 |
+
else Usage()
|
| 148 |
+
)
|
| 149 |
+
else:
|
| 150 |
+
usage = Usage()
|
| 151 |
+
logger.warning("No usage information returned from Litellm")
|
| 152 |
+
|
| 153 |
+
if tracing.include_data():
|
| 154 |
+
span_generation.span_data.output = [response.choices[0].message.model_dump()]
|
| 155 |
+
span_generation.span_data.usage = {
|
| 156 |
+
"input_tokens": usage.input_tokens,
|
| 157 |
+
"output_tokens": usage.output_tokens,
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
items = Converter.message_to_output_items(
|
| 161 |
+
LitellmConverter.convert_message_to_openai(response.choices[0].message)
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
return ModelResponse(
|
| 165 |
+
output=items,
|
| 166 |
+
usage=usage,
|
| 167 |
+
response_id=None,
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
async def stream_response(
|
| 171 |
+
self,
|
| 172 |
+
system_instructions: str | None,
|
| 173 |
+
input: str | list[TResponseInputItem],
|
| 174 |
+
model_settings: ModelSettings,
|
| 175 |
+
tools: list[Tool],
|
| 176 |
+
output_schema: AgentOutputSchemaBase | None,
|
| 177 |
+
handoffs: list[Handoff],
|
| 178 |
+
tracing: ModelTracing,
|
| 179 |
+
previous_response_id: str | None = None, # unused
|
| 180 |
+
conversation_id: str | None = None, # unused
|
| 181 |
+
prompt: Any | None = None,
|
| 182 |
+
) -> AsyncIterator[TResponseStreamEvent]:
|
| 183 |
+
with generation_span(
|
| 184 |
+
model=str(self.model),
|
| 185 |
+
model_config=model_settings.to_json_dict()
|
| 186 |
+
| {"base_url": str(self.base_url or ""), "model_impl": "litellm"},
|
| 187 |
+
disabled=tracing.is_disabled(),
|
| 188 |
+
) as span_generation:
|
| 189 |
+
response, stream = await self._fetch_response(
|
| 190 |
+
system_instructions,
|
| 191 |
+
input,
|
| 192 |
+
model_settings,
|
| 193 |
+
tools,
|
| 194 |
+
output_schema,
|
| 195 |
+
handoffs,
|
| 196 |
+
span_generation,
|
| 197 |
+
tracing,
|
| 198 |
+
stream=True,
|
| 199 |
+
prompt=prompt,
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
final_response: Response | None = None
|
| 203 |
+
async for chunk in ChatCmplStreamHandler.handle_stream(response, stream):
|
| 204 |
+
yield chunk
|
| 205 |
+
|
| 206 |
+
if chunk.type == "response.completed":
|
| 207 |
+
final_response = chunk.response
|
| 208 |
+
|
| 209 |
+
if tracing.include_data() and final_response:
|
| 210 |
+
span_generation.span_data.output = [final_response.model_dump()]
|
| 211 |
+
|
| 212 |
+
if final_response and final_response.usage:
|
| 213 |
+
span_generation.span_data.usage = {
|
| 214 |
+
"input_tokens": final_response.usage.input_tokens,
|
| 215 |
+
"output_tokens": final_response.usage.output_tokens,
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
@overload
|
| 219 |
+
async def _fetch_response(
|
| 220 |
+
self,
|
| 221 |
+
system_instructions: str | None,
|
| 222 |
+
input: str | list[TResponseInputItem],
|
| 223 |
+
model_settings: ModelSettings,
|
| 224 |
+
tools: list[Tool],
|
| 225 |
+
output_schema: AgentOutputSchemaBase | None,
|
| 226 |
+
handoffs: list[Handoff],
|
| 227 |
+
span: Span[GenerationSpanData],
|
| 228 |
+
tracing: ModelTracing,
|
| 229 |
+
stream: Literal[True],
|
| 230 |
+
prompt: Any | None = None,
|
| 231 |
+
) -> tuple[Response, AsyncStream[ChatCompletionChunk]]: ...
|
| 232 |
+
|
| 233 |
+
@overload
|
| 234 |
+
async def _fetch_response(
|
| 235 |
+
self,
|
| 236 |
+
system_instructions: str | None,
|
| 237 |
+
input: str | list[TResponseInputItem],
|
| 238 |
+
model_settings: ModelSettings,
|
| 239 |
+
tools: list[Tool],
|
| 240 |
+
output_schema: AgentOutputSchemaBase | None,
|
| 241 |
+
handoffs: list[Handoff],
|
| 242 |
+
span: Span[GenerationSpanData],
|
| 243 |
+
tracing: ModelTracing,
|
| 244 |
+
stream: Literal[False],
|
| 245 |
+
prompt: Any | None = None,
|
| 246 |
+
) -> litellm.types.utils.ModelResponse: ...
|
| 247 |
+
|
| 248 |
+
async def _fetch_response(
|
| 249 |
+
self,
|
| 250 |
+
system_instructions: str | None,
|
| 251 |
+
input: str | list[TResponseInputItem],
|
| 252 |
+
model_settings: ModelSettings,
|
| 253 |
+
tools: list[Tool],
|
| 254 |
+
output_schema: AgentOutputSchemaBase | None,
|
| 255 |
+
handoffs: list[Handoff],
|
| 256 |
+
span: Span[GenerationSpanData],
|
| 257 |
+
tracing: ModelTracing,
|
| 258 |
+
stream: bool = False,
|
| 259 |
+
prompt: Any | None = None,
|
| 260 |
+
) -> litellm.types.utils.ModelResponse | tuple[Response, AsyncStream[ChatCompletionChunk]]:
|
| 261 |
+
# Preserve reasoning messages for tool calls when reasoning is on
|
| 262 |
+
# This is needed for models like Claude 4 Sonnet/Opus which support interleaved thinking
|
| 263 |
+
preserve_thinking_blocks = (
|
| 264 |
+
model_settings.reasoning is not None and model_settings.reasoning.effort is not None
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
converted_messages = Converter.items_to_messages(
|
| 268 |
+
input, preserve_thinking_blocks=preserve_thinking_blocks
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
# Fix for interleaved thinking bug: reorder messages to ensure tool_use comes before tool_result # noqa: E501
|
| 272 |
+
if preserve_thinking_blocks:
|
| 273 |
+
converted_messages = self._fix_tool_message_ordering(converted_messages)
|
| 274 |
+
|
| 275 |
+
if system_instructions:
|
| 276 |
+
converted_messages.insert(
|
| 277 |
+
0,
|
| 278 |
+
{
|
| 279 |
+
"content": system_instructions,
|
| 280 |
+
"role": "system",
|
| 281 |
+
},
|
| 282 |
+
)
|
| 283 |
+
converted_messages = _to_dump_compatible(converted_messages)
|
| 284 |
+
|
| 285 |
+
if tracing.include_data():
|
| 286 |
+
span.span_data.input = converted_messages
|
| 287 |
+
|
| 288 |
+
parallel_tool_calls = (
|
| 289 |
+
True
|
| 290 |
+
if model_settings.parallel_tool_calls and tools and len(tools) > 0
|
| 291 |
+
else False
|
| 292 |
+
if model_settings.parallel_tool_calls is False
|
| 293 |
+
else None
|
| 294 |
+
)
|
| 295 |
+
tool_choice = Converter.convert_tool_choice(model_settings.tool_choice)
|
| 296 |
+
response_format = Converter.convert_response_format(output_schema)
|
| 297 |
+
|
| 298 |
+
converted_tools = [Converter.tool_to_openai(tool) for tool in tools] if tools else []
|
| 299 |
+
|
| 300 |
+
for handoff in handoffs:
|
| 301 |
+
converted_tools.append(Converter.convert_handoff_tool(handoff))
|
| 302 |
+
|
| 303 |
+
converted_tools = _to_dump_compatible(converted_tools)
|
| 304 |
+
|
| 305 |
+
if _debug.DONT_LOG_MODEL_DATA:
|
| 306 |
+
logger.debug("Calling LLM")
|
| 307 |
+
else:
|
| 308 |
+
messages_json = json.dumps(
|
| 309 |
+
converted_messages,
|
| 310 |
+
indent=2,
|
| 311 |
+
ensure_ascii=False,
|
| 312 |
+
)
|
| 313 |
+
tools_json = json.dumps(
|
| 314 |
+
converted_tools,
|
| 315 |
+
indent=2,
|
| 316 |
+
ensure_ascii=False,
|
| 317 |
+
)
|
| 318 |
+
logger.debug(
|
| 319 |
+
f"Calling Litellm model: {self.model}\n"
|
| 320 |
+
f"{messages_json}\n"
|
| 321 |
+
f"Tools:\n{tools_json}\n"
|
| 322 |
+
f"Stream: {stream}\n"
|
| 323 |
+
f"Tool choice: {tool_choice}\n"
|
| 324 |
+
f"Response format: {response_format}\n"
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
reasoning_effort = model_settings.reasoning.effort if model_settings.reasoning else None
|
| 328 |
+
|
| 329 |
+
stream_options = None
|
| 330 |
+
if stream and model_settings.include_usage is not None:
|
| 331 |
+
stream_options = {"include_usage": model_settings.include_usage}
|
| 332 |
+
|
| 333 |
+
extra_kwargs = {}
|
| 334 |
+
if model_settings.extra_query:
|
| 335 |
+
extra_kwargs["extra_query"] = copy(model_settings.extra_query)
|
| 336 |
+
if model_settings.metadata:
|
| 337 |
+
extra_kwargs["metadata"] = copy(model_settings.metadata)
|
| 338 |
+
if model_settings.extra_body and isinstance(model_settings.extra_body, dict):
|
| 339 |
+
extra_kwargs.update(model_settings.extra_body)
|
| 340 |
+
|
| 341 |
+
# Add kwargs from model_settings.extra_args, filtering out None values
|
| 342 |
+
if model_settings.extra_args:
|
| 343 |
+
extra_kwargs.update(model_settings.extra_args)
|
| 344 |
+
|
| 345 |
+
ret = await litellm.acompletion(
|
| 346 |
+
model=self.model,
|
| 347 |
+
messages=converted_messages,
|
| 348 |
+
tools=converted_tools or None,
|
| 349 |
+
temperature=model_settings.temperature,
|
| 350 |
+
top_p=model_settings.top_p,
|
| 351 |
+
frequency_penalty=model_settings.frequency_penalty,
|
| 352 |
+
presence_penalty=model_settings.presence_penalty,
|
| 353 |
+
max_tokens=model_settings.max_tokens,
|
| 354 |
+
tool_choice=self._remove_not_given(tool_choice),
|
| 355 |
+
response_format=self._remove_not_given(response_format),
|
| 356 |
+
parallel_tool_calls=parallel_tool_calls,
|
| 357 |
+
stream=stream,
|
| 358 |
+
stream_options=stream_options,
|
| 359 |
+
reasoning_effort=reasoning_effort,
|
| 360 |
+
top_logprobs=model_settings.top_logprobs,
|
| 361 |
+
extra_headers=self._merge_headers(model_settings),
|
| 362 |
+
api_key=self.api_key,
|
| 363 |
+
base_url=self.base_url,
|
| 364 |
+
**extra_kwargs,
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
if isinstance(ret, litellm.types.utils.ModelResponse):
|
| 368 |
+
return ret
|
| 369 |
+
|
| 370 |
+
response = Response(
|
| 371 |
+
id=FAKE_RESPONSES_ID,
|
| 372 |
+
created_at=time.time(),
|
| 373 |
+
model=self.model,
|
| 374 |
+
object="response",
|
| 375 |
+
output=[],
|
| 376 |
+
tool_choice=cast(Literal["auto", "required", "none"], tool_choice)
|
| 377 |
+
if tool_choice != NOT_GIVEN
|
| 378 |
+
else "auto",
|
| 379 |
+
top_p=model_settings.top_p,
|
| 380 |
+
temperature=model_settings.temperature,
|
| 381 |
+
tools=[],
|
| 382 |
+
parallel_tool_calls=parallel_tool_calls or False,
|
| 383 |
+
reasoning=model_settings.reasoning,
|
| 384 |
+
)
|
| 385 |
+
return response, ret
|
| 386 |
+
|
| 387 |
+
def _fix_tool_message_ordering(
|
| 388 |
+
self, messages: list[ChatCompletionMessageParam]
|
| 389 |
+
) -> list[ChatCompletionMessageParam]:
|
| 390 |
+
"""
|
| 391 |
+
Fix the ordering of tool messages to ensure tool_use messages come before tool_result messages.
|
| 392 |
+
|
| 393 |
+
This addresses the interleaved thinking bug where conversation histories may contain
|
| 394 |
+
tool results before their corresponding tool calls, causing Anthropic API to reject the request.
|
| 395 |
+
""" # noqa: E501
|
| 396 |
+
if not messages:
|
| 397 |
+
return messages
|
| 398 |
+
|
| 399 |
+
# Collect all tool calls and tool results
|
| 400 |
+
tool_call_messages = {} # tool_id -> (index, message)
|
| 401 |
+
tool_result_messages = {} # tool_id -> (index, message)
|
| 402 |
+
other_messages = [] # (index, message) for non-tool messages
|
| 403 |
+
|
| 404 |
+
for i, message in enumerate(messages):
|
| 405 |
+
if not isinstance(message, dict):
|
| 406 |
+
other_messages.append((i, message))
|
| 407 |
+
continue
|
| 408 |
+
|
| 409 |
+
role = message.get("role")
|
| 410 |
+
|
| 411 |
+
if role == "assistant" and message.get("tool_calls"):
|
| 412 |
+
# Extract tool calls from this assistant message
|
| 413 |
+
tool_calls = message.get("tool_calls", [])
|
| 414 |
+
if isinstance(tool_calls, list):
|
| 415 |
+
for tool_call in tool_calls:
|
| 416 |
+
if isinstance(tool_call, dict):
|
| 417 |
+
tool_id = tool_call.get("id")
|
| 418 |
+
if tool_id:
|
| 419 |
+
# Create a separate assistant message for each tool call
|
| 420 |
+
single_tool_msg = cast(dict[str, Any], message.copy())
|
| 421 |
+
single_tool_msg["tool_calls"] = [tool_call]
|
| 422 |
+
tool_call_messages[tool_id] = (
|
| 423 |
+
i,
|
| 424 |
+
cast(ChatCompletionMessageParam, single_tool_msg),
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
elif role == "tool":
|
| 428 |
+
tool_call_id = message.get("tool_call_id")
|
| 429 |
+
if tool_call_id:
|
| 430 |
+
tool_result_messages[tool_call_id] = (i, message)
|
| 431 |
+
else:
|
| 432 |
+
other_messages.append((i, message))
|
| 433 |
+
else:
|
| 434 |
+
other_messages.append((i, message))
|
| 435 |
+
|
| 436 |
+
# First, identify which tool results will be paired to avoid duplicates
|
| 437 |
+
paired_tool_result_indices = set()
|
| 438 |
+
for tool_id in tool_call_messages:
|
| 439 |
+
if tool_id in tool_result_messages:
|
| 440 |
+
tool_result_idx, _ = tool_result_messages[tool_id]
|
| 441 |
+
paired_tool_result_indices.add(tool_result_idx)
|
| 442 |
+
|
| 443 |
+
# Create the fixed message sequence
|
| 444 |
+
fixed_messages: list[ChatCompletionMessageParam] = []
|
| 445 |
+
used_indices = set()
|
| 446 |
+
|
| 447 |
+
# Add messages in their original order, but ensure tool_use → tool_result pairing
|
| 448 |
+
for i, original_message in enumerate(messages):
|
| 449 |
+
if i in used_indices:
|
| 450 |
+
continue
|
| 451 |
+
|
| 452 |
+
if not isinstance(original_message, dict):
|
| 453 |
+
fixed_messages.append(original_message)
|
| 454 |
+
used_indices.add(i)
|
| 455 |
+
continue
|
| 456 |
+
|
| 457 |
+
role = original_message.get("role")
|
| 458 |
+
|
| 459 |
+
if role == "assistant" and original_message.get("tool_calls"):
|
| 460 |
+
# Process each tool call in this assistant message
|
| 461 |
+
tool_calls = original_message.get("tool_calls", [])
|
| 462 |
+
if isinstance(tool_calls, list):
|
| 463 |
+
for tool_call in tool_calls:
|
| 464 |
+
if isinstance(tool_call, dict):
|
| 465 |
+
tool_id = tool_call.get("id")
|
| 466 |
+
if (
|
| 467 |
+
tool_id
|
| 468 |
+
and tool_id in tool_call_messages
|
| 469 |
+
and tool_id in tool_result_messages
|
| 470 |
+
):
|
| 471 |
+
# Add tool_use → tool_result pair
|
| 472 |
+
_, tool_call_msg = tool_call_messages[tool_id]
|
| 473 |
+
tool_result_idx, tool_result_msg = tool_result_messages[tool_id]
|
| 474 |
+
|
| 475 |
+
fixed_messages.append(tool_call_msg)
|
| 476 |
+
fixed_messages.append(tool_result_msg)
|
| 477 |
+
|
| 478 |
+
# Mark both as used
|
| 479 |
+
used_indices.add(tool_call_messages[tool_id][0])
|
| 480 |
+
used_indices.add(tool_result_idx)
|
| 481 |
+
elif tool_id and tool_id in tool_call_messages:
|
| 482 |
+
# Tool call without result - add just the tool call
|
| 483 |
+
_, tool_call_msg = tool_call_messages[tool_id]
|
| 484 |
+
fixed_messages.append(tool_call_msg)
|
| 485 |
+
used_indices.add(tool_call_messages[tool_id][0])
|
| 486 |
+
|
| 487 |
+
used_indices.add(i) # Mark original multi-tool message as used
|
| 488 |
+
|
| 489 |
+
elif role == "tool":
|
| 490 |
+
# Only preserve unmatched tool results to avoid duplicates
|
| 491 |
+
if i not in paired_tool_result_indices:
|
| 492 |
+
fixed_messages.append(original_message)
|
| 493 |
+
used_indices.add(i)
|
| 494 |
+
|
| 495 |
+
else:
|
| 496 |
+
# Regular message - add it normally
|
| 497 |
+
fixed_messages.append(original_message)
|
| 498 |
+
used_indices.add(i)
|
| 499 |
+
|
| 500 |
+
return fixed_messages
|
| 501 |
+
|
| 502 |
+
def _remove_not_given(self, value: Any) -> Any:
|
| 503 |
+
if isinstance(value, NotGiven):
|
| 504 |
+
return None
|
| 505 |
+
return value
|
| 506 |
+
|
| 507 |
+
def _merge_headers(self, model_settings: ModelSettings):
|
| 508 |
+
return {**HEADERS, **(model_settings.extra_headers or {}), **(HEADERS_OVERRIDE.get() or {})}
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
class LitellmConverter:
|
| 512 |
+
@classmethod
|
| 513 |
+
def convert_message_to_openai(
|
| 514 |
+
cls, message: litellm.types.utils.Message
|
| 515 |
+
) -> ChatCompletionMessage:
|
| 516 |
+
if message.role != "assistant":
|
| 517 |
+
raise ModelBehaviorError(f"Unsupported role: {message.role}")
|
| 518 |
+
|
| 519 |
+
tool_calls: (
|
| 520 |
+
list[ChatCompletionMessageFunctionToolCall | ChatCompletionMessageCustomToolCall] | None
|
| 521 |
+
) = (
|
| 522 |
+
[LitellmConverter.convert_tool_call_to_openai(tool) for tool in message.tool_calls]
|
| 523 |
+
if message.tool_calls
|
| 524 |
+
else None
|
| 525 |
+
)
|
| 526 |
+
|
| 527 |
+
provider_specific_fields = message.get("provider_specific_fields", None)
|
| 528 |
+
refusal = (
|
| 529 |
+
provider_specific_fields.get("refusal", None) if provider_specific_fields else None
|
| 530 |
+
)
|
| 531 |
+
|
| 532 |
+
reasoning_content = ""
|
| 533 |
+
if hasattr(message, "reasoning_content") and message.reasoning_content:
|
| 534 |
+
reasoning_content = message.reasoning_content
|
| 535 |
+
|
| 536 |
+
# Extract full thinking blocks including signatures (for Anthropic)
|
| 537 |
+
thinking_blocks: list[dict[str, Any]] | None = None
|
| 538 |
+
if hasattr(message, "thinking_blocks") and message.thinking_blocks:
|
| 539 |
+
# Convert thinking blocks to dict format for compatibility
|
| 540 |
+
thinking_blocks = []
|
| 541 |
+
for block in message.thinking_blocks:
|
| 542 |
+
if isinstance(block, dict):
|
| 543 |
+
thinking_blocks.append(cast(dict[str, Any], block))
|
| 544 |
+
else:
|
| 545 |
+
# Convert object to dict by accessing its attributes
|
| 546 |
+
block_dict: dict[str, Any] = {}
|
| 547 |
+
if hasattr(block, "__dict__"):
|
| 548 |
+
block_dict = dict(block.__dict__.items())
|
| 549 |
+
elif hasattr(block, "model_dump"):
|
| 550 |
+
block_dict = block.model_dump()
|
| 551 |
+
else:
|
| 552 |
+
# Last resort: convert to string representation
|
| 553 |
+
block_dict = {"thinking": str(block)}
|
| 554 |
+
thinking_blocks.append(block_dict)
|
| 555 |
+
|
| 556 |
+
return InternalChatCompletionMessage(
|
| 557 |
+
content=message.content,
|
| 558 |
+
refusal=refusal,
|
| 559 |
+
role="assistant",
|
| 560 |
+
annotations=cls.convert_annotations_to_openai(message),
|
| 561 |
+
audio=message.get("audio", None), # litellm deletes audio if not present
|
| 562 |
+
tool_calls=tool_calls,
|
| 563 |
+
reasoning_content=reasoning_content,
|
| 564 |
+
thinking_blocks=thinking_blocks,
|
| 565 |
+
)
|
| 566 |
+
|
| 567 |
+
@classmethod
|
| 568 |
+
def convert_annotations_to_openai(
|
| 569 |
+
cls, message: litellm.types.utils.Message
|
| 570 |
+
) -> list[Annotation] | None:
|
| 571 |
+
annotations: list[litellm.types.llms.openai.ChatCompletionAnnotation] | None = message.get(
|
| 572 |
+
"annotations", None
|
| 573 |
+
)
|
| 574 |
+
if not annotations:
|
| 575 |
+
return None
|
| 576 |
+
|
| 577 |
+
return [
|
| 578 |
+
Annotation(
|
| 579 |
+
type="url_citation",
|
| 580 |
+
url_citation=AnnotationURLCitation(
|
| 581 |
+
start_index=annotation["url_citation"]["start_index"],
|
| 582 |
+
end_index=annotation["url_citation"]["end_index"],
|
| 583 |
+
url=annotation["url_citation"]["url"],
|
| 584 |
+
title=annotation["url_citation"]["title"],
|
| 585 |
+
),
|
| 586 |
+
)
|
| 587 |
+
for annotation in annotations
|
| 588 |
+
]
|
| 589 |
+
|
| 590 |
+
@classmethod
|
| 591 |
+
def convert_tool_call_to_openai(
|
| 592 |
+
cls, tool_call: litellm.types.utils.ChatCompletionMessageToolCall
|
| 593 |
+
) -> ChatCompletionMessageFunctionToolCall:
|
| 594 |
+
return ChatCompletionMessageFunctionToolCall(
|
| 595 |
+
id=tool_call.id,
|
| 596 |
+
type="function",
|
| 597 |
+
function=Function(
|
| 598 |
+
name=tool_call.function.name or "",
|
| 599 |
+
arguments=tool_call.function.arguments,
|
| 600 |
+
),
|
| 601 |
+
)
|
agents/extensions/models/litellm_provider.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ...models.default_models import get_default_model
|
| 2 |
+
from ...models.interface import Model, ModelProvider
|
| 3 |
+
from .litellm_model import LitellmModel
|
| 4 |
+
|
| 5 |
+
# This is kept for backward compatiblity but using get_default_model() method is recommended.
|
| 6 |
+
DEFAULT_MODEL: str = "gpt-4.1"
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class LitellmProvider(ModelProvider):
|
| 10 |
+
"""A ModelProvider that uses LiteLLM to route to any model provider. You can use it via:
|
| 11 |
+
```python
|
| 12 |
+
Runner.run(agent, input, run_config=RunConfig(model_provider=LitellmProvider()))
|
| 13 |
+
```
|
| 14 |
+
See supported models here: [litellm models](https://docs.litellm.ai/docs/providers).
|
| 15 |
+
|
| 16 |
+
NOTE: API keys must be set via environment variables. If you're using models that require
|
| 17 |
+
additional configuration (e.g. Azure API base or version), those must also be set via the
|
| 18 |
+
environment variables that LiteLLM expects. If you have more advanced needs, we recommend
|
| 19 |
+
copy-pasting this class and making any modifications you need.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def get_model(self, model_name: str | None) -> Model:
|
| 23 |
+
return LitellmModel(model_name or get_default_model())
|
agents/extensions/visualization.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import graphviz # type: ignore
|
| 4 |
+
|
| 5 |
+
from agents import Agent
|
| 6 |
+
from agents.handoffs import Handoff
|
| 7 |
+
from agents.tool import Tool
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def get_main_graph(agent: Agent) -> str:
|
| 11 |
+
"""
|
| 12 |
+
Generates the main graph structure in DOT format for the given agent.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
agent (Agent): The agent for which the graph is to be generated.
|
| 16 |
+
|
| 17 |
+
Returns:
|
| 18 |
+
str: The DOT format string representing the graph.
|
| 19 |
+
"""
|
| 20 |
+
parts = [
|
| 21 |
+
"""
|
| 22 |
+
digraph G {
|
| 23 |
+
graph [splines=true];
|
| 24 |
+
node [fontname="Arial"];
|
| 25 |
+
edge [penwidth=1.5];
|
| 26 |
+
"""
|
| 27 |
+
]
|
| 28 |
+
parts.append(get_all_nodes(agent))
|
| 29 |
+
parts.append(get_all_edges(agent))
|
| 30 |
+
parts.append("}")
|
| 31 |
+
return "".join(parts)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def get_all_nodes(
|
| 35 |
+
agent: Agent, parent: Agent | None = None, visited: set[str] | None = None
|
| 36 |
+
) -> str:
|
| 37 |
+
"""
|
| 38 |
+
Recursively generates the nodes for the given agent and its handoffs in DOT format.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
agent (Agent): The agent for which the nodes are to be generated.
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
str: The DOT format string representing the nodes.
|
| 45 |
+
"""
|
| 46 |
+
if visited is None:
|
| 47 |
+
visited = set()
|
| 48 |
+
if agent.name in visited:
|
| 49 |
+
return ""
|
| 50 |
+
visited.add(agent.name)
|
| 51 |
+
|
| 52 |
+
parts = []
|
| 53 |
+
|
| 54 |
+
# Start and end the graph
|
| 55 |
+
if not parent:
|
| 56 |
+
parts.append(
|
| 57 |
+
'"__start__" [label="__start__", shape=ellipse, style=filled, '
|
| 58 |
+
"fillcolor=lightblue, width=0.5, height=0.3];"
|
| 59 |
+
'"__end__" [label="__end__", shape=ellipse, style=filled, '
|
| 60 |
+
"fillcolor=lightblue, width=0.5, height=0.3];"
|
| 61 |
+
)
|
| 62 |
+
# Ensure parent agent node is colored
|
| 63 |
+
parts.append(
|
| 64 |
+
f'"{agent.name}" [label="{agent.name}", shape=box, style=filled, '
|
| 65 |
+
"fillcolor=lightyellow, width=1.5, height=0.8];"
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
for tool in agent.tools:
|
| 69 |
+
parts.append(
|
| 70 |
+
f'"{tool.name}" [label="{tool.name}", shape=ellipse, style=filled, '
|
| 71 |
+
f"fillcolor=lightgreen, width=0.5, height=0.3];"
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
for mcp_server in agent.mcp_servers:
|
| 75 |
+
parts.append(
|
| 76 |
+
f'"{mcp_server.name}" [label="{mcp_server.name}", shape=box, style=filled, '
|
| 77 |
+
f"fillcolor=lightgrey, width=1, height=0.5];"
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
for handoff in agent.handoffs:
|
| 81 |
+
if isinstance(handoff, Handoff):
|
| 82 |
+
parts.append(
|
| 83 |
+
f'"{handoff.agent_name}" [label="{handoff.agent_name}", '
|
| 84 |
+
f"shape=box, style=filled, style=rounded, "
|
| 85 |
+
f"fillcolor=lightyellow, width=1.5, height=0.8];"
|
| 86 |
+
)
|
| 87 |
+
if isinstance(handoff, Agent):
|
| 88 |
+
if handoff.name not in visited:
|
| 89 |
+
parts.append(
|
| 90 |
+
f'"{handoff.name}" [label="{handoff.name}", '
|
| 91 |
+
f"shape=box, style=filled, style=rounded, "
|
| 92 |
+
f"fillcolor=lightyellow, width=1.5, height=0.8];"
|
| 93 |
+
)
|
| 94 |
+
parts.append(get_all_nodes(handoff, agent, visited))
|
| 95 |
+
|
| 96 |
+
return "".join(parts)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def get_all_edges(
|
| 100 |
+
agent: Agent, parent: Agent | None = None, visited: set[str] | None = None
|
| 101 |
+
) -> str:
|
| 102 |
+
"""
|
| 103 |
+
Recursively generates the edges for the given agent and its handoffs in DOT format.
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
agent (Agent): The agent for which the edges are to be generated.
|
| 107 |
+
parent (Agent, optional): The parent agent. Defaults to None.
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
str: The DOT format string representing the edges.
|
| 111 |
+
"""
|
| 112 |
+
if visited is None:
|
| 113 |
+
visited = set()
|
| 114 |
+
if agent.name in visited:
|
| 115 |
+
return ""
|
| 116 |
+
visited.add(agent.name)
|
| 117 |
+
|
| 118 |
+
parts = []
|
| 119 |
+
|
| 120 |
+
if not parent:
|
| 121 |
+
parts.append(f'"__start__" -> "{agent.name}";')
|
| 122 |
+
|
| 123 |
+
for tool in agent.tools:
|
| 124 |
+
parts.append(f"""
|
| 125 |
+
"{agent.name}" -> "{tool.name}" [style=dotted, penwidth=1.5];
|
| 126 |
+
"{tool.name}" -> "{agent.name}" [style=dotted, penwidth=1.5];""")
|
| 127 |
+
|
| 128 |
+
for mcp_server in agent.mcp_servers:
|
| 129 |
+
parts.append(f"""
|
| 130 |
+
"{agent.name}" -> "{mcp_server.name}" [style=dashed, penwidth=1.5];
|
| 131 |
+
"{mcp_server.name}" -> "{agent.name}" [style=dashed, penwidth=1.5];""")
|
| 132 |
+
|
| 133 |
+
for handoff in agent.handoffs:
|
| 134 |
+
if isinstance(handoff, Handoff):
|
| 135 |
+
parts.append(f"""
|
| 136 |
+
"{agent.name}" -> "{handoff.agent_name}";""")
|
| 137 |
+
if isinstance(handoff, Agent):
|
| 138 |
+
parts.append(f"""
|
| 139 |
+
"{agent.name}" -> "{handoff.name}";""")
|
| 140 |
+
parts.append(get_all_edges(handoff, agent, visited))
|
| 141 |
+
|
| 142 |
+
if not agent.handoffs and not isinstance(agent, Tool): # type: ignore
|
| 143 |
+
parts.append(f'"{agent.name}" -> "__end__";')
|
| 144 |
+
|
| 145 |
+
return "".join(parts)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def draw_graph(agent: Agent, filename: str | None = None) -> graphviz.Source:
|
| 149 |
+
"""
|
| 150 |
+
Draws the graph for the given agent and optionally saves it as a PNG file.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
agent (Agent): The agent for which the graph is to be drawn.
|
| 154 |
+
filename (str): The name of the file to save the graph as a PNG.
|
| 155 |
+
|
| 156 |
+
Returns:
|
| 157 |
+
graphviz.Source: The graphviz Source object representing the graph.
|
| 158 |
+
"""
|
| 159 |
+
dot_code = get_main_graph(agent)
|
| 160 |
+
graph = graphviz.Source(dot_code)
|
| 161 |
+
|
| 162 |
+
if filename:
|
| 163 |
+
graph.render(filename, format="png", cleanup=True)
|
| 164 |
+
|
| 165 |
+
return graph
|
agents/function_schema.py
ADDED
|
@@ -0,0 +1,398 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import contextlib
|
| 4 |
+
import inspect
|
| 5 |
+
import logging
|
| 6 |
+
import re
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from typing import Annotated, Any, Callable, Literal, get_args, get_origin, get_type_hints
|
| 9 |
+
|
| 10 |
+
from griffe import Docstring, DocstringSectionKind
|
| 11 |
+
from pydantic import BaseModel, Field, create_model
|
| 12 |
+
from pydantic.fields import FieldInfo
|
| 13 |
+
|
| 14 |
+
from .exceptions import UserError
|
| 15 |
+
from .run_context import RunContextWrapper
|
| 16 |
+
from .strict_schema import ensure_strict_json_schema
|
| 17 |
+
from .tool_context import ToolContext
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class FuncSchema:
|
| 22 |
+
"""
|
| 23 |
+
Captures the schema for a python function, in preparation for sending it to an LLM as a tool.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
name: str
|
| 27 |
+
"""The name of the function."""
|
| 28 |
+
description: str | None
|
| 29 |
+
"""The description of the function."""
|
| 30 |
+
params_pydantic_model: type[BaseModel]
|
| 31 |
+
"""A Pydantic model that represents the function's parameters."""
|
| 32 |
+
params_json_schema: dict[str, Any]
|
| 33 |
+
"""The JSON schema for the function's parameters, derived from the Pydantic model."""
|
| 34 |
+
signature: inspect.Signature
|
| 35 |
+
"""The signature of the function."""
|
| 36 |
+
takes_context: bool = False
|
| 37 |
+
"""Whether the function takes a RunContextWrapper argument (must be the first argument)."""
|
| 38 |
+
strict_json_schema: bool = True
|
| 39 |
+
"""Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True,
|
| 40 |
+
as it increases the likelihood of correct JSON input."""
|
| 41 |
+
|
| 42 |
+
def to_call_args(self, data: BaseModel) -> tuple[list[Any], dict[str, Any]]:
|
| 43 |
+
"""
|
| 44 |
+
Converts validated data from the Pydantic model into (args, kwargs), suitable for calling
|
| 45 |
+
the original function.
|
| 46 |
+
"""
|
| 47 |
+
positional_args: list[Any] = []
|
| 48 |
+
keyword_args: dict[str, Any] = {}
|
| 49 |
+
seen_var_positional = False
|
| 50 |
+
|
| 51 |
+
# Use enumerate() so we can skip the first parameter if it's context.
|
| 52 |
+
for idx, (name, param) in enumerate(self.signature.parameters.items()):
|
| 53 |
+
# If the function takes a RunContextWrapper and this is the first parameter, skip it.
|
| 54 |
+
if self.takes_context and idx == 0:
|
| 55 |
+
continue
|
| 56 |
+
|
| 57 |
+
value = getattr(data, name, None)
|
| 58 |
+
if param.kind == param.VAR_POSITIONAL:
|
| 59 |
+
# e.g. *args: extend positional args and mark that *args is now seen
|
| 60 |
+
positional_args.extend(value or [])
|
| 61 |
+
seen_var_positional = True
|
| 62 |
+
elif param.kind == param.VAR_KEYWORD:
|
| 63 |
+
# e.g. **kwargs handling
|
| 64 |
+
keyword_args.update(value or {})
|
| 65 |
+
elif param.kind in (param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD):
|
| 66 |
+
# Before *args, add to positional args. After *args, add to keyword args.
|
| 67 |
+
if not seen_var_positional:
|
| 68 |
+
positional_args.append(value)
|
| 69 |
+
else:
|
| 70 |
+
keyword_args[name] = value
|
| 71 |
+
else:
|
| 72 |
+
# For KEYWORD_ONLY parameters, always use keyword args.
|
| 73 |
+
keyword_args[name] = value
|
| 74 |
+
return positional_args, keyword_args
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
@dataclass
|
| 78 |
+
class FuncDocumentation:
|
| 79 |
+
"""Contains metadata about a Python function, extracted from its docstring."""
|
| 80 |
+
|
| 81 |
+
name: str
|
| 82 |
+
"""The name of the function, via `__name__`."""
|
| 83 |
+
description: str | None
|
| 84 |
+
"""The description of the function, derived from the docstring."""
|
| 85 |
+
param_descriptions: dict[str, str] | None
|
| 86 |
+
"""The parameter descriptions of the function, derived from the docstring."""
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
DocstringStyle = Literal["google", "numpy", "sphinx"]
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
# As of Feb 2025, the automatic style detection in griffe is an Insiders feature. This
|
| 93 |
+
# code approximates it.
|
| 94 |
+
def _detect_docstring_style(doc: str) -> DocstringStyle:
|
| 95 |
+
scores: dict[DocstringStyle, int] = {"sphinx": 0, "numpy": 0, "google": 0}
|
| 96 |
+
|
| 97 |
+
# Sphinx style detection: look for :param, :type, :return:, and :rtype:
|
| 98 |
+
sphinx_patterns = [r"^:param\s", r"^:type\s", r"^:return:", r"^:rtype:"]
|
| 99 |
+
for pattern in sphinx_patterns:
|
| 100 |
+
if re.search(pattern, doc, re.MULTILINE):
|
| 101 |
+
scores["sphinx"] += 1
|
| 102 |
+
|
| 103 |
+
# Numpy style detection: look for headers like 'Parameters', 'Returns', or 'Yields' followed by
|
| 104 |
+
# a dashed underline
|
| 105 |
+
numpy_patterns = [
|
| 106 |
+
r"^Parameters\s*\n\s*-{3,}",
|
| 107 |
+
r"^Returns\s*\n\s*-{3,}",
|
| 108 |
+
r"^Yields\s*\n\s*-{3,}",
|
| 109 |
+
]
|
| 110 |
+
for pattern in numpy_patterns:
|
| 111 |
+
if re.search(pattern, doc, re.MULTILINE):
|
| 112 |
+
scores["numpy"] += 1
|
| 113 |
+
|
| 114 |
+
# Google style detection: look for section headers with a trailing colon
|
| 115 |
+
google_patterns = [r"^(Args|Arguments):", r"^(Returns):", r"^(Raises):"]
|
| 116 |
+
for pattern in google_patterns:
|
| 117 |
+
if re.search(pattern, doc, re.MULTILINE):
|
| 118 |
+
scores["google"] += 1
|
| 119 |
+
|
| 120 |
+
max_score = max(scores.values())
|
| 121 |
+
if max_score == 0:
|
| 122 |
+
return "google"
|
| 123 |
+
|
| 124 |
+
# Priority order: sphinx > numpy > google in case of tie
|
| 125 |
+
styles: list[DocstringStyle] = ["sphinx", "numpy", "google"]
|
| 126 |
+
|
| 127 |
+
for style in styles:
|
| 128 |
+
if scores[style] == max_score:
|
| 129 |
+
return style
|
| 130 |
+
|
| 131 |
+
return "google"
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
@contextlib.contextmanager
|
| 135 |
+
def _suppress_griffe_logging():
|
| 136 |
+
# Suppresses warnings about missing annotations for params
|
| 137 |
+
logger = logging.getLogger("griffe")
|
| 138 |
+
previous_level = logger.getEffectiveLevel()
|
| 139 |
+
logger.setLevel(logging.ERROR)
|
| 140 |
+
try:
|
| 141 |
+
yield
|
| 142 |
+
finally:
|
| 143 |
+
logger.setLevel(previous_level)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def generate_func_documentation(
|
| 147 |
+
func: Callable[..., Any], style: DocstringStyle | None = None
|
| 148 |
+
) -> FuncDocumentation:
|
| 149 |
+
"""
|
| 150 |
+
Extracts metadata from a function docstring, in preparation for sending it to an LLM as a tool.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
func: The function to extract documentation from.
|
| 154 |
+
style: The style of the docstring to use for parsing. If not provided, we will attempt to
|
| 155 |
+
auto-detect the style.
|
| 156 |
+
|
| 157 |
+
Returns:
|
| 158 |
+
A FuncDocumentation object containing the function's name, description, and parameter
|
| 159 |
+
descriptions.
|
| 160 |
+
"""
|
| 161 |
+
name = func.__name__
|
| 162 |
+
doc = inspect.getdoc(func)
|
| 163 |
+
if not doc:
|
| 164 |
+
return FuncDocumentation(name=name, description=None, param_descriptions=None)
|
| 165 |
+
|
| 166 |
+
with _suppress_griffe_logging():
|
| 167 |
+
docstring = Docstring(doc, lineno=1, parser=style or _detect_docstring_style(doc))
|
| 168 |
+
parsed = docstring.parse()
|
| 169 |
+
|
| 170 |
+
description: str | None = next(
|
| 171 |
+
(section.value for section in parsed if section.kind == DocstringSectionKind.text), None
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
param_descriptions: dict[str, str] = {
|
| 175 |
+
param.name: param.description
|
| 176 |
+
for section in parsed
|
| 177 |
+
if section.kind == DocstringSectionKind.parameters
|
| 178 |
+
for param in section.value
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
return FuncDocumentation(
|
| 182 |
+
name=func.__name__,
|
| 183 |
+
description=description,
|
| 184 |
+
param_descriptions=param_descriptions or None,
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def _strip_annotated(annotation: Any) -> tuple[Any, tuple[Any, ...]]:
|
| 189 |
+
"""Returns the underlying annotation and any metadata from typing.Annotated."""
|
| 190 |
+
|
| 191 |
+
metadata: tuple[Any, ...] = ()
|
| 192 |
+
ann = annotation
|
| 193 |
+
|
| 194 |
+
while get_origin(ann) is Annotated:
|
| 195 |
+
args = get_args(ann)
|
| 196 |
+
if not args:
|
| 197 |
+
break
|
| 198 |
+
ann = args[0]
|
| 199 |
+
metadata = (*metadata, *args[1:])
|
| 200 |
+
|
| 201 |
+
return ann, metadata
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def _extract_description_from_metadata(metadata: tuple[Any, ...]) -> str | None:
|
| 205 |
+
"""Extracts a human readable description from Annotated metadata if present."""
|
| 206 |
+
|
| 207 |
+
for item in metadata:
|
| 208 |
+
if isinstance(item, str):
|
| 209 |
+
return item
|
| 210 |
+
return None
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def function_schema(
|
| 214 |
+
func: Callable[..., Any],
|
| 215 |
+
docstring_style: DocstringStyle | None = None,
|
| 216 |
+
name_override: str | None = None,
|
| 217 |
+
description_override: str | None = None,
|
| 218 |
+
use_docstring_info: bool = True,
|
| 219 |
+
strict_json_schema: bool = True,
|
| 220 |
+
) -> FuncSchema:
|
| 221 |
+
"""
|
| 222 |
+
Given a Python function, extracts a `FuncSchema` from it, capturing the name, description,
|
| 223 |
+
parameter descriptions, and other metadata.
|
| 224 |
+
|
| 225 |
+
Args:
|
| 226 |
+
func: The function to extract the schema from.
|
| 227 |
+
docstring_style: The style of the docstring to use for parsing. If not provided, we will
|
| 228 |
+
attempt to auto-detect the style.
|
| 229 |
+
name_override: If provided, use this name instead of the function's `__name__`.
|
| 230 |
+
description_override: If provided, use this description instead of the one derived from the
|
| 231 |
+
docstring.
|
| 232 |
+
use_docstring_info: If True, uses the docstring to generate the description and parameter
|
| 233 |
+
descriptions.
|
| 234 |
+
strict_json_schema: Whether the JSON schema is in strict mode. If True, we'll ensure that
|
| 235 |
+
the schema adheres to the "strict" standard the OpenAI API expects. We **strongly**
|
| 236 |
+
recommend setting this to True, as it increases the likelihood of the LLM producing
|
| 237 |
+
correct JSON input.
|
| 238 |
+
|
| 239 |
+
Returns:
|
| 240 |
+
A `FuncSchema` object containing the function's name, description, parameter descriptions,
|
| 241 |
+
and other metadata.
|
| 242 |
+
"""
|
| 243 |
+
|
| 244 |
+
# 1. Grab docstring info
|
| 245 |
+
if use_docstring_info:
|
| 246 |
+
doc_info = generate_func_documentation(func, docstring_style)
|
| 247 |
+
param_descs = dict(doc_info.param_descriptions or {})
|
| 248 |
+
else:
|
| 249 |
+
doc_info = None
|
| 250 |
+
param_descs = {}
|
| 251 |
+
|
| 252 |
+
type_hints_with_extras = get_type_hints(func, include_extras=True)
|
| 253 |
+
type_hints: dict[str, Any] = {}
|
| 254 |
+
annotated_param_descs: dict[str, str] = {}
|
| 255 |
+
|
| 256 |
+
for name, annotation in type_hints_with_extras.items():
|
| 257 |
+
if name == "return":
|
| 258 |
+
continue
|
| 259 |
+
|
| 260 |
+
stripped_ann, metadata = _strip_annotated(annotation)
|
| 261 |
+
type_hints[name] = stripped_ann
|
| 262 |
+
|
| 263 |
+
description = _extract_description_from_metadata(metadata)
|
| 264 |
+
if description is not None:
|
| 265 |
+
annotated_param_descs[name] = description
|
| 266 |
+
|
| 267 |
+
for name, description in annotated_param_descs.items():
|
| 268 |
+
param_descs.setdefault(name, description)
|
| 269 |
+
|
| 270 |
+
# Ensure name_override takes precedence even if docstring info is disabled.
|
| 271 |
+
func_name = name_override or (doc_info.name if doc_info else func.__name__)
|
| 272 |
+
|
| 273 |
+
# 2. Inspect function signature and get type hints
|
| 274 |
+
sig = inspect.signature(func)
|
| 275 |
+
params = list(sig.parameters.items())
|
| 276 |
+
takes_context = False
|
| 277 |
+
filtered_params = []
|
| 278 |
+
|
| 279 |
+
if params:
|
| 280 |
+
first_name, first_param = params[0]
|
| 281 |
+
# Prefer the evaluated type hint if available
|
| 282 |
+
ann = type_hints.get(first_name, first_param.annotation)
|
| 283 |
+
if ann != inspect._empty:
|
| 284 |
+
origin = get_origin(ann) or ann
|
| 285 |
+
if origin is RunContextWrapper or origin is ToolContext:
|
| 286 |
+
takes_context = True # Mark that the function takes context
|
| 287 |
+
else:
|
| 288 |
+
filtered_params.append((first_name, first_param))
|
| 289 |
+
else:
|
| 290 |
+
filtered_params.append((first_name, first_param))
|
| 291 |
+
|
| 292 |
+
# For parameters other than the first, raise error if any use RunContextWrapper or ToolContext.
|
| 293 |
+
for name, param in params[1:]:
|
| 294 |
+
ann = type_hints.get(name, param.annotation)
|
| 295 |
+
if ann != inspect._empty:
|
| 296 |
+
origin = get_origin(ann) or ann
|
| 297 |
+
if origin is RunContextWrapper or origin is ToolContext:
|
| 298 |
+
raise UserError(
|
| 299 |
+
f"RunContextWrapper/ToolContext param found at non-first position in function"
|
| 300 |
+
f" {func.__name__}"
|
| 301 |
+
)
|
| 302 |
+
filtered_params.append((name, param))
|
| 303 |
+
|
| 304 |
+
# We will collect field definitions for create_model as a dict:
|
| 305 |
+
# field_name -> (type_annotation, default_value_or_Field(...))
|
| 306 |
+
fields: dict[str, Any] = {}
|
| 307 |
+
|
| 308 |
+
for name, param in filtered_params:
|
| 309 |
+
ann = type_hints.get(name, param.annotation)
|
| 310 |
+
default = param.default
|
| 311 |
+
|
| 312 |
+
# If there's no type hint, assume `Any`
|
| 313 |
+
if ann == inspect._empty:
|
| 314 |
+
ann = Any
|
| 315 |
+
|
| 316 |
+
# If a docstring param description exists, use it
|
| 317 |
+
field_description = param_descs.get(name, None)
|
| 318 |
+
|
| 319 |
+
# Handle different parameter kinds
|
| 320 |
+
if param.kind == param.VAR_POSITIONAL:
|
| 321 |
+
# e.g. *args: extend positional args
|
| 322 |
+
if get_origin(ann) is tuple:
|
| 323 |
+
# e.g. def foo(*args: tuple[int, ...]) -> treat as List[int]
|
| 324 |
+
args_of_tuple = get_args(ann)
|
| 325 |
+
if len(args_of_tuple) == 2 and args_of_tuple[1] is Ellipsis:
|
| 326 |
+
ann = list[args_of_tuple[0]] # type: ignore
|
| 327 |
+
else:
|
| 328 |
+
ann = list[Any]
|
| 329 |
+
else:
|
| 330 |
+
# If user wrote *args: int, treat as List[int]
|
| 331 |
+
ann = list[ann] # type: ignore
|
| 332 |
+
|
| 333 |
+
# Default factory to empty list
|
| 334 |
+
fields[name] = (
|
| 335 |
+
ann,
|
| 336 |
+
Field(default_factory=list, description=field_description),
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
elif param.kind == param.VAR_KEYWORD:
|
| 340 |
+
# **kwargs handling
|
| 341 |
+
if get_origin(ann) is dict:
|
| 342 |
+
# e.g. def foo(**kwargs: dict[str, int])
|
| 343 |
+
dict_args = get_args(ann)
|
| 344 |
+
if len(dict_args) == 2:
|
| 345 |
+
ann = dict[dict_args[0], dict_args[1]] # type: ignore
|
| 346 |
+
else:
|
| 347 |
+
ann = dict[str, Any]
|
| 348 |
+
else:
|
| 349 |
+
# e.g. def foo(**kwargs: int) -> Dict[str, int]
|
| 350 |
+
ann = dict[str, ann] # type: ignore
|
| 351 |
+
|
| 352 |
+
fields[name] = (
|
| 353 |
+
ann,
|
| 354 |
+
Field(default_factory=dict, description=field_description),
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
else:
|
| 358 |
+
# Normal parameter
|
| 359 |
+
if default == inspect._empty:
|
| 360 |
+
# Required field
|
| 361 |
+
fields[name] = (
|
| 362 |
+
ann,
|
| 363 |
+
Field(..., description=field_description),
|
| 364 |
+
)
|
| 365 |
+
elif isinstance(default, FieldInfo):
|
| 366 |
+
# Parameter with a default value that is a Field(...)
|
| 367 |
+
fields[name] = (
|
| 368 |
+
ann,
|
| 369 |
+
FieldInfo.merge_field_infos(
|
| 370 |
+
default, description=field_description or default.description
|
| 371 |
+
),
|
| 372 |
+
)
|
| 373 |
+
else:
|
| 374 |
+
# Parameter with a default value
|
| 375 |
+
fields[name] = (
|
| 376 |
+
ann,
|
| 377 |
+
Field(default=default, description=field_description),
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
# 3. Dynamically build a Pydantic model
|
| 381 |
+
dynamic_model = create_model(f"{func_name}_args", __base__=BaseModel, **fields)
|
| 382 |
+
|
| 383 |
+
# 4. Build JSON schema from that model
|
| 384 |
+
json_schema = dynamic_model.model_json_schema()
|
| 385 |
+
if strict_json_schema:
|
| 386 |
+
json_schema = ensure_strict_json_schema(json_schema)
|
| 387 |
+
|
| 388 |
+
# 5. Return as a FuncSchema dataclass
|
| 389 |
+
return FuncSchema(
|
| 390 |
+
name=func_name,
|
| 391 |
+
# Ensure description_override takes precedence even if docstring info is disabled.
|
| 392 |
+
description=description_override or (doc_info.description if doc_info else None),
|
| 393 |
+
params_pydantic_model=dynamic_model,
|
| 394 |
+
params_json_schema=json_schema,
|
| 395 |
+
signature=sig,
|
| 396 |
+
takes_context=takes_context,
|
| 397 |
+
strict_json_schema=strict_json_schema,
|
| 398 |
+
)
|
agents/guardrail.py
ADDED
|
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import inspect
|
| 4 |
+
from collections.abc import Awaitable
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from typing import TYPE_CHECKING, Any, Callable, Generic, Union, overload
|
| 7 |
+
|
| 8 |
+
from typing_extensions import TypeVar
|
| 9 |
+
|
| 10 |
+
from .exceptions import UserError
|
| 11 |
+
from .items import TResponseInputItem
|
| 12 |
+
from .run_context import RunContextWrapper, TContext
|
| 13 |
+
from .util._types import MaybeAwaitable
|
| 14 |
+
|
| 15 |
+
if TYPE_CHECKING:
|
| 16 |
+
from .agent import Agent
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class GuardrailFunctionOutput:
|
| 21 |
+
"""The output of a guardrail function."""
|
| 22 |
+
|
| 23 |
+
output_info: Any
|
| 24 |
+
"""
|
| 25 |
+
Optional information about the guardrail's output. For example, the guardrail could include
|
| 26 |
+
information about the checks it performed and granular results.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
tripwire_triggered: bool
|
| 30 |
+
"""
|
| 31 |
+
Whether the tripwire was triggered. If triggered, the agent's execution will be halted.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@dataclass
|
| 36 |
+
class InputGuardrailResult:
|
| 37 |
+
"""The result of a guardrail run."""
|
| 38 |
+
|
| 39 |
+
guardrail: InputGuardrail[Any]
|
| 40 |
+
"""
|
| 41 |
+
The guardrail that was run.
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
output: GuardrailFunctionOutput
|
| 45 |
+
"""The output of the guardrail function."""
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@dataclass
|
| 49 |
+
class OutputGuardrailResult:
|
| 50 |
+
"""The result of a guardrail run."""
|
| 51 |
+
|
| 52 |
+
guardrail: OutputGuardrail[Any]
|
| 53 |
+
"""
|
| 54 |
+
The guardrail that was run.
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
agent_output: Any
|
| 58 |
+
"""
|
| 59 |
+
The output of the agent that was checked by the guardrail.
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
agent: Agent[Any]
|
| 63 |
+
"""
|
| 64 |
+
The agent that was checked by the guardrail.
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
output: GuardrailFunctionOutput
|
| 68 |
+
"""The output of the guardrail function."""
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
@dataclass
|
| 72 |
+
class InputGuardrail(Generic[TContext]):
|
| 73 |
+
"""Input guardrails are checks that run in parallel to the agent's execution.
|
| 74 |
+
They can be used to do things like:
|
| 75 |
+
- Check if input messages are off-topic
|
| 76 |
+
- Take over control of the agent's execution if an unexpected input is detected
|
| 77 |
+
|
| 78 |
+
You can use the `@input_guardrail()` decorator to turn a function into an `InputGuardrail`, or
|
| 79 |
+
create an `InputGuardrail` manually.
|
| 80 |
+
|
| 81 |
+
Guardrails return a `GuardrailResult`. If `result.tripwire_triggered` is `True`,
|
| 82 |
+
the agent's execution will immediately stop, and
|
| 83 |
+
an `InputGuardrailTripwireTriggered` exception will be raised
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
guardrail_function: Callable[
|
| 87 |
+
[RunContextWrapper[TContext], Agent[Any], str | list[TResponseInputItem]],
|
| 88 |
+
MaybeAwaitable[GuardrailFunctionOutput],
|
| 89 |
+
]
|
| 90 |
+
"""A function that receives the agent input and the context, and returns a
|
| 91 |
+
`GuardrailResult`. The result marks whether the tripwire was triggered, and can optionally
|
| 92 |
+
include information about the guardrail's output.
|
| 93 |
+
"""
|
| 94 |
+
|
| 95 |
+
name: str | None = None
|
| 96 |
+
"""The name of the guardrail, used for tracing. If not provided, we'll use the guardrail
|
| 97 |
+
function's name.
|
| 98 |
+
"""
|
| 99 |
+
|
| 100 |
+
def get_name(self) -> str:
|
| 101 |
+
if self.name:
|
| 102 |
+
return self.name
|
| 103 |
+
|
| 104 |
+
return self.guardrail_function.__name__
|
| 105 |
+
|
| 106 |
+
async def run(
|
| 107 |
+
self,
|
| 108 |
+
agent: Agent[Any],
|
| 109 |
+
input: str | list[TResponseInputItem],
|
| 110 |
+
context: RunContextWrapper[TContext],
|
| 111 |
+
) -> InputGuardrailResult:
|
| 112 |
+
if not callable(self.guardrail_function):
|
| 113 |
+
raise UserError(f"Guardrail function must be callable, got {self.guardrail_function}")
|
| 114 |
+
|
| 115 |
+
output = self.guardrail_function(context, agent, input)
|
| 116 |
+
if inspect.isawaitable(output):
|
| 117 |
+
return InputGuardrailResult(
|
| 118 |
+
guardrail=self,
|
| 119 |
+
output=await output,
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
return InputGuardrailResult(
|
| 123 |
+
guardrail=self,
|
| 124 |
+
output=output,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
@dataclass
|
| 129 |
+
class OutputGuardrail(Generic[TContext]):
|
| 130 |
+
"""Output guardrails are checks that run on the final output of an agent.
|
| 131 |
+
They can be used to do check if the output passes certain validation criteria
|
| 132 |
+
|
| 133 |
+
You can use the `@output_guardrail()` decorator to turn a function into an `OutputGuardrail`,
|
| 134 |
+
or create an `OutputGuardrail` manually.
|
| 135 |
+
|
| 136 |
+
Guardrails return a `GuardrailResult`. If `result.tripwire_triggered` is `True`, an
|
| 137 |
+
`OutputGuardrailTripwireTriggered` exception will be raised.
|
| 138 |
+
"""
|
| 139 |
+
|
| 140 |
+
guardrail_function: Callable[
|
| 141 |
+
[RunContextWrapper[TContext], Agent[Any], Any],
|
| 142 |
+
MaybeAwaitable[GuardrailFunctionOutput],
|
| 143 |
+
]
|
| 144 |
+
"""A function that receives the final agent, its output, and the context, and returns a
|
| 145 |
+
`GuardrailResult`. The result marks whether the tripwire was triggered, and can optionally
|
| 146 |
+
include information about the guardrail's output.
|
| 147 |
+
"""
|
| 148 |
+
|
| 149 |
+
name: str | None = None
|
| 150 |
+
"""The name of the guardrail, used for tracing. If not provided, we'll use the guardrail
|
| 151 |
+
function's name.
|
| 152 |
+
"""
|
| 153 |
+
|
| 154 |
+
def get_name(self) -> str:
|
| 155 |
+
if self.name:
|
| 156 |
+
return self.name
|
| 157 |
+
|
| 158 |
+
return self.guardrail_function.__name__
|
| 159 |
+
|
| 160 |
+
async def run(
|
| 161 |
+
self, context: RunContextWrapper[TContext], agent: Agent[Any], agent_output: Any
|
| 162 |
+
) -> OutputGuardrailResult:
|
| 163 |
+
if not callable(self.guardrail_function):
|
| 164 |
+
raise UserError(f"Guardrail function must be callable, got {self.guardrail_function}")
|
| 165 |
+
|
| 166 |
+
output = self.guardrail_function(context, agent, agent_output)
|
| 167 |
+
if inspect.isawaitable(output):
|
| 168 |
+
return OutputGuardrailResult(
|
| 169 |
+
guardrail=self,
|
| 170 |
+
agent=agent,
|
| 171 |
+
agent_output=agent_output,
|
| 172 |
+
output=await output,
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
return OutputGuardrailResult(
|
| 176 |
+
guardrail=self,
|
| 177 |
+
agent=agent,
|
| 178 |
+
agent_output=agent_output,
|
| 179 |
+
output=output,
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
TContext_co = TypeVar("TContext_co", bound=Any, covariant=True)
|
| 184 |
+
|
| 185 |
+
# For InputGuardrail
|
| 186 |
+
_InputGuardrailFuncSync = Callable[
|
| 187 |
+
[RunContextWrapper[TContext_co], "Agent[Any]", Union[str, list[TResponseInputItem]]],
|
| 188 |
+
GuardrailFunctionOutput,
|
| 189 |
+
]
|
| 190 |
+
_InputGuardrailFuncAsync = Callable[
|
| 191 |
+
[RunContextWrapper[TContext_co], "Agent[Any]", Union[str, list[TResponseInputItem]]],
|
| 192 |
+
Awaitable[GuardrailFunctionOutput],
|
| 193 |
+
]
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
@overload
|
| 197 |
+
def input_guardrail(
|
| 198 |
+
func: _InputGuardrailFuncSync[TContext_co],
|
| 199 |
+
) -> InputGuardrail[TContext_co]: ...
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
@overload
|
| 203 |
+
def input_guardrail(
|
| 204 |
+
func: _InputGuardrailFuncAsync[TContext_co],
|
| 205 |
+
) -> InputGuardrail[TContext_co]: ...
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
@overload
|
| 209 |
+
def input_guardrail(
|
| 210 |
+
*,
|
| 211 |
+
name: str | None = None,
|
| 212 |
+
) -> Callable[
|
| 213 |
+
[_InputGuardrailFuncSync[TContext_co] | _InputGuardrailFuncAsync[TContext_co]],
|
| 214 |
+
InputGuardrail[TContext_co],
|
| 215 |
+
]: ...
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def input_guardrail(
|
| 219 |
+
func: _InputGuardrailFuncSync[TContext_co]
|
| 220 |
+
| _InputGuardrailFuncAsync[TContext_co]
|
| 221 |
+
| None = None,
|
| 222 |
+
*,
|
| 223 |
+
name: str | None = None,
|
| 224 |
+
) -> (
|
| 225 |
+
InputGuardrail[TContext_co]
|
| 226 |
+
| Callable[
|
| 227 |
+
[_InputGuardrailFuncSync[TContext_co] | _InputGuardrailFuncAsync[TContext_co]],
|
| 228 |
+
InputGuardrail[TContext_co],
|
| 229 |
+
]
|
| 230 |
+
):
|
| 231 |
+
"""
|
| 232 |
+
Decorator that transforms a sync or async function into an `InputGuardrail`.
|
| 233 |
+
It can be used directly (no parentheses) or with keyword args, e.g.:
|
| 234 |
+
|
| 235 |
+
@input_guardrail
|
| 236 |
+
def my_sync_guardrail(...): ...
|
| 237 |
+
|
| 238 |
+
@input_guardrail(name="guardrail_name")
|
| 239 |
+
async def my_async_guardrail(...): ...
|
| 240 |
+
"""
|
| 241 |
+
|
| 242 |
+
def decorator(
|
| 243 |
+
f: _InputGuardrailFuncSync[TContext_co] | _InputGuardrailFuncAsync[TContext_co],
|
| 244 |
+
) -> InputGuardrail[TContext_co]:
|
| 245 |
+
return InputGuardrail(
|
| 246 |
+
guardrail_function=f,
|
| 247 |
+
# If not set, guardrail name uses the function’s name by default.
|
| 248 |
+
name=name if name else f.__name__,
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
if func is not None:
|
| 252 |
+
# Decorator was used without parentheses
|
| 253 |
+
return decorator(func)
|
| 254 |
+
|
| 255 |
+
# Decorator used with keyword arguments
|
| 256 |
+
return decorator
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
_OutputGuardrailFuncSync = Callable[
|
| 260 |
+
[RunContextWrapper[TContext_co], "Agent[Any]", Any],
|
| 261 |
+
GuardrailFunctionOutput,
|
| 262 |
+
]
|
| 263 |
+
_OutputGuardrailFuncAsync = Callable[
|
| 264 |
+
[RunContextWrapper[TContext_co], "Agent[Any]", Any],
|
| 265 |
+
Awaitable[GuardrailFunctionOutput],
|
| 266 |
+
]
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
@overload
|
| 270 |
+
def output_guardrail(
|
| 271 |
+
func: _OutputGuardrailFuncSync[TContext_co],
|
| 272 |
+
) -> OutputGuardrail[TContext_co]: ...
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
@overload
|
| 276 |
+
def output_guardrail(
|
| 277 |
+
func: _OutputGuardrailFuncAsync[TContext_co],
|
| 278 |
+
) -> OutputGuardrail[TContext_co]: ...
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
@overload
|
| 282 |
+
def output_guardrail(
|
| 283 |
+
*,
|
| 284 |
+
name: str | None = None,
|
| 285 |
+
) -> Callable[
|
| 286 |
+
[_OutputGuardrailFuncSync[TContext_co] | _OutputGuardrailFuncAsync[TContext_co]],
|
| 287 |
+
OutputGuardrail[TContext_co],
|
| 288 |
+
]: ...
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
def output_guardrail(
|
| 292 |
+
func: _OutputGuardrailFuncSync[TContext_co]
|
| 293 |
+
| _OutputGuardrailFuncAsync[TContext_co]
|
| 294 |
+
| None = None,
|
| 295 |
+
*,
|
| 296 |
+
name: str | None = None,
|
| 297 |
+
) -> (
|
| 298 |
+
OutputGuardrail[TContext_co]
|
| 299 |
+
| Callable[
|
| 300 |
+
[_OutputGuardrailFuncSync[TContext_co] | _OutputGuardrailFuncAsync[TContext_co]],
|
| 301 |
+
OutputGuardrail[TContext_co],
|
| 302 |
+
]
|
| 303 |
+
):
|
| 304 |
+
"""
|
| 305 |
+
Decorator that transforms a sync or async function into an `OutputGuardrail`.
|
| 306 |
+
It can be used directly (no parentheses) or with keyword args, e.g.:
|
| 307 |
+
|
| 308 |
+
@output_guardrail
|
| 309 |
+
def my_sync_guardrail(...): ...
|
| 310 |
+
|
| 311 |
+
@output_guardrail(name="guardrail_name")
|
| 312 |
+
async def my_async_guardrail(...): ...
|
| 313 |
+
"""
|
| 314 |
+
|
| 315 |
+
def decorator(
|
| 316 |
+
f: _OutputGuardrailFuncSync[TContext_co] | _OutputGuardrailFuncAsync[TContext_co],
|
| 317 |
+
) -> OutputGuardrail[TContext_co]:
|
| 318 |
+
return OutputGuardrail(
|
| 319 |
+
guardrail_function=f,
|
| 320 |
+
# Guardrail name defaults to function's name when not specified (None).
|
| 321 |
+
name=name if name else f.__name__,
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
if func is not None:
|
| 325 |
+
# Decorator was used without parentheses
|
| 326 |
+
return decorator(func)
|
| 327 |
+
|
| 328 |
+
# Decorator used with keyword arguments
|
| 329 |
+
return decorator
|