|
|
from __future__ import annotations |
|
|
|
|
|
import abc |
|
|
import asyncio |
|
|
from collections.abc import AsyncIterator |
|
|
from dataclasses import dataclass, field |
|
|
from typing import TYPE_CHECKING, Any, cast |
|
|
|
|
|
from typing_extensions import TypeVar |
|
|
|
|
|
from ._run_impl import QueueCompleteSentinel |
|
|
from .agent import Agent |
|
|
from .agent_output import AgentOutputSchemaBase |
|
|
from .exceptions import ( |
|
|
AgentsException, |
|
|
InputGuardrailTripwireTriggered, |
|
|
MaxTurnsExceeded, |
|
|
RunErrorDetails, |
|
|
) |
|
|
from .guardrail import InputGuardrailResult, OutputGuardrailResult |
|
|
from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem |
|
|
from .logger import logger |
|
|
from .run_context import RunContextWrapper |
|
|
from .stream_events import StreamEvent |
|
|
from .tracing import Trace |
|
|
from .util._pretty_print import ( |
|
|
pretty_print_result, |
|
|
pretty_print_run_result_streaming, |
|
|
) |
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from ._run_impl import QueueCompleteSentinel |
|
|
from .agent import Agent |
|
|
from .tool_guardrails import ToolInputGuardrailResult, ToolOutputGuardrailResult |
|
|
|
|
|
T = TypeVar("T") |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class RunResultBase(abc.ABC): |
|
|
input: str | list[TResponseInputItem] |
|
|
"""The original input items i.e. the items before run() was called. This may be a mutated |
|
|
version of the input, if there are handoff input filters that mutate the input. |
|
|
""" |
|
|
|
|
|
new_items: list[RunItem] |
|
|
"""The new items generated during the agent run. These include things like new messages, tool |
|
|
calls and their outputs, etc. |
|
|
""" |
|
|
|
|
|
raw_responses: list[ModelResponse] |
|
|
"""The raw LLM responses generated by the model during the agent run.""" |
|
|
|
|
|
final_output: Any |
|
|
"""The output of the last agent.""" |
|
|
|
|
|
input_guardrail_results: list[InputGuardrailResult] |
|
|
"""Guardrail results for the input messages.""" |
|
|
|
|
|
output_guardrail_results: list[OutputGuardrailResult] |
|
|
"""Guardrail results for the final output of the agent.""" |
|
|
|
|
|
tool_input_guardrail_results: list[ToolInputGuardrailResult] |
|
|
"""Tool input guardrail results from all tools executed during the run.""" |
|
|
|
|
|
tool_output_guardrail_results: list[ToolOutputGuardrailResult] |
|
|
"""Tool output guardrail results from all tools executed during the run.""" |
|
|
|
|
|
context_wrapper: RunContextWrapper[Any] |
|
|
"""The context wrapper for the agent run.""" |
|
|
|
|
|
@property |
|
|
@abc.abstractmethod |
|
|
def last_agent(self) -> Agent[Any]: |
|
|
"""The last agent that was run.""" |
|
|
|
|
|
def final_output_as(self, cls: type[T], raise_if_incorrect_type: bool = False) -> T: |
|
|
"""A convenience method to cast the final output to a specific type. By default, the cast |
|
|
is only for the typechecker. If you set `raise_if_incorrect_type` to True, we'll raise a |
|
|
TypeError if the final output is not of the given type. |
|
|
|
|
|
Args: |
|
|
cls: The type to cast the final output to. |
|
|
raise_if_incorrect_type: If True, we'll raise a TypeError if the final output is not of |
|
|
the given type. |
|
|
|
|
|
Returns: |
|
|
The final output casted to the given type. |
|
|
""" |
|
|
if raise_if_incorrect_type and not isinstance(self.final_output, cls): |
|
|
raise TypeError(f"Final output is not of type {cls.__name__}") |
|
|
|
|
|
return cast(T, self.final_output) |
|
|
|
|
|
def to_input_list(self) -> list[TResponseInputItem]: |
|
|
"""Creates a new input list, merging the original input with all the new items generated.""" |
|
|
original_items: list[TResponseInputItem] = ItemHelpers.input_to_new_input_list(self.input) |
|
|
new_items = [item.to_input_item() for item in self.new_items] |
|
|
|
|
|
return original_items + new_items |
|
|
|
|
|
@property |
|
|
def last_response_id(self) -> str | None: |
|
|
"""Convenience method to get the response ID of the last model response.""" |
|
|
if not self.raw_responses: |
|
|
return None |
|
|
|
|
|
return self.raw_responses[-1].response_id |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class RunResult(RunResultBase): |
|
|
_last_agent: Agent[Any] |
|
|
|
|
|
@property |
|
|
def last_agent(self) -> Agent[Any]: |
|
|
"""The last agent that was run.""" |
|
|
return self._last_agent |
|
|
|
|
|
def __str__(self) -> str: |
|
|
return pretty_print_result(self) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class RunResultStreaming(RunResultBase): |
|
|
"""The result of an agent run in streaming mode. You can use the `stream_events` method to |
|
|
receive semantic events as they are generated. |
|
|
|
|
|
The streaming method will raise: |
|
|
- A MaxTurnsExceeded exception if the agent exceeds the max_turns limit. |
|
|
- A GuardrailTripwireTriggered exception if a guardrail is tripped. |
|
|
""" |
|
|
|
|
|
current_agent: Agent[Any] |
|
|
"""The current agent that is running.""" |
|
|
|
|
|
current_turn: int |
|
|
"""The current turn number.""" |
|
|
|
|
|
max_turns: int |
|
|
"""The maximum number of turns the agent can run for.""" |
|
|
|
|
|
final_output: Any |
|
|
"""The final output of the agent. This is None until the agent has finished running.""" |
|
|
|
|
|
_current_agent_output_schema: AgentOutputSchemaBase | None = field(repr=False) |
|
|
|
|
|
trace: Trace | None = field(repr=False) |
|
|
|
|
|
is_complete: bool = False |
|
|
"""Whether the agent has finished running.""" |
|
|
|
|
|
|
|
|
_event_queue: asyncio.Queue[StreamEvent | QueueCompleteSentinel] = field( |
|
|
default_factory=asyncio.Queue, repr=False |
|
|
) |
|
|
_input_guardrail_queue: asyncio.Queue[InputGuardrailResult] = field( |
|
|
default_factory=asyncio.Queue, repr=False |
|
|
) |
|
|
|
|
|
|
|
|
_run_impl_task: asyncio.Task[Any] | None = field(default=None, repr=False) |
|
|
_input_guardrails_task: asyncio.Task[Any] | None = field(default=None, repr=False) |
|
|
_output_guardrails_task: asyncio.Task[Any] | None = field(default=None, repr=False) |
|
|
_stored_exception: Exception | None = field(default=None, repr=False) |
|
|
|
|
|
@property |
|
|
def last_agent(self) -> Agent[Any]: |
|
|
"""The last agent that was run. Updates as the agent run progresses, so the true last agent |
|
|
is only available after the agent run is complete. |
|
|
""" |
|
|
return self.current_agent |
|
|
|
|
|
def cancel(self) -> None: |
|
|
"""Cancels the streaming run, stopping all background tasks and marking the run as |
|
|
complete.""" |
|
|
self._cleanup_tasks() |
|
|
self.is_complete = True |
|
|
|
|
|
|
|
|
while not self._event_queue.empty(): |
|
|
self._event_queue.get_nowait() |
|
|
while not self._input_guardrail_queue.empty(): |
|
|
self._input_guardrail_queue.get_nowait() |
|
|
|
|
|
async def stream_events(self) -> AsyncIterator[StreamEvent]: |
|
|
"""Stream deltas for new items as they are generated. We're using the types from the |
|
|
OpenAI Responses API, so these are semantic events: each event has a `type` field that |
|
|
describes the type of the event, along with the data for that event. |
|
|
|
|
|
This will raise: |
|
|
- A MaxTurnsExceeded exception if the agent exceeds the max_turns limit. |
|
|
- A GuardrailTripwireTriggered exception if a guardrail is tripped. |
|
|
""" |
|
|
try: |
|
|
while True: |
|
|
self._check_errors() |
|
|
if self._stored_exception: |
|
|
logger.debug("Breaking due to stored exception") |
|
|
self.is_complete = True |
|
|
break |
|
|
|
|
|
if self.is_complete and self._event_queue.empty(): |
|
|
break |
|
|
|
|
|
try: |
|
|
item = await self._event_queue.get() |
|
|
except asyncio.CancelledError: |
|
|
break |
|
|
|
|
|
if isinstance(item, QueueCompleteSentinel): |
|
|
|
|
|
|
|
|
await self._await_task_safely(self._input_guardrails_task) |
|
|
|
|
|
self._event_queue.task_done() |
|
|
|
|
|
|
|
|
|
|
|
self._check_errors() |
|
|
break |
|
|
|
|
|
yield item |
|
|
self._event_queue.task_done() |
|
|
finally: |
|
|
|
|
|
|
|
|
await self._await_task_safely(self._run_impl_task) |
|
|
|
|
|
self._cleanup_tasks() |
|
|
|
|
|
if self._stored_exception: |
|
|
raise self._stored_exception |
|
|
|
|
|
def _create_error_details(self) -> RunErrorDetails: |
|
|
"""Return a `RunErrorDetails` object considering the current attributes of the class.""" |
|
|
return RunErrorDetails( |
|
|
input=self.input, |
|
|
new_items=self.new_items, |
|
|
raw_responses=self.raw_responses, |
|
|
last_agent=self.current_agent, |
|
|
context_wrapper=self.context_wrapper, |
|
|
input_guardrail_results=self.input_guardrail_results, |
|
|
output_guardrail_results=self.output_guardrail_results, |
|
|
) |
|
|
|
|
|
def _check_errors(self): |
|
|
if self.current_turn > self.max_turns: |
|
|
max_turns_exc = MaxTurnsExceeded(f"Max turns ({self.max_turns}) exceeded") |
|
|
max_turns_exc.run_data = self._create_error_details() |
|
|
self._stored_exception = max_turns_exc |
|
|
|
|
|
|
|
|
while not self._input_guardrail_queue.empty(): |
|
|
guardrail_result = self._input_guardrail_queue.get_nowait() |
|
|
if guardrail_result.output.tripwire_triggered: |
|
|
tripwire_exc = InputGuardrailTripwireTriggered(guardrail_result) |
|
|
tripwire_exc.run_data = self._create_error_details() |
|
|
self._stored_exception = tripwire_exc |
|
|
|
|
|
|
|
|
if self._run_impl_task and self._run_impl_task.done(): |
|
|
run_impl_exc = self._run_impl_task.exception() |
|
|
if run_impl_exc and isinstance(run_impl_exc, Exception): |
|
|
if isinstance(run_impl_exc, AgentsException) and run_impl_exc.run_data is None: |
|
|
run_impl_exc.run_data = self._create_error_details() |
|
|
self._stored_exception = run_impl_exc |
|
|
|
|
|
if self._input_guardrails_task and self._input_guardrails_task.done(): |
|
|
in_guard_exc = self._input_guardrails_task.exception() |
|
|
if in_guard_exc and isinstance(in_guard_exc, Exception): |
|
|
if isinstance(in_guard_exc, AgentsException) and in_guard_exc.run_data is None: |
|
|
in_guard_exc.run_data = self._create_error_details() |
|
|
self._stored_exception = in_guard_exc |
|
|
|
|
|
if self._output_guardrails_task and self._output_guardrails_task.done(): |
|
|
out_guard_exc = self._output_guardrails_task.exception() |
|
|
if out_guard_exc and isinstance(out_guard_exc, Exception): |
|
|
if isinstance(out_guard_exc, AgentsException) and out_guard_exc.run_data is None: |
|
|
out_guard_exc.run_data = self._create_error_details() |
|
|
self._stored_exception = out_guard_exc |
|
|
|
|
|
def _cleanup_tasks(self): |
|
|
if self._run_impl_task and not self._run_impl_task.done(): |
|
|
self._run_impl_task.cancel() |
|
|
|
|
|
if self._input_guardrails_task and not self._input_guardrails_task.done(): |
|
|
self._input_guardrails_task.cancel() |
|
|
|
|
|
if self._output_guardrails_task and not self._output_guardrails_task.done(): |
|
|
self._output_guardrails_task.cancel() |
|
|
|
|
|
def __str__(self) -> str: |
|
|
return pretty_print_run_result_streaming(self) |
|
|
|
|
|
async def _await_task_safely(self, task: asyncio.Task[Any] | None) -> None: |
|
|
"""Await a task if present, ignoring cancellation and storing exceptions elsewhere. |
|
|
|
|
|
This ensures we do not lose late guardrail exceptions while not surfacing |
|
|
CancelledError to callers of stream_events. |
|
|
""" |
|
|
if task and not task.done(): |
|
|
try: |
|
|
await task |
|
|
except asyncio.CancelledError: |
|
|
|
|
|
pass |
|
|
except Exception: |
|
|
|
|
|
pass |
|
|
|