blackccpie commited on
Commit
11c0a44
·
1 Parent(s): 16f8a25

feat : added xperimental advanced search mode.

Browse files
Files changed (3) hide show
  1. agent.py +18 -6
  2. agent_ui.py +26 -3
  3. web_tools.py +102 -14
agent.py CHANGED
@@ -43,16 +43,16 @@ class SmolAlbert(CodeAgent):
43
  Initialize the SmolAlbert agent with Tavily tools and a model.
44
  """
45
  # Set up the agent with the Tavily tool and a model
46
- search_tool = TavilySearchTool()
47
- image_search_tool = TavilyImageURLSearchTool()
48
- extract_tool = TavilyExtractTool()
49
- image_query_tool = ImageQueryTool()
50
  model = InferenceClientModel(
51
  model_id=self.model_id,
52
  provider=self.provider,
53
  token=os.getenv("HF_API_KEY"))
54
  self.agent = CodeAgent(
55
- tools=[search_tool, image_search_tool, extract_tool, image_query_tool],
56
  model=model,
57
  stream_outputs=True,
58
  instructions=(
@@ -61,9 +61,21 @@ class SmolAlbert(CodeAgent):
61
  "Example format: ... (see [1](https://example1.com)) ... (see [2](https://example2.com)) ... "
62
  "Do not invent URL(s) — only use the ones you were provided."
63
  "If the answer includes an image URL, include it as an inline Markdown image: ![image](<image_url>)"
64
- )
 
65
  )
66
 
 
 
 
 
 
 
 
 
 
 
 
67
  def run(self, task: str, additional_args: dict | None = None) -> str:
68
  """
69
  Run the agent with a given query and return the final answer.
 
43
  Initialize the SmolAlbert agent with Tavily tools and a model.
44
  """
45
  # Set up the agent with the Tavily tool and a model
46
+ self.search_tool = TavilySearchTool()
47
+ self.image_search_tool = TavilyImageURLSearchTool()
48
+ self.extract_tool = TavilyExtractTool()
49
+ self.image_query_tool = ImageQueryTool()
50
  model = InferenceClientModel(
51
  model_id=self.model_id,
52
  provider=self.provider,
53
  token=os.getenv("HF_API_KEY"))
54
  self.agent = CodeAgent(
55
+ tools=[self.search_tool, self.image_search_tool, self.extract_tool, self.image_query_tool],
56
  model=model,
57
  stream_outputs=True,
58
  instructions=(
 
61
  "Example format: ... (see [1](https://example1.com)) ... (see [2](https://example2.com)) ... "
62
  "Do not invent URL(s) — only use the ones you were provided."
63
  "If the answer includes an image URL, include it as an inline Markdown image: ![image](<image_url>)"
64
+ ),
65
+ #additional_authorized_imports=[""]
66
  )
67
 
68
+ self.advanced_mode = False
69
+
70
+ def enable_advanced_mode(self, enable: bool):
71
+ """
72
+ Enable or disable advanced mode for the search tool.
73
+ Advanced mode uses more credits but yields better results.
74
+ """
75
+ self.advanced_mode = enable
76
+ self.search_tool.enable_advanced_mode(enable)
77
+ self.extract_tool.enable_advanced_mode(enable)
78
+
79
  def run(self, task: str, additional_args: dict | None = None) -> str:
80
  """
81
  Run the agent with a given query and return the final answer.
agent_ui.py CHANGED
@@ -299,7 +299,13 @@ class AgentUI:
299
  self.agent = agent
300
  self.description = getattr(agent, "description", None)
301
 
302
- def interact_with_agent(self, prompt, verbose_messages, quiet_messages):
 
 
 
 
 
 
303
  """
304
  Interacts with the agent and streams results into two separate histories:
305
  - verbose_messages: full reasoning stream (Chatterbox)
@@ -437,6 +443,12 @@ class AgentUI:
437
  """
438
  return self.agent.get_search_credits()
439
 
 
 
 
 
 
 
440
  def create_app(self):
441
  import gradio as gr
442
 
@@ -463,6 +475,19 @@ class AgentUI:
463
  )
464
  submit_btn = gr.Button("Submit", variant="primary")
