deep_research-personal / agents /tool_guardrails.py
Akashmj22122002's picture
Upload folder using huggingface_hub
14edff4 verified
from __future__ import annotations
import inspect
from collections.abc import Awaitable
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, overload
from typing_extensions import TypedDict, TypeVar
from .exceptions import UserError
from .tool_context import ToolContext
from .util._types import MaybeAwaitable
if TYPE_CHECKING:
from .agent import Agent
@dataclass
class ToolInputGuardrailResult:
"""The result of a tool input guardrail run."""
guardrail: ToolInputGuardrail[Any]
"""The guardrail that was run."""
output: ToolGuardrailFunctionOutput
"""The output of the guardrail function."""
@dataclass
class ToolOutputGuardrailResult:
"""The result of a tool output guardrail run."""
guardrail: ToolOutputGuardrail[Any]
"""The guardrail that was run."""
output: ToolGuardrailFunctionOutput
"""The output of the guardrail function."""
class RejectContentBehavior(TypedDict):
"""Rejects the tool call/output but continues execution with a message to the model."""
type: Literal["reject_content"]
message: str
class RaiseExceptionBehavior(TypedDict):
"""Raises an exception to halt execution."""
type: Literal["raise_exception"]
class AllowBehavior(TypedDict):
"""Allows normal tool execution to continue."""
type: Literal["allow"]
@dataclass
class ToolGuardrailFunctionOutput:
"""The output of a tool guardrail function."""
output_info: Any
"""
Optional data about checks performed. For example, the guardrail could include
information about the checks it performed and granular results.
"""
behavior: RejectContentBehavior | RaiseExceptionBehavior | AllowBehavior = field(
default_factory=lambda: AllowBehavior(type="allow")
)
"""
Defines how the system should respond when this guardrail result is processed.
- allow: Allow normal tool execution to continue without interference (default)
- reject_content: Reject the tool call/output but continue execution with a message to the model
- raise_exception: Halt execution by raising a ToolGuardrailTripwireTriggered exception
"""
@classmethod
def allow(cls, output_info: Any = None) -> ToolGuardrailFunctionOutput:
"""Create a guardrail output that allows the tool execution to continue normally.
Args:
output_info: Optional data about checks performed.
Returns:
ToolGuardrailFunctionOutput configured to allow normal execution.
"""
return cls(output_info=output_info, behavior=AllowBehavior(type="allow"))
@classmethod
def reject_content(cls, message: str, output_info: Any = None) -> ToolGuardrailFunctionOutput:
"""Create a guardrail output that rejects the tool call/output but continues execution.
Args:
message: Message to send to the model instead of the tool result.
output_info: Optional data about checks performed.
Returns:
ToolGuardrailFunctionOutput configured to reject the content.
"""
return cls(
output_info=output_info,
behavior=RejectContentBehavior(type="reject_content", message=message),
)
@classmethod
def raise_exception(cls, output_info: Any = None) -> ToolGuardrailFunctionOutput:
"""Create a guardrail output that raises an exception to halt execution.
Args:
output_info: Optional data about checks performed.
Returns:
ToolGuardrailFunctionOutput configured to raise an exception.
"""
return cls(output_info=output_info, behavior=RaiseExceptionBehavior(type="raise_exception"))
@dataclass
class ToolInputGuardrailData:
"""Input data passed to a tool input guardrail function."""
context: ToolContext[Any]
"""
The tool context containing information about the current tool execution.
"""
agent: Agent[Any]
"""
The agent that is executing the tool.
"""
@dataclass
class ToolOutputGuardrailData(ToolInputGuardrailData):
"""Input data passed to a tool output guardrail function.
Extends input data with the tool's output.
"""
output: Any
"""
The output produced by the tool function.
"""
TContext_co = TypeVar("TContext_co", bound=Any, covariant=True)
@dataclass
class ToolInputGuardrail(Generic[TContext_co]):
"""A guardrail that runs before a function tool is invoked."""
guardrail_function: Callable[
[ToolInputGuardrailData], MaybeAwaitable[ToolGuardrailFunctionOutput]
]
"""
The function that implements the guardrail logic.
"""
name: str | None = None
"""
Optional name for the guardrail. If not provided, uses the function name.
"""
def get_name(self) -> str:
return self.name or self.guardrail_function.__name__
async def run(self, data: ToolInputGuardrailData) -> ToolGuardrailFunctionOutput:
if not callable(self.guardrail_function):
raise UserError(f"Guardrail function must be callable, got {self.guardrail_function}")
result = self.guardrail_function(data)
if inspect.isawaitable(result):
return await result
return result
@dataclass
class ToolOutputGuardrail(Generic[TContext_co]):
"""A guardrail that runs after a function tool is invoked."""
guardrail_function: Callable[
[ToolOutputGuardrailData], MaybeAwaitable[ToolGuardrailFunctionOutput]
]
"""
The function that implements the guardrail logic.
"""
name: str | None = None
"""
Optional name for the guardrail. If not provided, uses the function name.
"""
def get_name(self) -> str:
return self.name or self.guardrail_function.__name__
async def run(self, data: ToolOutputGuardrailData) -> ToolGuardrailFunctionOutput:
if not callable(self.guardrail_function):
raise UserError(f"Guardrail function must be callable, got {self.guardrail_function}")
result = self.guardrail_function(data)
if inspect.isawaitable(result):
return await result
return result
# Decorators
_ToolInputFuncSync = Callable[[ToolInputGuardrailData], ToolGuardrailFunctionOutput]
_ToolInputFuncAsync = Callable[[ToolInputGuardrailData], Awaitable[ToolGuardrailFunctionOutput]]
@overload
def tool_input_guardrail(func: _ToolInputFuncSync): ...
@overload
def tool_input_guardrail(func: _ToolInputFuncAsync): ...
@overload
def tool_input_guardrail(
*, name: str | None = None
) -> Callable[[_ToolInputFuncSync | _ToolInputFuncAsync], ToolInputGuardrail[Any]]: ...
def tool_input_guardrail(
func: _ToolInputFuncSync | _ToolInputFuncAsync | None = None,
*,
name: str | None = None,
) -> (
ToolInputGuardrail[Any]
| Callable[[_ToolInputFuncSync | _ToolInputFuncAsync], ToolInputGuardrail[Any]]
):
"""Decorator to create a ToolInputGuardrail from a function."""
def decorator(f: _ToolInputFuncSync | _ToolInputFuncAsync) -> ToolInputGuardrail[Any]:
return ToolInputGuardrail(guardrail_function=f, name=name or f.__name__)
if func is not None:
return decorator(func)
return decorator
_ToolOutputFuncSync = Callable[[ToolOutputGuardrailData], ToolGuardrailFunctionOutput]
_ToolOutputFuncAsync = Callable[[ToolOutputGuardrailData], Awaitable[ToolGuardrailFunctionOutput]]
@overload
def tool_output_guardrail(func: _ToolOutputFuncSync): ...
@overload
def tool_output_guardrail(func: _ToolOutputFuncAsync): ...
@overload
def tool_output_guardrail(
*, name: str | None = None
) -> Callable[[_ToolOutputFuncSync | _ToolOutputFuncAsync], ToolOutputGuardrail[Any]]: ...
def tool_output_guardrail(
func: _ToolOutputFuncSync | _ToolOutputFuncAsync | None = None,
*,
name: str | None = None,
) -> (
ToolOutputGuardrail[Any]
| Callable[[_ToolOutputFuncSync | _ToolOutputFuncAsync], ToolOutputGuardrail[Any]]
):
"""Decorator to create a ToolOutputGuardrail from a function."""
def decorator(f: _ToolOutputFuncSync | _ToolOutputFuncAsync) -> ToolOutputGuardrail[Any]:
return ToolOutputGuardrail(guardrail_function=f, name=name or f.__name__)
if func is not None:
return decorator(func)
return decorator