|
|
import abc |
|
|
from dataclasses import dataclass |
|
|
from typing import Any |
|
|
|
|
|
from pydantic import BaseModel, TypeAdapter |
|
|
from typing_extensions import TypedDict, get_args, get_origin |
|
|
|
|
|
from .exceptions import ModelBehaviorError, UserError |
|
|
from .strict_schema import ensure_strict_json_schema |
|
|
from .tracing import SpanError |
|
|
from .util import _error_tracing, _json |
|
|
|
|
|
_WRAPPER_DICT_KEY = "response" |
|
|
|
|
|
|
|
|
class AgentOutputSchemaBase(abc.ABC): |
|
|
"""An object that captures the JSON schema of the output, as well as validating/parsing JSON |
|
|
produced by the LLM into the output type. |
|
|
""" |
|
|
|
|
|
@abc.abstractmethod |
|
|
def is_plain_text(self) -> bool: |
|
|
"""Whether the output type is plain text (versus a JSON object).""" |
|
|
pass |
|
|
|
|
|
@abc.abstractmethod |
|
|
def name(self) -> str: |
|
|
"""The name of the output type.""" |
|
|
pass |
|
|
|
|
|
@abc.abstractmethod |
|
|
def json_schema(self) -> dict[str, Any]: |
|
|
"""Returns the JSON schema of the output. Will only be called if the output type is not |
|
|
plain text. |
|
|
""" |
|
|
pass |
|
|
|
|
|
@abc.abstractmethod |
|
|
def is_strict_json_schema(self) -> bool: |
|
|
"""Whether the JSON schema is in strict mode. Strict mode constrains the JSON schema |
|
|
features, but guarantees valid JSON. See here for details: |
|
|
https://platform.openai.com/docs/guides/structured-outputs#supported-schemas |
|
|
""" |
|
|
pass |
|
|
|
|
|
@abc.abstractmethod |
|
|
def validate_json(self, json_str: str) -> Any: |
|
|
"""Validate a JSON string against the output type. You must return the validated object, |
|
|
or raise a `ModelBehaviorError` if the JSON is invalid. |
|
|
""" |
|
|
pass |
|
|
|
|
|
|
|
|
@dataclass(init=False) |
|
|
class AgentOutputSchema(AgentOutputSchemaBase): |
|
|
"""An object that captures the JSON schema of the output, as well as validating/parsing JSON |
|
|
produced by the LLM into the output type. |
|
|
""" |
|
|
|
|
|
output_type: type[Any] |
|
|
"""The type of the output.""" |
|
|
|
|
|
_type_adapter: TypeAdapter[Any] |
|
|
"""A type adapter that wraps the output type, so that we can validate JSON.""" |
|
|
|
|
|
_is_wrapped: bool |
|
|
"""Whether the output type is wrapped in a dictionary. This is generally done if the base |
|
|
output type cannot be represented as a JSON Schema object. |
|
|
""" |
|
|
|
|
|
_output_schema: dict[str, Any] |
|
|
"""The JSON schema of the output.""" |
|
|
|
|
|
_strict_json_schema: bool |
|
|
"""Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True, |
|
|
as it increases the likelihood of correct JSON input. |
|
|
""" |
|
|
|
|
|
def __init__(self, output_type: type[Any], strict_json_schema: bool = True): |
|
|
""" |
|
|
Args: |
|
|
output_type: The type of the output. |
|
|
strict_json_schema: Whether the JSON schema is in strict mode. We **strongly** recommend |
|
|
setting this to True, as it increases the likelihood of correct JSON input. |
|
|
""" |
|
|
self.output_type = output_type |
|
|
self._strict_json_schema = strict_json_schema |
|
|
|
|
|
if output_type is None or output_type is str: |
|
|
self._is_wrapped = False |
|
|
self._type_adapter = TypeAdapter(output_type) |
|
|
self._output_schema = self._type_adapter.json_schema() |
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
self._is_wrapped = not _is_subclass_of_base_model_or_dict(output_type) |
|
|
|
|
|
if self._is_wrapped: |
|
|
OutputType = TypedDict( |
|
|
"OutputType", |
|
|
{ |
|
|
_WRAPPER_DICT_KEY: output_type, |
|
|
}, |
|
|
) |
|
|
self._type_adapter = TypeAdapter(OutputType) |
|
|
self._output_schema = self._type_adapter.json_schema() |
|
|
else: |
|
|
self._type_adapter = TypeAdapter(output_type) |
|
|
self._output_schema = self._type_adapter.json_schema() |
|
|
|
|
|
if self._strict_json_schema: |
|
|
try: |
|
|
self._output_schema = ensure_strict_json_schema(self._output_schema) |
|
|
except UserError as e: |
|
|
raise UserError( |
|
|
"Strict JSON schema is enabled, but the output type is not valid. " |
|
|
"Either make the output type strict, " |
|
|
"or wrap your type with AgentOutputSchema(YourType, strict_json_schema=False)" |
|
|
) from e |
|
|
|
|
|
def is_plain_text(self) -> bool: |
|
|
"""Whether the output type is plain text (versus a JSON object).""" |
|
|
return self.output_type is None or self.output_type is str |
|
|
|
|
|
def is_strict_json_schema(self) -> bool: |
|
|
"""Whether the JSON schema is in strict mode.""" |
|
|
return self._strict_json_schema |
|
|
|
|
|
def json_schema(self) -> dict[str, Any]: |
|
|
"""The JSON schema of the output type.""" |
|
|
if self.is_plain_text(): |
|
|
raise UserError("Output type is plain text, so no JSON schema is available") |
|
|
return self._output_schema |
|
|
|
|
|
def validate_json(self, json_str: str) -> Any: |
|
|
"""Validate a JSON string against the output type. Returns the validated object, or raises |
|
|
a `ModelBehaviorError` if the JSON is invalid. |
|
|
""" |
|
|
validated = _json.validate_json(json_str, self._type_adapter, partial=False) |
|
|
if self._is_wrapped: |
|
|
if not isinstance(validated, dict): |
|
|
_error_tracing.attach_error_to_current_span( |
|
|
SpanError( |
|
|
message="Invalid JSON", |
|
|
data={"details": f"Expected a dict, got {type(validated)}"}, |
|
|
) |
|
|
) |
|
|
raise ModelBehaviorError( |
|
|
f"Expected a dict, got {type(validated)} for JSON: {json_str}" |
|
|
) |
|
|
|
|
|
if _WRAPPER_DICT_KEY not in validated: |
|
|
_error_tracing.attach_error_to_current_span( |
|
|
SpanError( |
|
|
message="Invalid JSON", |
|
|
data={"details": f"Could not find key {_WRAPPER_DICT_KEY} in JSON"}, |
|
|
) |
|
|
) |
|
|
raise ModelBehaviorError( |
|
|
f"Could not find key {_WRAPPER_DICT_KEY} in JSON: {json_str}" |
|
|
) |
|
|
return validated[_WRAPPER_DICT_KEY] |
|
|
return validated |
|
|
|
|
|
def name(self) -> str: |
|
|
"""The name of the output type.""" |
|
|
return _type_to_str(self.output_type) |
|
|
|
|
|
|
|
|
def _is_subclass_of_base_model_or_dict(t: Any) -> bool: |
|
|
if not isinstance(t, type): |
|
|
return False |
|
|
|
|
|
|
|
|
origin = get_origin(t) |
|
|
|
|
|
allowed_types = (BaseModel, dict) |
|
|
|
|
|
return issubclass(origin or t, allowed_types) |
|
|
|
|
|
|
|
|
def _type_to_str(t: type[Any]) -> str: |
|
|
origin = get_origin(t) |
|
|
args = get_args(t) |
|
|
|
|
|
if origin is None: |
|
|
|
|
|
return t.__name__ |
|
|
elif args: |
|
|
args_str = ", ".join(_type_to_str(arg) for arg in args) |
|
|
return f"{origin.__name__}[{args_str}]" |
|
|
else: |
|
|
return str(t) |
|
|
|