|
|
from __future__ import annotations |
|
|
|
|
|
import contextlib |
|
|
import inspect |
|
|
import logging |
|
|
import re |
|
|
from dataclasses import dataclass |
|
|
from typing import Annotated, Any, Callable, Literal, get_args, get_origin, get_type_hints |
|
|
|
|
|
from griffe import Docstring, DocstringSectionKind |
|
|
from pydantic import BaseModel, Field, create_model |
|
|
from pydantic.fields import FieldInfo |
|
|
|
|
|
from .exceptions import UserError |
|
|
from .run_context import RunContextWrapper |
|
|
from .strict_schema import ensure_strict_json_schema |
|
|
from .tool_context import ToolContext |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class FuncSchema: |
|
|
""" |
|
|
Captures the schema for a python function, in preparation for sending it to an LLM as a tool. |
|
|
""" |
|
|
|
|
|
name: str |
|
|
"""The name of the function.""" |
|
|
description: str | None |
|
|
"""The description of the function.""" |
|
|
params_pydantic_model: type[BaseModel] |
|
|
"""A Pydantic model that represents the function's parameters.""" |
|
|
params_json_schema: dict[str, Any] |
|
|
"""The JSON schema for the function's parameters, derived from the Pydantic model.""" |
|
|
signature: inspect.Signature |
|
|
"""The signature of the function.""" |
|
|
takes_context: bool = False |
|
|
"""Whether the function takes a RunContextWrapper argument (must be the first argument).""" |
|
|
strict_json_schema: bool = True |
|
|
"""Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True, |
|
|
as it increases the likelihood of correct JSON input.""" |
|
|
|
|
|
def to_call_args(self, data: BaseModel) -> tuple[list[Any], dict[str, Any]]: |
|
|
""" |
|
|
Converts validated data from the Pydantic model into (args, kwargs), suitable for calling |
|
|
the original function. |
|
|
""" |
|
|
positional_args: list[Any] = [] |
|
|
keyword_args: dict[str, Any] = {} |
|
|
seen_var_positional = False |
|
|
|
|
|
|
|
|
for idx, (name, param) in enumerate(self.signature.parameters.items()): |
|
|
|
|
|
if self.takes_context and idx == 0: |
|
|
continue |
|
|
|
|
|
value = getattr(data, name, None) |
|
|
if param.kind == param.VAR_POSITIONAL: |
|
|
|
|
|
positional_args.extend(value or []) |
|
|
seen_var_positional = True |
|
|
elif param.kind == param.VAR_KEYWORD: |
|
|
|
|
|
keyword_args.update(value or {}) |
|
|
elif param.kind in (param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD): |
|
|
|
|
|
if not seen_var_positional: |
|
|
positional_args.append(value) |
|
|
else: |
|
|
keyword_args[name] = value |
|
|
else: |
|
|
|
|
|
keyword_args[name] = value |
|
|
return positional_args, keyword_args |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class FuncDocumentation: |
|
|
"""Contains metadata about a Python function, extracted from its docstring.""" |
|
|
|
|
|
name: str |
|
|
"""The name of the function, via `__name__`.""" |
|
|
description: str | None |
|
|
"""The description of the function, derived from the docstring.""" |
|
|
param_descriptions: dict[str, str] | None |
|
|
"""The parameter descriptions of the function, derived from the docstring.""" |
|
|
|
|
|
|
|
|
DocstringStyle = Literal["google", "numpy", "sphinx"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _detect_docstring_style(doc: str) -> DocstringStyle: |
|
|
scores: dict[DocstringStyle, int] = {"sphinx": 0, "numpy": 0, "google": 0} |
|
|
|
|
|
|
|
|
sphinx_patterns = [r"^:param\s", r"^:type\s", r"^:return:", r"^:rtype:"] |
|
|
for pattern in sphinx_patterns: |
|
|
if re.search(pattern, doc, re.MULTILINE): |
|
|
scores["sphinx"] += 1 |
|
|
|
|
|
|
|
|
|
|
|
numpy_patterns = [ |
|
|
r"^Parameters\s*\n\s*-{3,}", |
|
|
r"^Returns\s*\n\s*-{3,}", |
|
|
r"^Yields\s*\n\s*-{3,}", |
|
|
] |
|
|
for pattern in numpy_patterns: |
|
|
if re.search(pattern, doc, re.MULTILINE): |
|
|
scores["numpy"] += 1 |
|
|
|
|
|
|
|
|
google_patterns = [r"^(Args|Arguments):", r"^(Returns):", r"^(Raises):"] |
|
|
for pattern in google_patterns: |
|
|
if re.search(pattern, doc, re.MULTILINE): |
|
|
scores["google"] += 1 |
|
|
|
|
|
max_score = max(scores.values()) |
|
|
if max_score == 0: |
|
|
return "google" |
|
|
|
|
|
|
|
|
styles: list[DocstringStyle] = ["sphinx", "numpy", "google"] |
|
|
|
|
|
for style in styles: |
|
|
if scores[style] == max_score: |
|
|
return style |
|
|
|
|
|
return "google" |
|
|
|
|
|
|
|
|
@contextlib.contextmanager |
|
|
def _suppress_griffe_logging(): |
|
|
|
|
|
logger = logging.getLogger("griffe") |
|
|
previous_level = logger.getEffectiveLevel() |
|
|
logger.setLevel(logging.ERROR) |
|
|
try: |
|
|
yield |
|
|
finally: |
|
|
logger.setLevel(previous_level) |
|
|
|
|
|
|
|
|
def generate_func_documentation( |
|
|
func: Callable[..., Any], style: DocstringStyle | None = None |
|
|
) -> FuncDocumentation: |
|
|
""" |
|
|
Extracts metadata from a function docstring, in preparation for sending it to an LLM as a tool. |
|
|
|
|
|
Args: |
|
|
func: The function to extract documentation from. |
|
|
style: The style of the docstring to use for parsing. If not provided, we will attempt to |
|
|
auto-detect the style. |
|
|
|
|
|
Returns: |
|
|
A FuncDocumentation object containing the function's name, description, and parameter |
|
|
descriptions. |
|
|
""" |
|
|
name = func.__name__ |
|
|
doc = inspect.getdoc(func) |
|
|
if not doc: |
|
|
return FuncDocumentation(name=name, description=None, param_descriptions=None) |
|
|
|
|
|
with _suppress_griffe_logging(): |
|
|
docstring = Docstring(doc, lineno=1, parser=style or _detect_docstring_style(doc)) |
|
|
parsed = docstring.parse() |
|
|
|
|
|
description: str | None = next( |
|
|
(section.value for section in parsed if section.kind == DocstringSectionKind.text), None |
|
|
) |
|
|
|
|
|
param_descriptions: dict[str, str] = { |
|
|
param.name: param.description |
|
|
for section in parsed |
|
|
if section.kind == DocstringSectionKind.parameters |
|
|
for param in section.value |
|
|
} |
|
|
|
|
|
return FuncDocumentation( |
|
|
name=func.__name__, |
|
|
description=description, |
|
|
param_descriptions=param_descriptions or None, |
|
|
) |
|
|
|
|
|
|
|
|
def _strip_annotated(annotation: Any) -> tuple[Any, tuple[Any, ...]]: |
|
|
"""Returns the underlying annotation and any metadata from typing.Annotated.""" |
|
|
|
|
|
metadata: tuple[Any, ...] = () |
|
|
ann = annotation |
|
|
|
|
|
while get_origin(ann) is Annotated: |
|
|
args = get_args(ann) |
|
|
if not args: |
|
|
break |
|
|
ann = args[0] |
|
|
metadata = (*metadata, *args[1:]) |
|
|
|
|
|
return ann, metadata |
|
|
|
|
|
|
|
|
def _extract_description_from_metadata(metadata: tuple[Any, ...]) -> str | None: |
|
|
"""Extracts a human readable description from Annotated metadata if present.""" |
|
|
|
|
|
for item in metadata: |
|
|
if isinstance(item, str): |
|
|
return item |
|
|
return None |
|
|
|
|
|
|
|
|
def function_schema( |
|
|
func: Callable[..., Any], |
|
|
docstring_style: DocstringStyle | None = None, |
|
|
name_override: str | None = None, |
|
|
description_override: str | None = None, |
|
|
use_docstring_info: bool = True, |
|
|
strict_json_schema: bool = True, |
|
|
) -> FuncSchema: |
|
|
""" |
|
|
Given a Python function, extracts a `FuncSchema` from it, capturing the name, description, |
|
|
parameter descriptions, and other metadata. |
|
|
|
|
|
Args: |
|
|
func: The function to extract the schema from. |
|
|
docstring_style: The style of the docstring to use for parsing. If not provided, we will |
|
|
attempt to auto-detect the style. |
|
|
name_override: If provided, use this name instead of the function's `__name__`. |
|
|
description_override: If provided, use this description instead of the one derived from the |
|
|
docstring. |
|
|
use_docstring_info: If True, uses the docstring to generate the description and parameter |
|
|
descriptions. |
|
|
strict_json_schema: Whether the JSON schema is in strict mode. If True, we'll ensure that |
|
|
the schema adheres to the "strict" standard the OpenAI API expects. We **strongly** |
|
|
recommend setting this to True, as it increases the likelihood of the LLM producing |
|
|
correct JSON input. |
|
|
|
|
|
Returns: |
|
|
A `FuncSchema` object containing the function's name, description, parameter descriptions, |
|
|
and other metadata. |
|
|
""" |
|
|
|
|
|
|
|
|
if use_docstring_info: |
|
|
doc_info = generate_func_documentation(func, docstring_style) |
|
|
param_descs = dict(doc_info.param_descriptions or {}) |
|
|
else: |
|
|
doc_info = None |
|
|
param_descs = {} |
|
|
|
|
|
type_hints_with_extras = get_type_hints(func, include_extras=True) |
|
|
type_hints: dict[str, Any] = {} |
|
|
annotated_param_descs: dict[str, str] = {} |
|
|
|
|
|
for name, annotation in type_hints_with_extras.items(): |
|
|
if name == "return": |
|
|
continue |
|
|
|
|
|
stripped_ann, metadata = _strip_annotated(annotation) |
|
|
type_hints[name] = stripped_ann |
|
|
|
|
|
description = _extract_description_from_metadata(metadata) |
|
|
if description is not None: |
|
|
annotated_param_descs[name] = description |
|
|
|
|
|
for name, description in annotated_param_descs.items(): |
|
|
param_descs.setdefault(name, description) |
|
|
|
|
|
|
|
|
func_name = name_override or (doc_info.name if doc_info else func.__name__) |
|
|
|
|
|
|
|
|
sig = inspect.signature(func) |
|
|
params = list(sig.parameters.items()) |
|
|
takes_context = False |
|
|
filtered_params = [] |
|
|
|
|
|
if params: |
|
|
first_name, first_param = params[0] |
|
|
|
|
|
ann = type_hints.get(first_name, first_param.annotation) |
|
|
if ann != inspect._empty: |
|
|
origin = get_origin(ann) or ann |
|
|
if origin is RunContextWrapper or origin is ToolContext: |
|
|
takes_context = True |
|
|
else: |
|
|
filtered_params.append((first_name, first_param)) |
|
|
else: |
|
|
filtered_params.append((first_name, first_param)) |
|
|
|
|
|
|
|
|
for name, param in params[1:]: |
|
|
ann = type_hints.get(name, param.annotation) |
|
|
if ann != inspect._empty: |
|
|
origin = get_origin(ann) or ann |
|
|
if origin is RunContextWrapper or origin is ToolContext: |
|
|
raise UserError( |
|
|
f"RunContextWrapper/ToolContext param found at non-first position in function" |
|
|
f" {func.__name__}" |
|
|
) |
|
|
filtered_params.append((name, param)) |
|
|
|
|
|
|
|
|
|
|
|
fields: dict[str, Any] = {} |
|
|
|
|
|
for name, param in filtered_params: |
|
|
ann = type_hints.get(name, param.annotation) |
|
|
default = param.default |
|
|
|
|
|
|
|
|
if ann == inspect._empty: |
|
|
ann = Any |
|
|
|
|
|
|
|
|
field_description = param_descs.get(name, None) |
|
|
|
|
|
|
|
|
if param.kind == param.VAR_POSITIONAL: |
|
|
|
|
|
if get_origin(ann) is tuple: |
|
|
|
|
|
args_of_tuple = get_args(ann) |
|
|
if len(args_of_tuple) == 2 and args_of_tuple[1] is Ellipsis: |
|
|
ann = list[args_of_tuple[0]] |
|
|
else: |
|
|
ann = list[Any] |
|
|
else: |
|
|
|
|
|
ann = list[ann] |
|
|
|
|
|
|
|
|
fields[name] = ( |
|
|
ann, |
|
|
Field(default_factory=list, description=field_description), |
|
|
) |
|
|
|
|
|
elif param.kind == param.VAR_KEYWORD: |
|
|
|
|
|
if get_origin(ann) is dict: |
|
|
|
|
|
dict_args = get_args(ann) |
|
|
if len(dict_args) == 2: |
|
|
ann = dict[dict_args[0], dict_args[1]] |
|
|
else: |
|
|
ann = dict[str, Any] |
|
|
else: |
|
|
|
|
|
ann = dict[str, ann] |
|
|
|
|
|
fields[name] = ( |
|
|
ann, |
|
|
Field(default_factory=dict, description=field_description), |
|
|
) |
|
|
|
|
|
else: |
|
|
|
|
|
if default == inspect._empty: |
|
|
|
|
|
fields[name] = ( |
|
|
ann, |
|
|
Field(..., description=field_description), |
|
|
) |
|
|
elif isinstance(default, FieldInfo): |
|
|
|
|
|
fields[name] = ( |
|
|
ann, |
|
|
FieldInfo.merge_field_infos( |
|
|
default, description=field_description or default.description |
|
|
), |
|
|
) |
|
|
else: |
|
|
|
|
|
fields[name] = ( |
|
|
ann, |
|
|
Field(default=default, description=field_description), |
|
|
) |
|
|
|
|
|
|
|
|
dynamic_model = create_model(f"{func_name}_args", __base__=BaseModel, **fields) |
|
|
|
|
|
|
|
|
json_schema = dynamic_model.model_json_schema() |
|
|
if strict_json_schema: |
|
|
json_schema = ensure_strict_json_schema(json_schema) |
|
|
|
|
|
|
|
|
return FuncSchema( |
|
|
name=func_name, |
|
|
|
|
|
description=description_override or (doc_info.description if doc_info else None), |
|
|
params_pydantic_model=dynamic_model, |
|
|
params_json_schema=json_schema, |
|
|
signature=sig, |
|
|
takes_context=takes_context, |
|
|
strict_json_schema=strict_json_schema, |
|
|
) |
|
|
|