|
|
from __future__ import annotations |
|
|
|
|
|
import json |
|
|
from collections.abc import AsyncIterator |
|
|
from contextvars import ContextVar |
|
|
from dataclasses import dataclass |
|
|
from typing import TYPE_CHECKING, Any, Literal, cast, overload |
|
|
|
|
|
from openai import NOT_GIVEN, APIStatusError, AsyncOpenAI, AsyncStream, NotGiven |
|
|
from openai.types import ChatModel |
|
|
from openai.types.responses import ( |
|
|
Response, |
|
|
ResponseCompletedEvent, |
|
|
ResponseIncludable, |
|
|
ResponseStreamEvent, |
|
|
ResponseTextConfigParam, |
|
|
ToolParam, |
|
|
response_create_params, |
|
|
) |
|
|
from openai.types.responses.response_prompt_param import ResponsePromptParam |
|
|
|
|
|
from .. import _debug |
|
|
from ..agent_output import AgentOutputSchemaBase |
|
|
from ..exceptions import UserError |
|
|
from ..handoffs import Handoff |
|
|
from ..items import ItemHelpers, ModelResponse, TResponseInputItem |
|
|
from ..logger import logger |
|
|
from ..model_settings import MCPToolChoice |
|
|
from ..tool import ( |
|
|
CodeInterpreterTool, |
|
|
ComputerTool, |
|
|
FileSearchTool, |
|
|
FunctionTool, |
|
|
HostedMCPTool, |
|
|
ImageGenerationTool, |
|
|
LocalShellTool, |
|
|
Tool, |
|
|
WebSearchTool, |
|
|
) |
|
|
from ..tracing import SpanError, response_span |
|
|
from ..usage import Usage |
|
|
from ..util._json import _to_dump_compatible |
|
|
from ..version import __version__ |
|
|
from .interface import Model, ModelTracing |
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from ..model_settings import ModelSettings |
|
|
|
|
|
|
|
|
_USER_AGENT = f"Agents/Python {__version__}" |
|
|
_HEADERS = {"User-Agent": _USER_AGENT} |
|
|
|
|
|
|
|
|
_HEADERS_OVERRIDE: ContextVar[dict[str, str] | None] = ContextVar( |
|
|
"openai_responses_headers_override", default=None |
|
|
) |
|
|
|
|
|
|
|
|
class OpenAIResponsesModel(Model): |
|
|
""" |
|
|
Implementation of `Model` that uses the OpenAI Responses API. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model: str | ChatModel, |
|
|
openai_client: AsyncOpenAI, |
|
|
) -> None: |
|
|
self.model = model |
|
|
self._client = openai_client |
|
|
|
|
|
def _non_null_or_not_given(self, value: Any) -> Any: |
|
|
return value if value is not None else NOT_GIVEN |
|
|
|
|
|
async def get_response( |
|
|
self, |
|
|
system_instructions: str | None, |
|
|
input: str | list[TResponseInputItem], |
|
|
model_settings: ModelSettings, |
|
|
tools: list[Tool], |
|
|
output_schema: AgentOutputSchemaBase | None, |
|
|
handoffs: list[Handoff], |
|
|
tracing: ModelTracing, |
|
|
previous_response_id: str | None = None, |
|
|
conversation_id: str | None = None, |
|
|
prompt: ResponsePromptParam | None = None, |
|
|
) -> ModelResponse: |
|
|
with response_span(disabled=tracing.is_disabled()) as span_response: |
|
|
try: |
|
|
response = await self._fetch_response( |
|
|
system_instructions, |
|
|
input, |
|
|
model_settings, |
|
|
tools, |
|
|
output_schema, |
|
|
handoffs, |
|
|
previous_response_id=previous_response_id, |
|
|
conversation_id=conversation_id, |
|
|
stream=False, |
|
|
prompt=prompt, |
|
|
) |
|
|
|
|
|
if _debug.DONT_LOG_MODEL_DATA: |
|
|
logger.debug("LLM responded") |
|
|
else: |
|
|
logger.debug( |
|
|
"LLM resp:\n" |
|
|
f"""{ |
|
|
json.dumps( |
|
|
[x.model_dump() for x in response.output], |
|
|
indent=2, |
|
|
ensure_ascii=False, |
|
|
) |
|
|
}\n""" |
|
|
) |
|
|
|
|
|
usage = ( |
|
|
Usage( |
|
|
requests=1, |
|
|
input_tokens=response.usage.input_tokens, |
|
|
output_tokens=response.usage.output_tokens, |
|
|
total_tokens=response.usage.total_tokens, |
|
|
input_tokens_details=response.usage.input_tokens_details, |
|
|
output_tokens_details=response.usage.output_tokens_details, |
|
|
) |
|
|
if response.usage |
|
|
else Usage() |
|
|
) |
|
|
|
|
|
if tracing.include_data(): |
|
|
span_response.span_data.response = response |
|
|
span_response.span_data.input = input |
|
|
except Exception as e: |
|
|
span_response.set_error( |
|
|
SpanError( |
|
|
message="Error getting response", |
|
|
data={ |
|
|
"error": str(e) if tracing.include_data() else e.__class__.__name__, |
|
|
}, |
|
|
) |
|
|
) |
|
|
request_id = e.request_id if isinstance(e, APIStatusError) else None |
|
|
logger.error(f"Error getting response: {e}. (request_id: {request_id})") |
|
|
raise |
|
|
|
|
|
return ModelResponse( |
|
|
output=response.output, |
|
|
usage=usage, |
|
|
response_id=response.id, |
|
|
) |
|
|
|
|
|
async def stream_response( |
|
|
self, |
|
|
system_instructions: str | None, |
|
|
input: str | list[TResponseInputItem], |
|
|
model_settings: ModelSettings, |
|
|
tools: list[Tool], |
|
|
output_schema: AgentOutputSchemaBase | None, |
|
|
handoffs: list[Handoff], |
|
|
tracing: ModelTracing, |
|
|
previous_response_id: str | None = None, |
|
|
conversation_id: str | None = None, |
|
|
prompt: ResponsePromptParam | None = None, |
|
|
) -> AsyncIterator[ResponseStreamEvent]: |
|
|
""" |
|
|
Yields a partial message as it is generated, as well as the usage information. |
|
|
""" |
|
|
with response_span(disabled=tracing.is_disabled()) as span_response: |
|
|
try: |
|
|
stream = await self._fetch_response( |
|
|
system_instructions, |
|
|
input, |
|
|
model_settings, |
|
|
tools, |
|
|
output_schema, |
|
|
handoffs, |
|
|
previous_response_id=previous_response_id, |
|
|
conversation_id=conversation_id, |
|
|
stream=True, |
|
|
prompt=prompt, |
|
|
) |
|
|
|
|
|
final_response: Response | None = None |
|
|
|
|
|
async for chunk in stream: |
|
|
if isinstance(chunk, ResponseCompletedEvent): |
|
|
final_response = chunk.response |
|
|
yield chunk |
|
|
|
|
|
if final_response and tracing.include_data(): |
|
|
span_response.span_data.response = final_response |
|
|
span_response.span_data.input = input |
|
|
|
|
|
except Exception as e: |
|
|
span_response.set_error( |
|
|
SpanError( |
|
|
message="Error streaming response", |
|
|
data={ |
|
|
"error": str(e) if tracing.include_data() else e.__class__.__name__, |
|
|
}, |
|
|
) |
|
|
) |
|
|
logger.error(f"Error streaming response: {e}") |
|
|
raise |
|
|
|
|
|
@overload |
|
|
async def _fetch_response( |
|
|
self, |
|
|
system_instructions: str | None, |
|
|
input: str | list[TResponseInputItem], |
|
|
model_settings: ModelSettings, |
|
|
tools: list[Tool], |
|
|
output_schema: AgentOutputSchemaBase | None, |
|
|
handoffs: list[Handoff], |
|
|
previous_response_id: str | None, |
|
|
conversation_id: str | None, |
|
|
stream: Literal[True], |
|
|
prompt: ResponsePromptParam | None = None, |
|
|
) -> AsyncStream[ResponseStreamEvent]: ... |
|
|
|
|
|
@overload |
|
|
async def _fetch_response( |
|
|
self, |
|
|
system_instructions: str | None, |
|
|
input: str | list[TResponseInputItem], |
|
|
model_settings: ModelSettings, |
|
|
tools: list[Tool], |
|
|
output_schema: AgentOutputSchemaBase | None, |
|
|
handoffs: list[Handoff], |
|
|
previous_response_id: str | None, |
|
|
conversation_id: str | None, |
|
|
stream: Literal[False], |
|
|
prompt: ResponsePromptParam | None = None, |
|
|
) -> Response: ... |
|
|
|
|
|
async def _fetch_response( |
|
|
self, |
|
|
system_instructions: str | None, |
|
|
input: str | list[TResponseInputItem], |
|
|
model_settings: ModelSettings, |
|
|
tools: list[Tool], |
|
|
output_schema: AgentOutputSchemaBase | None, |
|
|
handoffs: list[Handoff], |
|
|
previous_response_id: str | None = None, |
|
|
conversation_id: str | None = None, |
|
|
stream: Literal[True] | Literal[False] = False, |
|
|
prompt: ResponsePromptParam | None = None, |
|
|
) -> Response | AsyncStream[ResponseStreamEvent]: |
|
|
list_input = ItemHelpers.input_to_new_input_list(input) |
|
|
list_input = _to_dump_compatible(list_input) |
|
|
|
|
|
parallel_tool_calls = ( |
|
|
True |
|
|
if model_settings.parallel_tool_calls and tools and len(tools) > 0 |
|
|
else False |
|
|
if model_settings.parallel_tool_calls is False |
|
|
else NOT_GIVEN |
|
|
) |
|
|
|
|
|
tool_choice = Converter.convert_tool_choice(model_settings.tool_choice) |
|
|
converted_tools = Converter.convert_tools(tools, handoffs) |
|
|
converted_tools_payload = _to_dump_compatible(converted_tools.tools) |
|
|
response_format = Converter.get_response_format(output_schema) |
|
|
|
|
|
include_set: set[str] = set(converted_tools.includes) |
|
|
if model_settings.response_include is not None: |
|
|
include_set.update(model_settings.response_include) |
|
|
if model_settings.top_logprobs is not None: |
|
|
include_set.add("message.output_text.logprobs") |
|
|
include = cast(list[ResponseIncludable], list(include_set)) |
|
|
|
|
|
if _debug.DONT_LOG_MODEL_DATA: |
|
|
logger.debug("Calling LLM") |
|
|
else: |
|
|
input_json = json.dumps( |
|
|
list_input, |
|
|
indent=2, |
|
|
ensure_ascii=False, |
|
|
) |
|
|
tools_json = json.dumps( |
|
|
converted_tools_payload, |
|
|
indent=2, |
|
|
ensure_ascii=False, |
|
|
) |
|
|
logger.debug( |
|
|
f"Calling LLM {self.model} with input:\n" |
|
|
f"{input_json}\n" |
|
|
f"Tools:\n{tools_json}\n" |
|
|
f"Stream: {stream}\n" |
|
|
f"Tool choice: {tool_choice}\n" |
|
|
f"Response format: {response_format}\n" |
|
|
f"Previous response id: {previous_response_id}\n" |
|
|
f"Conversation id: {conversation_id}\n" |
|
|
) |
|
|
|
|
|
extra_args = dict(model_settings.extra_args or {}) |
|
|
if model_settings.top_logprobs is not None: |
|
|
extra_args["top_logprobs"] = model_settings.top_logprobs |
|
|
if model_settings.verbosity is not None: |
|
|
if response_format != NOT_GIVEN: |
|
|
response_format["verbosity"] = model_settings.verbosity |
|
|
else: |
|
|
response_format = {"verbosity": model_settings.verbosity} |
|
|
|
|
|
return await self._client.responses.create( |
|
|
previous_response_id=self._non_null_or_not_given(previous_response_id), |
|
|
conversation=self._non_null_or_not_given(conversation_id), |
|
|
instructions=self._non_null_or_not_given(system_instructions), |
|
|
model=self.model, |
|
|
input=list_input, |
|
|
include=include, |
|
|
tools=converted_tools_payload, |
|
|
prompt=self._non_null_or_not_given(prompt), |
|
|
temperature=self._non_null_or_not_given(model_settings.temperature), |
|
|
top_p=self._non_null_or_not_given(model_settings.top_p), |
|
|
truncation=self._non_null_or_not_given(model_settings.truncation), |
|
|
max_output_tokens=self._non_null_or_not_given(model_settings.max_tokens), |
|
|
tool_choice=tool_choice, |
|
|
parallel_tool_calls=parallel_tool_calls, |
|
|
stream=stream, |
|
|
extra_headers=self._merge_headers(model_settings), |
|
|
extra_query=model_settings.extra_query, |
|
|
extra_body=model_settings.extra_body, |
|
|
text=response_format, |
|
|
store=self._non_null_or_not_given(model_settings.store), |
|
|
reasoning=self._non_null_or_not_given(model_settings.reasoning), |
|
|
metadata=self._non_null_or_not_given(model_settings.metadata), |
|
|
**extra_args, |
|
|
) |
|
|
|
|
|
def _get_client(self) -> AsyncOpenAI: |
|
|
if self._client is None: |
|
|
self._client = AsyncOpenAI() |
|
|
return self._client |
|
|
|
|
|
def _merge_headers(self, model_settings: ModelSettings): |
|
|
return { |
|
|
**_HEADERS, |
|
|
**(model_settings.extra_headers or {}), |
|
|
**(_HEADERS_OVERRIDE.get() or {}), |
|
|
} |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ConvertedTools: |
|
|
tools: list[ToolParam] |
|
|
includes: list[ResponseIncludable] |
|
|
|
|
|
|
|
|
class Converter: |
|
|
@classmethod |
|
|
def convert_tool_choice( |
|
|
cls, tool_choice: Literal["auto", "required", "none"] | str | MCPToolChoice | None |
|
|
) -> response_create_params.ToolChoice | NotGiven: |
|
|
if tool_choice is None: |
|
|
return NOT_GIVEN |
|
|
elif isinstance(tool_choice, MCPToolChoice): |
|
|
return { |
|
|
"server_label": tool_choice.server_label, |
|
|
"type": "mcp", |
|
|
"name": tool_choice.name, |
|
|
} |
|
|
elif tool_choice == "required": |
|
|
return "required" |
|
|
elif tool_choice == "auto": |
|
|
return "auto" |
|
|
elif tool_choice == "none": |
|
|
return "none" |
|
|
elif tool_choice == "file_search": |
|
|
return { |
|
|
"type": "file_search", |
|
|
} |
|
|
elif tool_choice == "web_search": |
|
|
return { |
|
|
|
|
|
"type": "web_search", |
|
|
} |
|
|
elif tool_choice == "web_search_preview": |
|
|
return { |
|
|
"type": "web_search_preview", |
|
|
} |
|
|
elif tool_choice == "computer_use_preview": |
|
|
return { |
|
|
"type": "computer_use_preview", |
|
|
} |
|
|
elif tool_choice == "image_generation": |
|
|
return { |
|
|
"type": "image_generation", |
|
|
} |
|
|
elif tool_choice == "code_interpreter": |
|
|
return { |
|
|
"type": "code_interpreter", |
|
|
} |
|
|
elif tool_choice == "mcp": |
|
|
|
|
|
|
|
|
return {"type": "mcp"} |
|
|
else: |
|
|
return { |
|
|
"type": "function", |
|
|
"name": tool_choice, |
|
|
} |
|
|
|
|
|
@classmethod |
|
|
def get_response_format( |
|
|
cls, output_schema: AgentOutputSchemaBase | None |
|
|
) -> ResponseTextConfigParam | NotGiven: |
|
|
if output_schema is None or output_schema.is_plain_text(): |
|
|
return NOT_GIVEN |
|
|
else: |
|
|
return { |
|
|
"format": { |
|
|
"type": "json_schema", |
|
|
"name": "final_output", |
|
|
"schema": output_schema.json_schema(), |
|
|
"strict": output_schema.is_strict_json_schema(), |
|
|
} |
|
|
} |
|
|
|
|
|
@classmethod |
|
|
def convert_tools( |
|
|
cls, |
|
|
tools: list[Tool], |
|
|
handoffs: list[Handoff[Any, Any]], |
|
|
) -> ConvertedTools: |
|
|
converted_tools: list[ToolParam] = [] |
|
|
includes: list[ResponseIncludable] = [] |
|
|
|
|
|
computer_tools = [tool for tool in tools if isinstance(tool, ComputerTool)] |
|
|
if len(computer_tools) > 1: |
|
|
raise UserError(f"You can only provide one computer tool. Got {len(computer_tools)}") |
|
|
|
|
|
for tool in tools: |
|
|
converted_tool, include = cls._convert_tool(tool) |
|
|
converted_tools.append(converted_tool) |
|
|
if include: |
|
|
includes.append(include) |
|
|
|
|
|
for handoff in handoffs: |
|
|
converted_tools.append(cls._convert_handoff_tool(handoff)) |
|
|
|
|
|
return ConvertedTools(tools=converted_tools, includes=includes) |
|
|
|
|
|
@classmethod |
|
|
def _convert_tool(cls, tool: Tool) -> tuple[ToolParam, ResponseIncludable | None]: |
|
|
"""Returns converted tool and includes""" |
|
|
|
|
|
if isinstance(tool, FunctionTool): |
|
|
converted_tool: ToolParam = { |
|
|
"name": tool.name, |
|
|
"parameters": tool.params_json_schema, |
|
|
"strict": tool.strict_json_schema, |
|
|
"type": "function", |
|
|
"description": tool.description, |
|
|
} |
|
|
includes: ResponseIncludable | None = None |
|
|
elif isinstance(tool, WebSearchTool): |
|
|
|
|
|
converted_tool = { |
|
|
"type": "web_search", |
|
|
"filters": tool.filters.model_dump() if tool.filters is not None else None, |
|
|
"user_location": tool.user_location, |
|
|
"search_context_size": tool.search_context_size, |
|
|
} |
|
|
includes = None |
|
|
elif isinstance(tool, FileSearchTool): |
|
|
converted_tool = { |
|
|
"type": "file_search", |
|
|
"vector_store_ids": tool.vector_store_ids, |
|
|
} |
|
|
if tool.max_num_results: |
|
|
converted_tool["max_num_results"] = tool.max_num_results |
|
|
if tool.ranking_options: |
|
|
converted_tool["ranking_options"] = tool.ranking_options |
|
|
if tool.filters: |
|
|
converted_tool["filters"] = tool.filters |
|
|
|
|
|
includes = "file_search_call.results" if tool.include_search_results else None |
|
|
elif isinstance(tool, ComputerTool): |
|
|
converted_tool = { |
|
|
"type": "computer_use_preview", |
|
|
"environment": tool.computer.environment, |
|
|
"display_width": tool.computer.dimensions[0], |
|
|
"display_height": tool.computer.dimensions[1], |
|
|
} |
|
|
includes = None |
|
|
elif isinstance(tool, HostedMCPTool): |
|
|
converted_tool = tool.tool_config |
|
|
includes = None |
|
|
elif isinstance(tool, ImageGenerationTool): |
|
|
converted_tool = tool.tool_config |
|
|
includes = None |
|
|
elif isinstance(tool, CodeInterpreterTool): |
|
|
converted_tool = tool.tool_config |
|
|
includes = None |
|
|
elif isinstance(tool, LocalShellTool): |
|
|
converted_tool = { |
|
|
"type": "local_shell", |
|
|
} |
|
|
includes = None |
|
|
else: |
|
|
raise UserError(f"Unknown tool type: {type(tool)}, tool") |
|
|
|
|
|
return converted_tool, includes |
|
|
|
|
|
@classmethod |
|
|
def _convert_handoff_tool(cls, handoff: Handoff) -> ToolParam: |
|
|
return { |
|
|
"name": handoff.tool_name, |
|
|
"parameters": handoff.input_json_schema, |
|
|
"strict": handoff.strict_json_schema, |
|
|
"type": "function", |
|
|
"description": handoff.tool_description, |
|
|
} |
|
|
|