llama_core/
search.rs

1//! Define APIs for web search operations.
2
3use crate::{error::LlamaCoreError, metadata::ggml::GgmlMetadata, CHAT_GRAPHS};
4use reqwest::{Client, Url};
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7
8/// Possible input/output Content Types. Currently only supports JSON.
9#[derive(Debug, Eq, PartialEq)]
10pub enum ContentType {
11    JSON,
12}
13
14impl std::fmt::Display for ContentType {
15    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
16        write!(
17            f,
18            "{}",
19            match &self {
20                ContentType::JSON => "application/json",
21            }
22        )
23    }
24}
25
26/// The base Search Configuration holding all relevant information to access a search api and retrieve results.
27#[derive(Debug)]
28pub struct SearchConfig {
29    /// The search engine we're currently focusing on. Currently only one supported, to ensure stability.
30    #[allow(dead_code)]
31    pub search_engine: String,
32    /// The total number of results.
33    pub max_search_results: u8,
34    /// The size limit of every search result.
35    pub size_limit_per_result: u16,
36    /// The endpoint for the search API.
37    pub endpoint: String,
38    /// The content type of the input.
39    pub content_type: ContentType,
40    /// The (expected) content type of the output.
41    pub output_content_type: ContentType,
42    /// Method expected by the api endpoint.
43    pub method: String,
44    /// Additional headers for any other purpose.
45    pub additional_headers: Option<std::collections::HashMap<String, String>>,
46    /// Callback function to parse the output of the api-service. Implementation left to the user.
47    pub parser: fn(&serde_json::Value) -> Result<SearchOutput, Box<dyn std::error::Error>>,
48    /// Prompts for use with summarization functionality. If set to `None`, use hard-coded prompts.
49    pub summarization_prompts: Option<(String, String)>,
50    /// Context size for summary generation. If `None`, will use the 4 char ~ 1 token metric to generate summary.
51    pub summarize_ctx_size: Option<usize>,
52}
53
54/// output format for individual results in the final output.
55#[derive(Serialize, Deserialize)]
56pub struct SearchResult {
57    pub url: String,
58    pub site_name: String,
59    pub text_content: String,
60}
61
62/// Final output format for consumption by the LLM.
63#[derive(Serialize, Deserialize)]
64pub struct SearchOutput {
65    pub results: Vec<SearchResult>,
66}
67
68impl SearchConfig {
69    /// Wrapper for the parser() function.
70    pub fn parse_into_results(
71        &self,
72        raw_results: &serde_json::Value,
73    ) -> Result<SearchOutput, Box<dyn std::error::Error>> {
74        (self.parser)(raw_results)
75    }
76
77    #[allow(clippy::too_many_arguments)]
78    pub fn new(
79        search_engine: String,
80        max_search_results: u8,
81        size_limit_per_result: u16,
82        endpoint: String,
83        content_type: ContentType,
84        output_content_type: ContentType,
85        method: String,
86        additional_headers: Option<std::collections::HashMap<String, String>>,
87        parser: fn(&serde_json::Value) -> Result<SearchOutput, Box<dyn std::error::Error>>,
88        summarization_prompts: Option<(String, String)>,
89        summarize_ctx_size: Option<usize>,
90    ) -> SearchConfig {
91        SearchConfig {
92            search_engine,
93            max_search_results,
94            size_limit_per_result,
95            endpoint,
96            content_type,
97            output_content_type,
98            method,
99            additional_headers,
100            parser,
101            summarization_prompts,
102            summarize_ctx_size,
103        }
104    }
105    /// Perform a web search with a `Serialize`-able input. The `search_input` is used as is to query the search endpoint.
106    pub async fn perform_search<T: Serialize>(
107        &self,
108        search_input: &T,
109    ) -> Result<SearchOutput, LlamaCoreError> {
110        let client = Client::new();
111        let url = match Url::parse(&self.endpoint) {
112            Ok(url) => url,
113            Err(_) => {
114                let msg = "Malformed endpoint url";
115                #[cfg(feature = "logging")]
116                error!(target: "stdout", "perform_search: {}", msg);
117                return Err(LlamaCoreError::Search(format!(
118                    "When parsing endpoint url: {}",
119                    msg
120                )));
121            }
122        };
123
124        let method_as_string = match reqwest::Method::from_bytes(self.method.as_bytes()) {
125            Ok(method) => method,
126            _ => {
127                let msg = "Non Standard or unknown method";
128                #[cfg(feature = "logging")]
129                error!(target: "stdout", "perform_search: {}", msg);
130                return Err(LlamaCoreError::Search(format!(
131                    "When converting method from bytes: {}",
132                    msg
133                )));
134            }
135        };
136
137        let mut req = client.request(method_as_string.clone(), url);
138
139        // check headers.
140        req = req.headers(
141            match (&self.additional_headers.clone().unwrap_or_default()).try_into() {
142                Ok(headers) => headers,
143                Err(_) => {
144                    let msg = "Failed to convert headers from HashMaps to HeaderMaps";
145                    #[cfg(feature = "logging")]
146                    error!(target: "stdout", "perform_search: {}", msg);
147                    return Err(LlamaCoreError::Search(format!(
148                        "On converting headers: {}",
149                        msg
150                    )));
151                }
152            },
153        );
154
155        // For POST requests, search_input goes into the request body. For GET requests, in the
156        // params.
157        req = match method_as_string {
158            reqwest::Method::POST => match self.content_type {
159                ContentType::JSON => req.json(search_input),
160            },
161            reqwest::Method::GET => req.query(search_input),
162            _ => {
163                let msg = format!(
164                    "Unsupported request method: {}",
165                    method_as_string.to_owned()
166                );
167                #[cfg(feature = "logging")]
168                error!(target: "stdout", "perform_search: {}", msg);
169                return Err(LlamaCoreError::Search(msg));
170            }
171        };
172
173        let res = match req.send().await {
174            Ok(r) => r,
175            Err(e) => {
176                let msg = e.to_string();
177                #[cfg(feature = "logging")]
178                error!(target: "stdout", "perform_search: {}", msg);
179                return Err(LlamaCoreError::Search(format!(
180                    "When recieving response: {}",
181                    msg
182                )));
183            }
184        };
185
186        match res.content_length() {
187            Some(length) => {
188                if length == 0 {
189                    let msg = "Empty response from server";
190                    #[cfg(feature = "logging")]
191                    error!(target: "stdout", "perform_search: {}", msg);
192                    return Err(LlamaCoreError::Search(format!(
193                        "Unexpected content length: {}",
194                        msg
195                    )));
196                }
197            }
198            None => {
199                let msg = "Content length returned None";
200                #[cfg(feature = "logging")]
201                error!(target: "stdout", "perform_search: {}", msg);
202                return Err(LlamaCoreError::Search(format!(
203                    "Content length field not found: {}",
204                    msg
205                )));
206            }
207        }
208
209        // start parsing the output.
210        //
211        // only checking for JSON as the output content type since it's the most common and widely
212        // supported.
213        let raw_results: Value;
214        match self.output_content_type {
215            ContentType::JSON => {
216                let body_text = match res.text().await {
217                    Ok(body) => body,
218                    Err(e) => {
219                        let msg = e.to_string();
220                        #[cfg(feature = "logging")]
221                        error!(target: "stdout", "perform_search: {}", msg);
222                        return Err(LlamaCoreError::Search(format!(
223                            "When accessing response body: {}",
224                            msg
225                        )));
226                    }
227                };
228                println!("{}", body_text);
229                raw_results = match serde_json::from_str(body_text.as_str()) {
230                    Ok(value) => value,
231                    Err(e) => {
232                        let msg = e.to_string();
233                        #[cfg(feature = "logging")]
234                        error!(target: "stdout", "perform_search: {}", msg);
235                        return Err(LlamaCoreError::Search(format!(
236                            "When converting to a JSON object: {}",
237                            msg
238                        )));
239                    }
240                };
241            }
242        };
243
244        // start cleaning the output.
245
246        // produce SearchOutput instance with the raw results obtained from the endpoint.
247        let mut search_output: SearchOutput = match self.parse_into_results(&raw_results) {
248            Ok(search_output) => search_output,
249            Err(e) => {
250                let msg = e.to_string();
251                #[cfg(feature = "logging")]
252                error!(target: "stdout", "perform_search: {}", msg);
253                return Err(LlamaCoreError::Search(format!(
254                    "When calling parse_into_results: {}",
255                    msg
256                )));
257            }
258        };
259
260        // apply maximum search result limit.
261        search_output
262            .results
263            .truncate(self.max_search_results as usize);
264
265        // apply per result character limit.
266        //
267        // since the clipping only happens when split_at_checked() returns Some, the results will
268        // remain unchanged should split_at_checked() return None.
269        for result in search_output.results.iter_mut() {
270            if let Some(clipped_content) = result
271                .text_content
272                .split_at_checked(self.size_limit_per_result as usize)
273            {
274                result.text_content = clipped_content.0.to_string();
275            }
276        }
277
278        // Search Output cleaned and finalized.
279        Ok(search_output)
280    }
281    /// Perform a search and summarize the corresponding search results
282    pub async fn summarize_search<T: Serialize>(
283        &self,
284        search_input: &T,
285    ) -> Result<String, LlamaCoreError> {
286        let search_output = self.perform_search(&search_input).await?;
287
288        let summarization_prompts = self.summarization_prompts.clone().unwrap_or((
289            "The following are search results I found on the internet:\n\n".to_string(),
290            "\n\nTo sum up them up: ".to_string(),
291        ));
292
293        // the fallback context size limit for the search summary to be generated.
294        let summarize_ctx_size = self
295            .summarize_ctx_size
296            .unwrap_or((self.size_limit_per_result * self.max_search_results as u16) as usize);
297
298        summarize(
299            search_output,
300            summarize_ctx_size,
301            summarization_prompts.0,
302            summarization_prompts.1,
303        )
304    }
305}
306
307/// Summarize the search output provided
308fn summarize(
309    search_output: SearchOutput,
310    summarize_ctx_size: usize,
311    initial_prompt: String,
312    final_prompt: String,
313) -> Result<String, LlamaCoreError> {
314    let mut search_output_string: String = String::new();
315
316    // Add the text content of every result together.
317    search_output
318        .results
319        .iter()
320        .for_each(|result| search_output_string.push_str(result.text_content.as_str()));
321
322    // Error on embedding running mode.
323    let running_mode = crate::running_mode()?;
324    if running_mode.contains(crate::RunningMode::EMBEDDINGS) {
325        let err_msg = "Summarization is not supported in the EMBEDDINGS running mode.";
326
327        #[cfg(feature = "logging")]
328        error!(target: "stdout", "{}", err_msg);
329
330        return Err(LlamaCoreError::Search(err_msg.into()));
331    }
332
333    // Get graphs and pick the first graph.
334    let chat_graphs = match CHAT_GRAPHS.get() {
335        Some(chat_graphs) => chat_graphs,
336        None => {
337            let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
338
339            #[cfg(feature = "logging")]
340            error!(target: "stdout", "{}", err_msg);
341
342            return Err(LlamaCoreError::Search(err_msg.into()));
343        }
344    };
345
346    let mut chat_graphs = chat_graphs.lock().map_err(|e| {
347        let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {}", e);
348
349        #[cfg(feature = "logging")]
350        error!(target: "stdout", "{}", &err_msg);
351
352        LlamaCoreError::Search(err_msg)
353    })?;
354
355    // Prepare input prompt.
356    let input = initial_prompt + search_output_string.as_str() + final_prompt.as_str();
357    let tensor_data = input.as_bytes().to_vec();
358
359    // Use first available chat graph
360    let graph: &mut crate::Graph<GgmlMetadata> = match chat_graphs.values_mut().next() {
361        Some(graph) => graph,
362        None => {
363            let err_msg = "No available chat graph.";
364
365            #[cfg(feature = "logging")]
366            error!(target: "stdout", "{}", err_msg);
367
368            return Err(LlamaCoreError::Search(err_msg.into()));
369        }
370    };
371
372    graph
373        .set_input(0, wasmedge_wasi_nn::TensorType::U8, &[1], &tensor_data)
374        .expect("Failed to set prompt as the input tensor");
375
376    #[cfg(feature = "logging")]
377    info!(target: "stdout", "Generating a summary for search results...");
378    // Execute the inference.
379    graph.compute().expect("Failed to complete inference");
380
381    // Retrieve the output.
382    let mut output_buffer = vec![0u8; summarize_ctx_size];
383    let mut output_size = graph
384        .get_output(0, &mut output_buffer)
385        .expect("Failed to get output tensor");
386    output_size = std::cmp::min(summarize_ctx_size, output_size);
387
388    // Compute lossy UTF-8 output (text only).
389    let output = String::from_utf8_lossy(&output_buffer[..output_size]).to_string();
390
391    #[cfg(feature = "logging")]
392    info!(target: "stdout", "Summary generated.");
393
394    Ok(output)
395}