465
 
 
 
 
 
 
 
 
 
 
 
 
 
 
466
  tavily_credits = gr.Textbox(
467
  label="Tavily Credits",
468
  value=self.get_tavily_credits(),
@@ -554,5 +579,3 @@ class AgentUI:
554
  verbose_chatbot.clear(self.clear_history, inputs=None, outputs=[stored_messages_verbose, stored_messages_quiet])
555
 
556
  return agent
557
-
558
- __all__ = ["stream_to_gradio", "AgentUI"]
 
299
  self.agent = agent
300
  self.description = getattr(agent, "description", None)
301
 
302
+ def set_advanced_mode(self, enabled: bool):
303
+ """
304
+ Configure the agent to enable/disable advanced mode.
305
+ """
306
+ self.agent.enable_advanced_mode(enabled)
307
+
308
+ def interact_with_agent(self, prompt: str, verbose_messages: list, quiet_messages: list):
309
  """
310
  Interacts with the agent and streams results into two separate histories:
311
  - verbose_messages: full reasoning stream (Chatterbox)
 
443
  """
444
  return self.agent.get_search_credits()
445
 
446
+ def get_advanced_mode(self) -> bool:
447
+ """
448
+ Return the agent's current advanced_mode flag for initializing the checkbox on page load.
449
+ """
450
+ return getattr(self.agent, "advanced_mode", False)
451
+
452
  def create_app(self):
453
  import gradio as gr
454
 
 
475
  )
476
  submit_btn = gr.Button("Submit", variant="primary")
477
 
478
+ # Advanced search mode checkbox
479
+ advanced_checkbox = gr.Checkbox(
480
+ label="Advanced search mode",
481
+ value=getattr(self.agent, "advanced_mode", False),
482
+ info="Toggle advanced search behavior for the agent (x2 search credits).",
483
+ container=True,
484
+ )
485
+
486
+ # call agent configuration when checkbox changes
487
+ advanced_checkbox.change(self.set_advanced_mode, advanced_checkbox, None)
488
+ # ensure the checkbox reflects the current agent state each time a page/session loads
489
+ agent.load(self.get_advanced_mode, None, advanced_checkbox)
490
+
491
  tavily_credits = gr.Textbox(
492
  label="Tavily Credits",
493
  value=self.get_tavily_credits(),
 
579
  verbose_chatbot.clear(self.clear_history, inputs=None, outputs=[stored_messages_verbose, stored_messages_quiet])
580
 
581
  return agent
 
 
web_tools.py CHANGED
@@ -46,6 +46,10 @@ class TavilyBaseClient:
46
 
47
  return f"{plan_usage}/{plan_limit}"
48
 
 
 
 
 
49
  class TavilySearchTool(TavilyBaseClient, Tool):
50
  """
51
  A tool to perform web searches using the Tavily API.
@@ -60,16 +64,63 @@ class TavilySearchTool(TavilyBaseClient, Tool):
60
  }
61
  output_type = "string"
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  def forward(self, query: str):
64
- response = self._tavily_client.search(query)
 
 
 
 
 
 
 
 
65
  return response
66
 
67
  class TavilyExtractTool(TavilyBaseClient, Tool):
68
  """
69
- A tool to extract information from web pages using the Tavily API.
70
  """
71
  name = "tavily_extract"
72
- description = "Extract information from web pages using Tavily."
73
  inputs = {
74
  "url": {
75
  "type": "string",
@@ -78,8 +129,37 @@ class TavilyExtractTool(TavilyBaseClient, Tool):
78
  }
79
  output_type = "string"
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  def forward(self, url: str):
82
- response = self._tavily_client.extract(url)
 
 
 
 
 
 
 
 
83
  return response
84
 
85
  class TavilyImageURLSearchTool(TavilyBaseClient, Tool):
@@ -97,14 +177,22 @@ class TavilyImageURLSearchTool(TavilyBaseClient, Tool):
97
  output_type = "string"
98
 
99
  def forward(self, query: str):
100
- response = self._tavily_client.search(query, include_images=True)
101
-
 
 
 
 
 
 
 
 
102
  images = response.get("images", [])
103
-
104
- if images:
105
- # Return the URL of the first image
106
- first_image = images[0]
107
- if isinstance(first_image, dict):
108
- return first_image.get("url")
109
- return first_image
110
- return "none"
 
46
 
47
  return f"{plan_usage}/{plan_limit}"
48
 
49
+ # ---------------------------------------------------------------------
50
+ # Tavily Tools
51
+ # ---------------------------------------------------------------------
52
+
53
  class TavilySearchTool(TavilyBaseClient, Tool):
54
  """
55
  A tool to perform web searches using the Tavily API.
 
64
  }
