Spaces:
Sleeping
Sleeping
blackccpie
commited on
Commit
·
11c0a44
1
Parent(s):
16f8a25
feat : added xperimental advanced search mode.
Browse files- agent.py +18 -6
- agent_ui.py +26 -3
- web_tools.py +102 -14
agent.py
CHANGED
|
@@ -43,16 +43,16 @@ class SmolAlbert(CodeAgent):
|
|
| 43 |
Initialize the SmolAlbert agent with Tavily tools and a model.
|
| 44 |
"""
|
| 45 |
# Set up the agent with the Tavily tool and a model
|
| 46 |
-
search_tool = TavilySearchTool()
|
| 47 |
-
image_search_tool = TavilyImageURLSearchTool()
|
| 48 |
-
extract_tool = TavilyExtractTool()
|
| 49 |
-
image_query_tool = ImageQueryTool()
|
| 50 |
model = InferenceClientModel(
|
| 51 |
model_id=self.model_id,
|
| 52 |
provider=self.provider,
|
| 53 |
token=os.getenv("HF_API_KEY"))
|
| 54 |
self.agent = CodeAgent(
|
| 55 |
-
tools=[search_tool, image_search_tool, extract_tool, image_query_tool],
|
| 56 |
model=model,
|
| 57 |
stream_outputs=True,
|
| 58 |
instructions=(
|
|
@@ -61,9 +61,21 @@ class SmolAlbert(CodeAgent):
|
|
| 61 |
"Example format: ... (see [1](https://example1.com)) ... (see [2](https://example2.com)) ... "
|
| 62 |
"Do not invent URL(s) — only use the ones you were provided."
|
| 63 |
"If the answer includes an image URL, include it as an inline Markdown image: "
|
| 64 |
-
)
|
|
|
|
| 65 |
)
|
| 66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
def run(self, task: str, additional_args: dict | None = None) -> str:
|
| 68 |
"""
|
| 69 |
Run the agent with a given query and return the final answer.
|
|
|
|
| 43 |
Initialize the SmolAlbert agent with Tavily tools and a model.
|
| 44 |
"""
|
| 45 |
# Set up the agent with the Tavily tool and a model
|
| 46 |
+
self.search_tool = TavilySearchTool()
|
| 47 |
+
self.image_search_tool = TavilyImageURLSearchTool()
|
| 48 |
+
self.extract_tool = TavilyExtractTool()
|
| 49 |
+
self.image_query_tool = ImageQueryTool()
|
| 50 |
model = InferenceClientModel(
|
| 51 |
model_id=self.model_id,
|
| 52 |
provider=self.provider,
|
| 53 |
token=os.getenv("HF_API_KEY"))
|
| 54 |
self.agent = CodeAgent(
|
| 55 |
+
tools=[self.search_tool, self.image_search_tool, self.extract_tool, self.image_query_tool],
|
| 56 |
model=model,
|
| 57 |
stream_outputs=True,
|
| 58 |
instructions=(
|
|
|
|
| 61 |
"Example format: ... (see [1](https://example1.com)) ... (see [2](https://example2.com)) ... "
|
| 62 |
"Do not invent URL(s) — only use the ones you were provided."
|
| 63 |
"If the answer includes an image URL, include it as an inline Markdown image: "
|
| 64 |
+
),
|
| 65 |
+
#additional_authorized_imports=[""]
|
| 66 |
)
|
| 67 |
|
| 68 |
+
self.advanced_mode = False
|
| 69 |
+
|
| 70 |
+
def enable_advanced_mode(self, enable: bool):
|
| 71 |
+
"""
|
| 72 |
+
Enable or disable advanced mode for the search tool.
|
| 73 |
+
Advanced mode uses more credits but yields better results.
|
| 74 |
+
"""
|
| 75 |
+
self.advanced_mode = enable
|
| 76 |
+
self.search_tool.enable_advanced_mode(enable)
|
| 77 |
+
self.extract_tool.enable_advanced_mode(enable)
|
| 78 |
+
|
| 79 |
def run(self, task: str, additional_args: dict | None = None) -> str:
|
| 80 |
"""
|
| 81 |
Run the agent with a given query and return the final answer.
|
agent_ui.py
CHANGED
|
@@ -299,7 +299,13 @@ class AgentUI:
|
|
| 299 |
self.agent = agent
|
| 300 |
self.description = getattr(agent, "description", None)
|
| 301 |
|
| 302 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
"""
|
| 304 |
Interacts with the agent and streams results into two separate histories:
|
| 305 |
- verbose_messages: full reasoning stream (Chatterbox)
|
|
@@ -437,6 +443,12 @@ class AgentUI:
|
|
| 437 |
"""
|
| 438 |
return self.agent.get_search_credits()
|
| 439 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 440 |
def create_app(self):
|
| 441 |
import gradio as gr
|
| 442 |
|
|
@@ -463,6 +475,19 @@ class AgentUI:
|
|
| 463 |
)
|
| 464 |
submit_btn = gr.Button("Submit", variant="primary")
|
| 465 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 466 |
tavily_credits = gr.Textbox(
|
| 467 |
label="Tavily Credits",
|
| 468 |
value=self.get_tavily_credits(),
|
|
@@ -554,5 +579,3 @@ class AgentUI:
|
|
| 554 |
verbose_chatbot.clear(self.clear_history, inputs=None, outputs=[stored_messages_verbose, stored_messages_quiet])
|
| 555 |
|
| 556 |
return agent
|
| 557 |
-
|
| 558 |
-
__all__ = ["stream_to_gradio", "AgentUI"]
|
|
|
|
| 299 |
self.agent = agent
|
| 300 |
self.description = getattr(agent, "description", None)
|
| 301 |
|
| 302 |
+
def set_advanced_mode(self, enabled: bool):
|
| 303 |
+
"""
|
| 304 |
+
Configure the agent to enable/disable advanced mode.
|
| 305 |
+
"""
|
| 306 |
+
self.agent.enable_advanced_mode(enabled)
|
| 307 |
+
|
| 308 |
+
def interact_with_agent(self, prompt: str, verbose_messages: list, quiet_messages: list):
|
| 309 |
"""
|
| 310 |
Interacts with the agent and streams results into two separate histories:
|
| 311 |
- verbose_messages: full reasoning stream (Chatterbox)
|
|
|
|
| 443 |
"""
|
| 444 |
return self.agent.get_search_credits()
|
| 445 |
|
| 446 |
+
def get_advanced_mode(self) -> bool:
|
| 447 |
+
"""
|
| 448 |
+
Return the agent's current advanced_mode flag for initializing the checkbox on page load.
|
| 449 |
+
"""
|
| 450 |
+
return getattr(self.agent, "advanced_mode", False)
|
| 451 |
+
|
| 452 |
def create_app(self):
|
| 453 |
import gradio as gr
|
| 454 |
|
|
|
|
| 475 |
)
|
| 476 |
submit_btn = gr.Button("Submit", variant="primary")
|
| 477 |
|
| 478 |
+
# Advanced search mode checkbox
|
| 479 |
+
advanced_checkbox = gr.Checkbox(
|
| 480 |
+
label="Advanced search mode",
|
| 481 |
+
value=getattr(self.agent, "advanced_mode", False),
|
| 482 |
+
info="Toggle advanced search behavior for the agent (x2 search credits).",
|
| 483 |
+
container=True,
|
| 484 |
+
)
|
| 485 |
+
|
| 486 |
+
# call agent configuration when checkbox changes
|
| 487 |
+
advanced_checkbox.change(self.set_advanced_mode, advanced_checkbox, None)
|
| 488 |
+
# ensure the checkbox reflects the current agent state each time a page/session loads
|
| 489 |
+
agent.load(self.get_advanced_mode, None, advanced_checkbox)
|
| 490 |
+
|
| 491 |
tavily_credits = gr.Textbox(
|
| 492 |
label="Tavily Credits",
|
| 493 |
value=self.get_tavily_credits(),
|
|
|
|
| 579 |
verbose_chatbot.clear(self.clear_history, inputs=None, outputs=[stored_messages_verbose, stored_messages_quiet])
|
| 580 |
|
| 581 |
return agent
|
|
|
|
|
|
web_tools.py
CHANGED
|
@@ -46,6 +46,10 @@ class TavilyBaseClient:
|
|
| 46 |
|
| 47 |
return f"{plan_usage}/{plan_limit}"
|
| 48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
class TavilySearchTool(TavilyBaseClient, Tool):
|
| 50 |
"""
|
| 51 |
A tool to perform web searches using the Tavily API.
|
|
@@ -60,16 +64,63 @@ class TavilySearchTool(TavilyBaseClient, Tool):
|
|
| 60 |
}
|
| 61 |
output_type = "string"
|
| 62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
def forward(self, query: str):
|
| 64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
return response
|
| 66 |
|
| 67 |
class TavilyExtractTool(TavilyBaseClient, Tool):
|
| 68 |
"""
|
| 69 |
-
A tool to extract information from web pages using the Tavily API.
|
| 70 |
"""
|
| 71 |
name = "tavily_extract"
|
| 72 |
-
description = "Extract information from web pages using Tavily."
|
| 73 |
inputs = {
|
| 74 |
"url": {
|
| 75 |
"type": "string",
|
|
@@ -78,8 +129,37 @@ class TavilyExtractTool(TavilyBaseClient, Tool):
|
|
| 78 |
}
|
| 79 |
output_type = "string"
|
| 80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
def forward(self, url: str):
|
| 82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
return response
|
| 84 |
|
| 85 |
class TavilyImageURLSearchTool(TavilyBaseClient, Tool):
|
|
@@ -97,14 +177,22 @@ class TavilyImageURLSearchTool(TavilyBaseClient, Tool):
|
|
| 97 |
output_type = "string"
|
| 98 |
|
| 99 |
def forward(self, query: str):
|
| 100 |
-
|
| 101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
images = response.get("images", [])
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
return first_image
|
| 110 |
-
return "none"
|
|
|
|
| 46 |
|
| 47 |
return f"{plan_usage}/{plan_limit}"
|
| 48 |
|
| 49 |
+
# ---------------------------------------------------------------------
|
| 50 |
+
# Tavily Tools
|
| 51 |
+
# ---------------------------------------------------------------------
|
| 52 |
+
|
| 53 |
class TavilySearchTool(TavilyBaseClient, Tool):
|
| 54 |
"""
|
| 55 |
A tool to perform web searches using the Tavily API.
|
|
|
|
| 64 |
}
|
| 65 |
output_type = "string"
|
| 66 |
|
| 67 |
+
# Consumes 1 Tavily credit per query
|
| 68 |
+
__basic_params = {
|
| 69 |
+
"search_depth": "basic",
|
| 70 |
+
"max_results": 10, # fetch up to 10 sources
|
| 71 |
+
"auto_parameters": False, # keep manual control
|
| 72 |
+
"include_raw_content": False,
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
# Consumes 2 Tavily credits per query
|
| 76 |
+
__advanced_params = {
|
| 77 |
+
"search_depth": "advanced", # 'advanced' yields better relevance
|
| 78 |
+
"max_results": 10, # fetch up to 10 sources
|
| 79 |
+
"chunks_per_source": 3, # number of content snippets to return per source
|
| 80 |
+
"auto_parameters": False, # keep manual control
|
| 81 |
+
"include_raw_content": False,
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
def __init__(self):
|
| 85 |
+
"""
|
| 86 |
+
Construct the TavilySearchTool.
|
| 87 |
+
"""
|
| 88 |
+
# Call superclass constructor
|
| 89 |
+
super().__init__()
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
self.params = TavilySearchTool.__basic_params
|
| 93 |
+
|
| 94 |
+
def enable_advanced_mode(self, enable: bool = True):
|
| 95 |
+
"""
|
| 96 |
+
Enable or disable advanced mode for the search tool.
|
| 97 |
+
Advanced mode uses more credits but yields better results.
|
| 98 |
+
"""
|
| 99 |
+
if enable:
|
| 100 |
+
self.params = TavilySearchTool.__advanced_params
|
| 101 |
+
else:
|
| 102 |
+
self.params = TavilySearchTool.__basic_params
|
| 103 |
+
|
| 104 |
+
print(f"TavilySearchTool advanced mode has been {'enabled' if enable else 'disabled'}.")
|
| 105 |
+
|
| 106 |
def forward(self, query: str):
|
| 107 |
+
|
| 108 |
+
params = self.params
|
| 109 |
+
params["query"] = query
|
| 110 |
+
|
| 111 |
+
try:
|
| 112 |
+
response = self._tavily_client.search(**params)
|
| 113 |
+
except Exception as e:
|
| 114 |
+
return f"Error calling Tavily API: {e}"
|
| 115 |
+
|
| 116 |
return response
|
| 117 |
|
| 118 |
class TavilyExtractTool(TavilyBaseClient, Tool):
|
| 119 |
"""
|
| 120 |
+
A tool to extract raw information from web pages using the Tavily API.
|
| 121 |
"""
|
| 122 |
name = "tavily_extract"
|
| 123 |
+
description = "Extract raw information from web pages using Tavily."
|
| 124 |
inputs = {
|
| 125 |
"url": {
|
| 126 |
"type": "string",
|
|
|
|
| 129 |
}
|
| 130 |
output_type = "string"
|
| 131 |
|
| 132 |
+
def __init__(self):
|
| 133 |
+
"""
|
| 134 |
+
Construct the TavilyExtractTool.
|
| 135 |
+
"""
|
| 136 |
+
# Call superclass constructor
|
| 137 |
+
super().__init__()
|
| 138 |
+
|
| 139 |
+
self.extract_depth = "basic"
|
| 140 |
+
|
| 141 |
+
def enable_advanced_mode(self, enable: bool = True):
|
| 142 |
+
"""
|
| 143 |
+
Enable or disable advanced mode for the extract tool.
|
| 144 |
+
Advanced mode uses more credits but yields better results (retrieves more data, including tables and embedded content).
|
| 145 |
+
"""
|
| 146 |
+
if enable:
|
| 147 |
+
self.extract_depth = "advanced"
|
| 148 |
+
else:
|
| 149 |
+
self.extract_depth = "basic"
|
| 150 |
+
|
| 151 |
+
print(f"TavilyExtractTool advanced mode has been {'enabled' if enable else 'disabled'}.")
|
| 152 |
+
|
| 153 |
def forward(self, url: str):
|
| 154 |
+
try:
|
| 155 |
+
response = self._tavily_client.extract(
|
| 156 |
+
urls=url,
|
| 157 |
+
extract_depth=self.extract_depth)
|
| 158 |
+
except Exception as e:
|
| 159 |
+
return f"Error calling Tavily extract API: {e}"
|
| 160 |
+
|
| 161 |
+
# Tavily's Extract API can return raw HTML + text.
|
| 162 |
+
# you may trim or sanitize here if needed.
|
| 163 |
return response
|
| 164 |
|
| 165 |
class TavilyImageURLSearchTool(TavilyBaseClient, Tool):
|
|
|
|
| 177 |
output_type = "string"
|
| 178 |
|
| 179 |
def forward(self, query: str):
|
| 180 |
+
try:
|
| 181 |
+
response = self._tavily_client.search(
|
| 182 |
+
query,
|
| 183 |
+
include_images=True,
|
| 184 |
+
include_image_descriptions=True,
|
| 185 |
+
max_results=5
|
| 186 |
+
)
|
| 187 |
+
except Exception as e:
|
| 188 |
+
return f"Error calling Tavily API: {e}"
|
| 189 |
+
|
| 190 |
images = response.get("images", [])
|
| 191 |
+
if not images:
|
| 192 |
+
return "none"
|
| 193 |
+
|
| 194 |
+
# Return the URL of the first image
|
| 195 |
+
first_image = images[0]
|
| 196 |
+
if isinstance(first_image, dict):
|
| 197 |
+
return first_image.get("url", "none")
|
| 198 |
+
return first_image or "none"
|