|
|
from dataclasses import dataclass, field, fields |
|
|
from typing import Any, Optional |
|
|
|
|
|
from openai.types.responses import ResponseFunctionToolCall |
|
|
|
|
|
from .run_context import RunContextWrapper, TContext |
|
|
|
|
|
|
|
|
def _assert_must_pass_tool_call_id() -> str: |
|
|
raise ValueError("tool_call_id must be passed to ToolContext") |
|
|
|
|
|
|
|
|
def _assert_must_pass_tool_name() -> str: |
|
|
raise ValueError("tool_name must be passed to ToolContext") |
|
|
|
|
|
|
|
|
def _assert_must_pass_tool_arguments() -> str: |
|
|
raise ValueError("tool_arguments must be passed to ToolContext") |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ToolContext(RunContextWrapper[TContext]): |
|
|
"""The context of a tool call.""" |
|
|
|
|
|
tool_name: str = field(default_factory=_assert_must_pass_tool_name) |
|
|
"""The name of the tool being invoked.""" |
|
|
|
|
|
tool_call_id: str = field(default_factory=_assert_must_pass_tool_call_id) |
|
|
"""The ID of the tool call.""" |
|
|
|
|
|
tool_arguments: str = field(default_factory=_assert_must_pass_tool_arguments) |
|
|
"""The raw arguments string of the tool call.""" |
|
|
|
|
|
@classmethod |
|
|
def from_agent_context( |
|
|
cls, |
|
|
context: RunContextWrapper[TContext], |
|
|
tool_call_id: str, |
|
|
tool_call: Optional[ResponseFunctionToolCall] = None, |
|
|
) -> "ToolContext": |
|
|
""" |
|
|
Create a ToolContext from a RunContextWrapper. |
|
|
""" |
|
|
|
|
|
base_values: dict[str, Any] = { |
|
|
f.name: getattr(context, f.name) for f in fields(RunContextWrapper) if f.init |
|
|
} |
|
|
tool_name = tool_call.name if tool_call is not None else _assert_must_pass_tool_name() |
|
|
tool_args = ( |
|
|
tool_call.arguments if tool_call is not None else _assert_must_pass_tool_arguments() |
|
|
) |
|
|
|
|
|
return cls( |
|
|
tool_name=tool_name, tool_call_id=tool_call_id, tool_arguments=tool_args, **base_values |
|
|
) |
|
|
|