from __future__ import annotations import json from collections.abc import AsyncIterator from contextvars import ContextVar from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Literal, cast, overload from openai import NOT_GIVEN, APIStatusError, AsyncOpenAI, AsyncStream, NotGiven from openai.types import ChatModel from openai.types.responses import ( Response, ResponseCompletedEvent, ResponseIncludable, ResponseStreamEvent, ResponseTextConfigParam, ToolParam, response_create_params, ) from openai.types.responses.response_prompt_param import ResponsePromptParam from .. import _debug from ..agent_output import AgentOutputSchemaBase from ..exceptions import UserError from ..handoffs import Handoff from ..items import ItemHelpers, ModelResponse, TResponseInputItem from ..logger import logger from ..model_settings import MCPToolChoice from ..tool import ( CodeInterpreterTool, ComputerTool, FileSearchTool, FunctionTool, HostedMCPTool, ImageGenerationTool, LocalShellTool, Tool, WebSearchTool, ) from ..tracing import SpanError, response_span from ..usage import Usage from ..util._json import _to_dump_compatible from ..version import __version__ from .interface import Model, ModelTracing if TYPE_CHECKING: from ..model_settings import ModelSettings _USER_AGENT = f"Agents/Python {__version__}" _HEADERS = {"User-Agent": _USER_AGENT} # Override headers used by the Responses API. _HEADERS_OVERRIDE: ContextVar[dict[str, str] | None] = ContextVar( "openai_responses_headers_override", default=None ) class OpenAIResponsesModel(Model): """ Implementation of `Model` that uses the OpenAI Responses API. """ def __init__( self, model: str | ChatModel, openai_client: AsyncOpenAI, ) -> None: self.model = model self._client = openai_client def _non_null_or_not_given(self, value: Any) -> Any: return value if value is not None else NOT_GIVEN async def get_response( self, system_instructions: str | None, input: str | list[TResponseInputItem], model_settings: ModelSettings, tools: list[Tool], output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], tracing: ModelTracing, previous_response_id: str | None = None, conversation_id: str | None = None, prompt: ResponsePromptParam | None = None, ) -> ModelResponse: with response_span(disabled=tracing.is_disabled()) as span_response: try: response = await self._fetch_response( system_instructions, input, model_settings, tools, output_schema, handoffs, previous_response_id=previous_response_id, conversation_id=conversation_id, stream=False, prompt=prompt, ) if _debug.DONT_LOG_MODEL_DATA: logger.debug("LLM responded") else: logger.debug( "LLM resp:\n" f"""{ json.dumps( [x.model_dump() for x in response.output], indent=2, ensure_ascii=False, ) }\n""" ) usage = ( Usage( requests=1, input_tokens=response.usage.input_tokens, output_tokens=response.usage.output_tokens, total_tokens=response.usage.total_tokens, input_tokens_details=response.usage.input_tokens_details, output_tokens_details=response.usage.output_tokens_details, ) if response.usage else Usage() ) if tracing.include_data(): span_response.span_data.response = response span_response.span_data.input = input except Exception as e: span_response.set_error( SpanError( message="Error getting response", data={ "error": str(e) if tracing.include_data() else e.__class__.__name__, }, ) ) request_id = e.request_id if isinstance(e, APIStatusError) else None logger.error(f"Error getting response: {e}. (request_id: {request_id})") raise return ModelResponse( output=response.output, usage=usage, response_id=response.id, ) async def stream_response( self, system_instructions: str | None, input: str | list[TResponseInputItem], model_settings: ModelSettings, tools: list[Tool], output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], tracing: ModelTracing, previous_response_id: str | None = None, conversation_id: str | None = None, prompt: ResponsePromptParam | None = None, ) -> AsyncIterator[ResponseStreamEvent]: """ Yields a partial message as it is generated, as well as the usage information. """ with response_span(disabled=tracing.is_disabled()) as span_response: try: stream = await self._fetch_response( system_instructions, input, model_settings, tools, output_schema, handoffs, previous_response_id=previous_response_id, conversation_id=conversation_id, stream=True, prompt=prompt, ) final_response: Response | None = None async for chunk in stream: if isinstance(chunk, ResponseCompletedEvent): final_response = chunk.response yield chunk if final_response and tracing.include_data(): span_response.span_data.response = final_response span_response.span_data.input = input except Exception as e: span_response.set_error( SpanError( message="Error streaming response", data={ "error": str(e) if tracing.include_data() else e.__class__.__name__, }, ) ) logger.error(f"Error streaming response: {e}") raise @overload async def _fetch_response( self, system_instructions: str | None, input: str | list[TResponseInputItem], model_settings: ModelSettings, tools: list[Tool], output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], previous_response_id: str | None, conversation_id: str | None, stream: Literal[True], prompt: ResponsePromptParam | None = None, ) -> AsyncStream[ResponseStreamEvent]: ... @overload async def _fetch_response( self, system_instructions: str | None, input: str | list[TResponseInputItem], model_settings: ModelSettings, tools: list[Tool], output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], previous_response_id: str | None, conversation_id: str | None, stream: Literal[False], prompt: ResponsePromptParam | None = None, ) -> Response: ... async def _fetch_response( self, system_instructions: str | None, input: str | list[TResponseInputItem], model_settings: ModelSettings, tools: list[Tool], output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], previous_response_id: str | None = None, conversation_id: str | None = None, stream: Literal[True] | Literal[False] = False, prompt: ResponsePromptParam | None = None, ) -> Response | AsyncStream[ResponseStreamEvent]: list_input = ItemHelpers.input_to_new_input_list(input) list_input = _to_dump_compatible(list_input) parallel_tool_calls = ( True if model_settings.parallel_tool_calls and tools and len(tools) > 0 else False if model_settings.parallel_tool_calls is False else NOT_GIVEN ) tool_choice = Converter.convert_tool_choice(model_settings.tool_choice) converted_tools = Converter.convert_tools(tools, handoffs) converted_tools_payload = _to_dump_compatible(converted_tools.tools) response_format = Converter.get_response_format(output_schema) include_set: set[str] = set(converted_tools.includes) if model_settings.response_include is not None: include_set.update(model_settings.response_include) if model_settings.top_logprobs is not None: include_set.add("message.output_text.logprobs") include = cast(list[ResponseIncludable], list(include_set)) if _debug.DONT_LOG_MODEL_DATA: logger.debug("Calling LLM") else: input_json = json.dumps( list_input, indent=2, ensure_ascii=False, ) tools_json = json.dumps( converted_tools_payload, indent=2, ensure_ascii=False, ) logger.debug( f"Calling LLM {self.model} with input:\n" f"{input_json}\n" f"Tools:\n{tools_json}\n" f"Stream: {stream}\n" f"Tool choice: {tool_choice}\n" f"Response format: {response_format}\n" f"Previous response id: {previous_response_id}\n" f"Conversation id: {conversation_id}\n" ) extra_args = dict(model_settings.extra_args or {}) if model_settings.top_logprobs is not None: extra_args["top_logprobs"] = model_settings.top_logprobs if model_settings.verbosity is not None: if response_format != NOT_GIVEN: response_format["verbosity"] = model_settings.verbosity # type: ignore [index] else: response_format = {"verbosity": model_settings.verbosity} return await self._client.responses.create( previous_response_id=self._non_null_or_not_given(previous_response_id), conversation=self._non_null_or_not_given(conversation_id), instructions=self._non_null_or_not_given(system_instructions), model=self.model, input=list_input, include=include, tools=converted_tools_payload, prompt=self._non_null_or_not_given(prompt), temperature=self._non_null_or_not_given(model_settings.temperature), top_p=self._non_null_or_not_given(model_settings.top_p), truncation=self._non_null_or_not_given(model_settings.truncation), max_output_tokens=self._non_null_or_not_given(model_settings.max_tokens), tool_choice=tool_choice, parallel_tool_calls=parallel_tool_calls, stream=stream, extra_headers=self._merge_headers(model_settings), extra_query=model_settings.extra_query, extra_body=model_settings.extra_body, text=response_format, store=self._non_null_or_not_given(model_settings.store), reasoning=self._non_null_or_not_given(model_settings.reasoning), metadata=self._non_null_or_not_given(model_settings.metadata), **extra_args, ) def _get_client(self) -> AsyncOpenAI: if self._client is None: self._client = AsyncOpenAI() return self._client def _merge_headers(self, model_settings: ModelSettings): return { **_HEADERS, **(model_settings.extra_headers or {}), **(_HEADERS_OVERRIDE.get() or {}), } @dataclass class ConvertedTools: tools: list[ToolParam] includes: list[ResponseIncludable] class Converter: @classmethod def convert_tool_choice( cls, tool_choice: Literal["auto", "required", "none"] | str | MCPToolChoice | None ) -> response_create_params.ToolChoice | NotGiven: if tool_choice is None: return NOT_GIVEN elif isinstance(tool_choice, MCPToolChoice): return { "server_label": tool_choice.server_label, "type": "mcp", "name": tool_choice.name, } elif tool_choice == "required": return "required" elif tool_choice == "auto": return "auto" elif tool_choice == "none": return "none" elif tool_choice == "file_search": return { "type": "file_search", } elif tool_choice == "web_search": return { # TODO: revist the type: ignore comment when ToolChoice is updated in the future "type": "web_search", # type: ignore [typeddict-item] } elif tool_choice == "web_search_preview": return { "type": "web_search_preview", } elif tool_choice == "computer_use_preview": return { "type": "computer_use_preview", } elif tool_choice == "image_generation": return { "type": "image_generation", } elif tool_choice == "code_interpreter": return { "type": "code_interpreter", } elif tool_choice == "mcp": # Note that this is still here for backwards compatibility, # but migrating to MCPToolChoice is recommended. return {"type": "mcp"} # type: ignore [typeddict-item] else: return { "type": "function", "name": tool_choice, } @classmethod def get_response_format( cls, output_schema: AgentOutputSchemaBase | None ) -> ResponseTextConfigParam | NotGiven: if output_schema is None or output_schema.is_plain_text(): return NOT_GIVEN else: return { "format": { "type": "json_schema", "name": "final_output", "schema": output_schema.json_schema(), "strict": output_schema.is_strict_json_schema(), } } @classmethod def convert_tools( cls, tools: list[Tool], handoffs: list[Handoff[Any, Any]], ) -> ConvertedTools: converted_tools: list[ToolParam] = [] includes: list[ResponseIncludable] = [] computer_tools = [tool for tool in tools if isinstance(tool, ComputerTool)] if len(computer_tools) > 1: raise UserError(f"You can only provide one computer tool. Got {len(computer_tools)}") for tool in tools: converted_tool, include = cls._convert_tool(tool) converted_tools.append(converted_tool) if include: includes.append(include) for handoff in handoffs: converted_tools.append(cls._convert_handoff_tool(handoff)) return ConvertedTools(tools=converted_tools, includes=includes) @classmethod def _convert_tool(cls, tool: Tool) -> tuple[ToolParam, ResponseIncludable | None]: """Returns converted tool and includes""" if isinstance(tool, FunctionTool): converted_tool: ToolParam = { "name": tool.name, "parameters": tool.params_json_schema, "strict": tool.strict_json_schema, "type": "function", "description": tool.description, } includes: ResponseIncludable | None = None elif isinstance(tool, WebSearchTool): # TODO: revist the type: ignore comment when ToolParam is updated in the future converted_tool = { "type": "web_search", "filters": tool.filters.model_dump() if tool.filters is not None else None, # type: ignore [typeddict-item] "user_location": tool.user_location, "search_context_size": tool.search_context_size, } includes = None elif isinstance(tool, FileSearchTool): converted_tool = { "type": "file_search", "vector_store_ids": tool.vector_store_ids, } if tool.max_num_results: converted_tool["max_num_results"] = tool.max_num_results if tool.ranking_options: converted_tool["ranking_options"] = tool.ranking_options if tool.filters: converted_tool["filters"] = tool.filters includes = "file_search_call.results" if tool.include_search_results else None elif isinstance(tool, ComputerTool): converted_tool = { "type": "computer_use_preview", "environment": tool.computer.environment, "display_width": tool.computer.dimensions[0], "display_height": tool.computer.dimensions[1], } includes = None elif isinstance(tool, HostedMCPTool): converted_tool = tool.tool_config includes = None elif isinstance(tool, ImageGenerationTool): converted_tool = tool.tool_config includes = None elif isinstance(tool, CodeInterpreterTool): converted_tool = tool.tool_config includes = None elif isinstance(tool, LocalShellTool): converted_tool = { "type": "local_shell", } includes = None else: raise UserError(f"Unknown tool type: {type(tool)}, tool") return converted_tool, includes @classmethod def _convert_handoff_tool(cls, handoff: Handoff) -> ToolParam: return { "name": handoff.tool_name, "parameters": handoff.input_json_schema, "strict": handoff.strict_json_schema, "type": "function", "description": handoff.tool_description, }