65
  output_type = "string"
66
 
67
+ # Consumes 1 Tavily credit per query
68
+ __basic_params = {
69
+ "search_depth": "basic",
70
+ "max_results": 10, # fetch up to 10 sources
71
+ "auto_parameters": False, # keep manual control
72
+ "include_raw_content": False,
73
+ }
74
+
75
+ # Consumes 2 Tavily credits per query
76
+ __advanced_params = {
77
+ "search_depth": "advanced", # 'advanced' yields better relevance
78
+ "max_results": 10, # fetch up to 10 sources
79
+ "chunks_per_source": 3, # number of content snippets to return per source
80
+ "auto_parameters": False, # keep manual control
81
+ "include_raw_content": False,
82
+ }
83
+
84
+ def __init__(self):
85
+ """
86
+ Construct the TavilySearchTool.
87
+ """
88
+ # Call superclass constructor
89
+ super().__init__()
90
+
91
+
92
+ self.params = TavilySearchTool.__basic_params
93
+
94
+ def enable_advanced_mode(self, enable: bool = True):
95
+ """
96
+ Enable or disable advanced mode for the search tool.
97
+ Advanced mode uses more credits but yields better results.
98
+ """
99
+ if enable:
100
+ self.params = TavilySearchTool.__advanced_params
101
+ else:
102
+ self.params = TavilySearchTool.__basic_params
103
+
104
+ print(f"TavilySearchTool advanced mode has been {'enabled' if enable else 'disabled'}.")
105
+
106
  def forward(self, query: str):
107
+
108
+ params = self.params
109
+ params["query"] = query
110
+
111
+ try:
112
+ response = self._tavily_client.search(**params)
113
+ except Exception as e:
114
+ return f"Error calling Tavily API: {e}"
115
+
116
  return response
117
 
118
  class TavilyExtractTool(TavilyBaseClient, Tool):
119
  """
120
+ A tool to extract raw information from web pages using the Tavily API.
121
  """
122
  name = "tavily_extract"
123
+ description = "Extract raw information from web pages using Tavily."
124
  inputs = {
125
  "url": {
126
  "type": "string",
 
129
  }
130
  output_type = "string"
131
 
132
+ def __init__(self):
133
+ """
134
+ Construct the TavilyExtractTool.
135
+ """
136
+ # Call superclass constructor
137
+ super().__init__()
138
+
139
+ self.extract_depth = "basic"
140
+
141
+ def enable_advanced_mode(self, enable: bool = True):
142
+ """
143
+ Enable or disable advanced mode for the extract tool.
144
+ Advanced mode uses more credits but yields better results (retrieves more data, including tables and embedded content).
145
+ """
146
+ if enable:
147
+ self.extract_depth = "advanced"
148
+ else:
149
+ self.extract_depth = "basic"
150
+
151
+ print(f"TavilyExtractTool advanced mode has been {'enabled' if enable else 'disabled'}.")
152
+
153
  def forward(self, url: str):
154
+ try:
155
+ response = self._tavily_client.extract(
156
+ urls=url,
157
+ extract_depth=self.extract_depth)
158
+ except Exception as e:
159
+ return f"Error calling Tavily extract API: {e}"
160
+
161
+ # Tavily's Extract API can return raw HTML + text.
162
+ # you may trim or sanitize here if needed.
163
  return response
164
 
165
  class TavilyImageURLSearchTool(TavilyBaseClient, Tool):
 
177
  output_type = "string"
178
 
179
  def forward(self, query: str):
180
+ try:
181
+ response = self._tavily_client.search(
182
+ query,
183
+ include_images=True,
184
+ include_image_descriptions=True,
185
+ max_results=5
186
+ )
187
+ except Exception as e:
188
+ return f"Error calling Tavily API: {e}"
189
+
190
  images = response.get("images", [])
191
+ if not images:
192
+ return "none"
193
+
194
+ # Return the URL of the first image
195
+ first_image = images[0]
196
+ if isinstance(first_image, dict):
197
+ return first_image.get("url", "none")
198
+ return first_image or "none"