1use crate::{error::LlamaCoreError, metadata::ggml::GgmlMetadata, CHAT_GRAPHS};
4use reqwest::{Client, Url};
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7
8#[derive(Debug, Eq, PartialEq)]
10pub enum ContentType {
11 JSON,
12}
13
14impl std::fmt::Display for ContentType {
15 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
16 write!(
17 f,
18 "{}",
19 match &self {
20 ContentType::JSON => "application/json",
21 }
22 )
23 }
24}
25
26#[derive(Debug)]
28pub struct SearchConfig {
29 #[allow(dead_code)]
31 pub search_engine: String,
32 pub max_search_results: u8,
34 pub size_limit_per_result: u16,
36 pub endpoint: String,
38 pub content_type: ContentType,
40 pub output_content_type: ContentType,
42 pub method: String,
44 pub additional_headers: Option<std::collections::HashMap<String, String>>,
46 pub parser: fn(&serde_json::Value) -> Result<SearchOutput, Box<dyn std::error::Error>>,
48 pub summarization_prompts: Option<(String, String)>,
50 pub summarize_ctx_size: Option<usize>,
52}
53
54#[derive(Serialize, Deserialize)]
56pub struct SearchResult {
57 pub url: String,
58 pub site_name: String,
59 pub text_content: String,
60}
61
62#[derive(Serialize, Deserialize)]
64pub struct SearchOutput {
65 pub results: Vec<SearchResult>,
66}
67
68impl SearchConfig {
69 pub fn parse_into_results(
71 &self,
72 raw_results: &serde_json::Value,
73 ) -> Result<SearchOutput, Box<dyn std::error::Error>> {
74 (self.parser)(raw_results)
75 }
76
77 #[allow(clippy::too_many_arguments)]
78 pub fn new(
79 search_engine: String,
80 max_search_results: u8,
81 size_limit_per_result: u16,
82 endpoint: String,
83 content_type: ContentType,
84 output_content_type: ContentType,
85 method: String,
86 additional_headers: Option<std::collections::HashMap<String, String>>,
87 parser: fn(&serde_json::Value) -> Result<SearchOutput, Box<dyn std::error::Error>>,
88 summarization_prompts: Option<(String, String)>,
89 summarize_ctx_size: Option<usize>,
90 ) -> SearchConfig {
91 SearchConfig {
92 search_engine,
93 max_search_results,
94 size_limit_per_result,
95 endpoint,
96 content_type,
97 output_content_type,
98 method,
99 additional_headers,
100 parser,
101 summarization_prompts,
102 summarize_ctx_size,
103 }
104 }
105 pub async fn perform_search<T: Serialize>(
107 &self,
108 search_input: &T,
109 ) -> Result<SearchOutput, LlamaCoreError> {
110 let client = Client::new();
111 let url = match Url::parse(&self.endpoint) {
112 Ok(url) => url,
113 Err(_) => {
114 let msg = "Malformed endpoint url";
115 #[cfg(feature = "logging")]
116 error!(target: "stdout", "perform_search: {msg}");
117 return Err(LlamaCoreError::Search(format!(
118 "When parsing endpoint url: {msg}"
119 )));
120 }
121 };
122
123 let method_as_string = match reqwest::Method::from_bytes(self.method.as_bytes()) {
124 Ok(method) => method,
125 _ => {
126 let msg = "Non Standard or unknown method";
127 #[cfg(feature = "logging")]
128 error!(target: "stdout", "perform_search: {msg}");
129 return Err(LlamaCoreError::Search(format!(
130 "When converting method from bytes: {msg}"
131 )));
132 }
133 };
134
135 let mut req = client.request(method_as_string.clone(), url);
136
137 req = req.headers(
139 match (&self.additional_headers.clone().unwrap_or_default()).try_into() {
140 Ok(headers) => headers,
141 Err(_) => {
142 let msg = "Failed to convert headers from HashMaps to HeaderMaps";
143 #[cfg(feature = "logging")]
144 error!(target: "stdout", "perform_search: {msg}");
145 return Err(LlamaCoreError::Search(format!(
146 "On converting headers: {msg}"
147 )));
148 }
149 },
150 );
151
152 req = match method_as_string {
155 reqwest::Method::POST => match self.content_type {
156 ContentType::JSON => req.json(search_input),
157 },
158 reqwest::Method::GET => req.query(search_input),
159 _ => {
160 let msg = format!(
161 "Unsupported request method: {}",
162 method_as_string.to_owned()
163 );
164 #[cfg(feature = "logging")]
165 error!(target: "stdout", "perform_search: {msg}");
166 return Err(LlamaCoreError::Search(msg));
167 }
168 };
169
170 let res = match req.send().await {
171 Ok(r) => r,
172 Err(e) => {
173 let msg = e.to_string();
174 #[cfg(feature = "logging")]
175 error!(target: "stdout", "perform_search: {msg}");
176 return Err(LlamaCoreError::Search(format!(
177 "When recieving response: {msg}"
178 )));
179 }
180 };
181
182 match res.content_length() {
183 Some(length) => {
184 if length == 0 {
185 let msg = "Empty response from server";
186 #[cfg(feature = "logging")]
187 error!(target: "stdout", "perform_search: {msg}");
188 return Err(LlamaCoreError::Search(format!(
189 "Unexpected content length: {msg}"
190 )));
191 }
192 }
193 None => {
194 let msg = "Content length returned None";
195 #[cfg(feature = "logging")]
196 error!(target: "stdout", "perform_search: {msg}");
197 return Err(LlamaCoreError::Search(format!(
198 "Content length field not found: {msg}"
199 )));
200 }
201 }
202
203 let raw_results: Value;
208 match self.output_content_type {
209 ContentType::JSON => {
210 let body_text = match res.text().await {
211 Ok(body) => body,
212 Err(e) => {
213 let msg = e.to_string();
214 #[cfg(feature = "logging")]
215 error!(target: "stdout", "perform_search: {msg}");
216 return Err(LlamaCoreError::Search(format!(
217 "When accessing response body: {msg}"
218 )));
219 }
220 };
221 println!("{body_text}");
222 raw_results = match serde_json::from_str(body_text.as_str()) {
223 Ok(value) => value,
224 Err(e) => {
225 let msg = e.to_string();
226 #[cfg(feature = "logging")]
227 error!(target: "stdout", "perform_search: {msg}");
228 return Err(LlamaCoreError::Search(format!(
229 "When converting to a JSON object: {msg}"
230 )));
231 }
232 };
233 }
234 };
235
236 let mut search_output: SearchOutput = match self.parse_into_results(&raw_results) {
240 Ok(search_output) => search_output,
241 Err(e) => {
242 let msg = e.to_string();
243 #[cfg(feature = "logging")]
244 error!(target: "stdout", "perform_search: {msg}");
245 return Err(LlamaCoreError::Search(format!(
246 "When calling parse_into_results: {msg}"
247 )));
248 }
249 };
250
251 search_output
253 .results
254 .truncate(self.max_search_results as usize);
255
256 for result in search_output.results.iter_mut() {
261 if let Some(clipped_content) = result
262 .text_content
263 .split_at_checked(self.size_limit_per_result as usize)
264 {
265 result.text_content = clipped_content.0.to_string();
266 }
267 }
268
269 Ok(search_output)
271 }
272 pub async fn summarize_search<T: Serialize>(
274 &self,
275 search_input: &T,
276 ) -> Result<String, LlamaCoreError> {
277 let search_output = self.perform_search(&search_input).await?;
278
279 let summarization_prompts = self.summarization_prompts.clone().unwrap_or((
280 "The following are search results I found on the internet:\n\n".to_string(),
281 "\n\nTo sum up them up: ".to_string(),
282 ));
283
284 let summarize_ctx_size = self
286 .summarize_ctx_size
287 .unwrap_or((self.size_limit_per_result * self.max_search_results as u16) as usize);
288
289 summarize(
290 search_output,
291 summarize_ctx_size,
292 summarization_prompts.0,
293 summarization_prompts.1,
294 )
295 }
296}
297
298fn summarize(
300 search_output: SearchOutput,
301 summarize_ctx_size: usize,
302 initial_prompt: String,
303 final_prompt: String,
304) -> Result<String, LlamaCoreError> {
305 let mut search_output_string: String = String::new();
306
307 search_output
309 .results
310 .iter()
311 .for_each(|result| search_output_string.push_str(result.text_content.as_str()));
312
313 let running_mode = crate::running_mode()?;
315 if running_mode.contains(crate::RunningMode::EMBEDDINGS) {
316 let err_msg = "Summarization is not supported in the EMBEDDINGS running mode.";
317
318 #[cfg(feature = "logging")]
319 error!(target: "stdout", "{err_msg}");
320
321 return Err(LlamaCoreError::Search(err_msg.into()));
322 }
323
324 let chat_graphs = match CHAT_GRAPHS.get() {
326 Some(chat_graphs) => chat_graphs,
327 None => {
328 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
329
330 #[cfg(feature = "logging")]
331 error!(target: "stdout", "{err_msg}");
332
333 return Err(LlamaCoreError::Search(err_msg.into()));
334 }
335 };
336
337 let mut chat_graphs = chat_graphs.lock().map_err(|e| {
338 let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
339
340 #[cfg(feature = "logging")]
341 error!(target: "stdout", "{}", &err_msg);
342
343 LlamaCoreError::Search(err_msg)
344 })?;
345
346 let input = initial_prompt + search_output_string.as_str() + final_prompt.as_str();
348 let tensor_data = input.as_bytes().to_vec();
349
350 let graph: &mut crate::Graph<GgmlMetadata> = match chat_graphs.values_mut().next() {
352 Some(graph) => graph,
353 None => {
354 let err_msg = "No available chat graph.";
355
356 #[cfg(feature = "logging")]
357 error!(target: "stdout", "{err_msg}");
358
359 return Err(LlamaCoreError::Search(err_msg.into()));
360 }
361 };
362
363 graph
364 .set_input(0, wasmedge_wasi_nn::TensorType::U8, &[1], &tensor_data)
365 .expect("Failed to set prompt as the input tensor");
366
367 #[cfg(feature = "logging")]
368 info!(target: "stdout", "Generating a summary for search results...");
369 graph.compute().expect("Failed to complete inference");
371
372 let mut output_buffer = vec![0u8; summarize_ctx_size];
374 let mut output_size = graph
375 .get_output(0, &mut output_buffer)
376 .expect("Failed to get output tensor");
377 output_size = std::cmp::min(summarize_ctx_size, output_size);
378
379 let output = String::from_utf8_lossy(&output_buffer[..output_size]).to_string();
381
382 #[cfg(feature = "logging")]
383 info!(target: "stdout", "Summary generated.");
384
385 Ok(output)
386}