|
|
from __future__ import annotations |
|
|
|
|
|
import inspect |
|
|
from collections.abc import Awaitable |
|
|
from dataclasses import dataclass, field |
|
|
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, overload |
|
|
|
|
|
from typing_extensions import TypedDict, TypeVar |
|
|
|
|
|
from .exceptions import UserError |
|
|
from .tool_context import ToolContext |
|
|
from .util._types import MaybeAwaitable |
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from .agent import Agent |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ToolInputGuardrailResult: |
|
|
"""The result of a tool input guardrail run.""" |
|
|
|
|
|
guardrail: ToolInputGuardrail[Any] |
|
|
"""The guardrail that was run.""" |
|
|
|
|
|
output: ToolGuardrailFunctionOutput |
|
|
"""The output of the guardrail function.""" |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ToolOutputGuardrailResult: |
|
|
"""The result of a tool output guardrail run.""" |
|
|
|
|
|
guardrail: ToolOutputGuardrail[Any] |
|
|
"""The guardrail that was run.""" |
|
|
|
|
|
output: ToolGuardrailFunctionOutput |
|
|
"""The output of the guardrail function.""" |
|
|
|
|
|
|
|
|
class RejectContentBehavior(TypedDict): |
|
|
"""Rejects the tool call/output but continues execution with a message to the model.""" |
|
|
|
|
|
type: Literal["reject_content"] |
|
|
message: str |
|
|
|
|
|
|
|
|
class RaiseExceptionBehavior(TypedDict): |
|
|
"""Raises an exception to halt execution.""" |
|
|
|
|
|
type: Literal["raise_exception"] |
|
|
|
|
|
|
|
|
class AllowBehavior(TypedDict): |
|
|
"""Allows normal tool execution to continue.""" |
|
|
|
|
|
type: Literal["allow"] |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ToolGuardrailFunctionOutput: |
|
|
"""The output of a tool guardrail function.""" |
|
|
|
|
|
output_info: Any |
|
|
""" |
|
|
Optional data about checks performed. For example, the guardrail could include |
|
|
information about the checks it performed and granular results. |
|
|
""" |
|
|
|
|
|
behavior: RejectContentBehavior | RaiseExceptionBehavior | AllowBehavior = field( |
|
|
default_factory=lambda: AllowBehavior(type="allow") |
|
|
) |
|
|
""" |
|
|
Defines how the system should respond when this guardrail result is processed. |
|
|
- allow: Allow normal tool execution to continue without interference (default) |
|
|
- reject_content: Reject the tool call/output but continue execution with a message to the model |
|
|
- raise_exception: Halt execution by raising a ToolGuardrailTripwireTriggered exception |
|
|
""" |
|
|
|
|
|
@classmethod |
|
|
def allow(cls, output_info: Any = None) -> ToolGuardrailFunctionOutput: |
|
|
"""Create a guardrail output that allows the tool execution to continue normally. |
|
|
|
|
|
Args: |
|
|
output_info: Optional data about checks performed. |
|
|
|
|
|
Returns: |
|
|
ToolGuardrailFunctionOutput configured to allow normal execution. |
|
|
""" |
|
|
return cls(output_info=output_info, behavior=AllowBehavior(type="allow")) |
|
|
|
|
|
@classmethod |
|
|
def reject_content(cls, message: str, output_info: Any = None) -> ToolGuardrailFunctionOutput: |
|
|
"""Create a guardrail output that rejects the tool call/output but continues execution. |
|
|
|
|
|
Args: |
|
|
message: Message to send to the model instead of the tool result. |
|
|
output_info: Optional data about checks performed. |
|
|
|
|
|
Returns: |
|
|
ToolGuardrailFunctionOutput configured to reject the content. |
|
|
""" |
|
|
return cls( |
|
|
output_info=output_info, |
|
|
behavior=RejectContentBehavior(type="reject_content", message=message), |
|
|
) |
|
|
|
|
|
@classmethod |
|
|
def raise_exception(cls, output_info: Any = None) -> ToolGuardrailFunctionOutput: |
|
|
"""Create a guardrail output that raises an exception to halt execution. |
|
|
|
|
|
Args: |
|
|
output_info: Optional data about checks performed. |
|
|
|
|
|
Returns: |
|
|
ToolGuardrailFunctionOutput configured to raise an exception. |
|
|
""" |
|
|
return cls(output_info=output_info, behavior=RaiseExceptionBehavior(type="raise_exception")) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ToolInputGuardrailData: |
|
|
"""Input data passed to a tool input guardrail function.""" |
|
|
|
|
|
context: ToolContext[Any] |
|
|
""" |
|
|
The tool context containing information about the current tool execution. |
|
|
""" |
|
|
|
|
|
agent: Agent[Any] |
|
|
""" |
|
|
The agent that is executing the tool. |
|
|
""" |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ToolOutputGuardrailData(ToolInputGuardrailData): |
|
|
"""Input data passed to a tool output guardrail function. |
|
|
|
|
|
Extends input data with the tool's output. |
|
|
""" |
|
|
|
|
|
output: Any |
|
|
""" |
|
|
The output produced by the tool function. |
|
|
""" |
|
|
|
|
|
|
|
|
TContext_co = TypeVar("TContext_co", bound=Any, covariant=True) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ToolInputGuardrail(Generic[TContext_co]): |
|
|
"""A guardrail that runs before a function tool is invoked.""" |
|
|
|
|
|
guardrail_function: Callable[ |
|
|
[ToolInputGuardrailData], MaybeAwaitable[ToolGuardrailFunctionOutput] |
|
|
] |
|
|
""" |
|
|
The function that implements the guardrail logic. |
|
|
""" |
|
|
|
|
|
name: str | None = None |
|
|
""" |
|
|
Optional name for the guardrail. If not provided, uses the function name. |
|
|
""" |
|
|
|
|
|
def get_name(self) -> str: |
|
|
return self.name or self.guardrail_function.__name__ |
|
|
|
|
|
async def run(self, data: ToolInputGuardrailData) -> ToolGuardrailFunctionOutput: |
|
|
if not callable(self.guardrail_function): |
|
|
raise UserError(f"Guardrail function must be callable, got {self.guardrail_function}") |
|
|
|
|
|
result = self.guardrail_function(data) |
|
|
if inspect.isawaitable(result): |
|
|
return await result |
|
|
return result |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ToolOutputGuardrail(Generic[TContext_co]): |
|
|
"""A guardrail that runs after a function tool is invoked.""" |
|
|
|
|
|
guardrail_function: Callable[ |
|
|
[ToolOutputGuardrailData], MaybeAwaitable[ToolGuardrailFunctionOutput] |
|
|
] |
|
|
""" |
|
|
The function that implements the guardrail logic. |
|
|
""" |
|
|
|
|
|
name: str | None = None |
|
|
""" |
|
|
Optional name for the guardrail. If not provided, uses the function name. |
|
|
""" |
|
|
|
|
|
def get_name(self) -> str: |
|
|
return self.name or self.guardrail_function.__name__ |
|
|
|
|
|
async def run(self, data: ToolOutputGuardrailData) -> ToolGuardrailFunctionOutput: |
|
|
if not callable(self.guardrail_function): |
|
|
raise UserError(f"Guardrail function must be callable, got {self.guardrail_function}") |
|
|
|
|
|
result = self.guardrail_function(data) |
|
|
if inspect.isawaitable(result): |
|
|
return await result |
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
_ToolInputFuncSync = Callable[[ToolInputGuardrailData], ToolGuardrailFunctionOutput] |
|
|
_ToolInputFuncAsync = Callable[[ToolInputGuardrailData], Awaitable[ToolGuardrailFunctionOutput]] |
|
|
|
|
|
|
|
|
@overload |
|
|
def tool_input_guardrail(func: _ToolInputFuncSync): ... |
|
|
|
|
|
|
|
|
@overload |
|
|
def tool_input_guardrail(func: _ToolInputFuncAsync): ... |
|
|
|
|
|
|
|
|
@overload |
|
|
def tool_input_guardrail( |
|
|
*, name: str | None = None |
|
|
) -> Callable[[_ToolInputFuncSync | _ToolInputFuncAsync], ToolInputGuardrail[Any]]: ... |
|
|
|
|
|
|
|
|
def tool_input_guardrail( |
|
|
func: _ToolInputFuncSync | _ToolInputFuncAsync | None = None, |
|
|
*, |
|
|
name: str | None = None, |
|
|
) -> ( |
|
|
ToolInputGuardrail[Any] |
|
|
| Callable[[_ToolInputFuncSync | _ToolInputFuncAsync], ToolInputGuardrail[Any]] |
|
|
): |
|
|
"""Decorator to create a ToolInputGuardrail from a function.""" |
|
|
|
|
|
def decorator(f: _ToolInputFuncSync | _ToolInputFuncAsync) -> ToolInputGuardrail[Any]: |
|
|
return ToolInputGuardrail(guardrail_function=f, name=name or f.__name__) |
|
|
|
|
|
if func is not None: |
|
|
return decorator(func) |
|
|
return decorator |
|
|
|
|
|
|
|
|
_ToolOutputFuncSync = Callable[[ToolOutputGuardrailData], ToolGuardrailFunctionOutput] |
|
|
_ToolOutputFuncAsync = Callable[[ToolOutputGuardrailData], Awaitable[ToolGuardrailFunctionOutput]] |
|
|
|
|
|
|
|
|
@overload |
|
|
def tool_output_guardrail(func: _ToolOutputFuncSync): ... |
|
|
|
|
|
|
|
|
@overload |
|
|
def tool_output_guardrail(func: _ToolOutputFuncAsync): ... |
|
|
|
|
|
|
|
|
@overload |
|
|
def tool_output_guardrail( |
|
|
*, name: str | None = None |
|
|
) -> Callable[[_ToolOutputFuncSync | _ToolOutputFuncAsync], ToolOutputGuardrail[Any]]: ... |
|
|
|
|
|
|
|
|
def tool_output_guardrail( |
|
|
func: _ToolOutputFuncSync | _ToolOutputFuncAsync | None = None, |
|
|
*, |
|
|
name: str | None = None, |
|
|
) -> ( |
|
|
ToolOutputGuardrail[Any] |
|
|
| Callable[[_ToolOutputFuncSync | _ToolOutputFuncAsync], ToolOutputGuardrail[Any]] |
|
|
): |
|
|
"""Decorator to create a ToolOutputGuardrail from a function.""" |
|
|
|
|
|
def decorator(f: _ToolOutputFuncSync | _ToolOutputFuncAsync) -> ToolOutputGuardrail[Any]: |
|
|
return ToolOutputGuardrail(guardrail_function=f, name=name or f.__name__) |
|
|
|
|
|
if func is not None: |
|
|
return decorator(func) |
|
|
return decorator |
|
|
|