llama_core/
search.rs

1//! Define APIs for web search operations.
2
3use crate::{error::LlamaCoreError, metadata::ggml::GgmlMetadata, CHAT_GRAPHS};
4use reqwest::{Client, Url};
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7
8/// Possible input/output Content Types. Currently only supports JSON.
9#[derive(Debug, Eq, PartialEq)]
10pub enum ContentType {
11    JSON,
12}
13
14impl std::fmt::Display for ContentType {
15    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
16        write!(
17            f,
18            "{}",
19            match &self {
20                ContentType::JSON => "application/json",
21            }
22        )
23    }
24}
25
26/// The base Search Configuration holding all relevant information to access a search api and retrieve results.
27#[derive(Debug)]
28pub struct SearchConfig {
29    /// The search engine we're currently focusing on. Currently only one supported, to ensure stability.
30    #[allow(dead_code)]
31    pub search_engine: String,
32    /// The total number of results.
33    pub max_search_results: u8,
34    /// The size limit of every search result.
35    pub size_limit_per_result: u16,
36    /// The endpoint for the search API.
37    pub endpoint: String,
38    /// The content type of the input.
39    pub content_type: ContentType,
40    /// The (expected) content type of the output.
41    pub output_content_type: ContentType,
42    /// Method expected by the api endpoint.
43    pub method: String,
44    /// Additional headers for any other purpose.
45    pub additional_headers: Option<std::collections::HashMap<String, String>>,
46    /// Callback function to parse the output of the api-service. Implementation left to the user.
47    pub parser: fn(&serde_json::Value) -> Result<SearchOutput, Box<dyn std::error::Error>>,
48    /// Prompts for use with summarization functionality. If set to `None`, use hard-coded prompts.
49    pub summarization_prompts: Option<(String, String)>,
50    /// Context size for summary generation. If `None`, will use the 4 char ~ 1 token metric to generate summary.
51    pub summarize_ctx_size: Option<usize>,
52}
53
54/// output format for individual results in the final output.
55#[derive(Serialize, Deserialize)]
56pub struct SearchResult {
57    pub url: String,
58    pub site_name: String,
59    pub text_content: String,
60}
61
62/// Final output format for consumption by the LLM.
63#[derive(Serialize, Deserialize)]
64pub struct SearchOutput {
65    pub results: Vec<SearchResult>,
66}
67
68impl SearchConfig {
69    /// Wrapper for the parser() function.
70    pub fn parse_into_results(
71        &self,
72        raw_results: &serde_json::Value,
73    ) -> Result<SearchOutput, Box<dyn std::error::Error>> {
74        (self.parser)(raw_results)
75    }
76
77    #[allow(clippy::too_many_arguments)]
78    pub fn new(
79        search_engine: String,
80        max_search_results: u8,
81        size_limit_per_result: u16,
82        endpoint: String,
83        content_type: ContentType,
84        output_content_type: ContentType,
85        method: String,
86        additional_headers: Option<std::collections::HashMap<String, String>>,
87        parser: fn(&serde_json::Value) -> Result<SearchOutput, Box<dyn std::error::Error>>,
88        summarization_prompts: Option<(String, String)>,
89        summarize_ctx_size: Option<usize>,
90    ) -> SearchConfig {
91        SearchConfig {
92            search_engine,
93            max_search_results,
94            size_limit_per_result,
95            endpoint,
96            content_type,
97            output_content_type,
98            method,
99            additional_headers,
100            parser,
101            summarization_prompts,
102            summarize_ctx_size,
103        }
104    }
105    /// Perform a web search with a `Serialize`-able input. The `search_input` is used as is to query the search endpoint.
106    pub async fn perform_search<T: Serialize>(
107        &self,
108        search_input: &T,
109    ) -> Result<SearchOutput, LlamaCoreError> {
110        let client = Client::new();
111        let url = match Url::parse(&self.endpoint) {
112            Ok(url) => url,
113            Err(_) => {
114                let msg = "Malformed endpoint url";
115                #[cfg(feature = "logging")]
116                error!(target: "stdout", "perform_search: {msg}");
117                return Err(LlamaCoreError::Search(format!(
118                    "When parsing endpoint url: {msg}"
119                )));
120            }
121        };
122
123        let method_as_string = match reqwest::Method::from_bytes(self.method.as_bytes()) {
124            Ok(method) => method,
125            _ => {
126                let msg = "Non Standard or unknown method";
127                #[cfg(feature = "logging")]
128                error!(target: "stdout", "perform_search: {msg}");
129                return Err(LlamaCoreError::Search(format!(
130                    "When converting method from bytes: {msg}"
131                )));
132            }
133        };
134
135        let mut req = client.request(method_as_string.clone(), url);
136
137        // check headers.
138        req = req.headers(
139            match (&self.additional_headers.clone().unwrap_or_default()).try_into() {
140                Ok(headers) => headers,
141                Err(_) => {
142                    let msg = "Failed to convert headers from HashMaps to HeaderMaps";
143                    #[cfg(feature = "logging")]
144                    error!(target: "stdout", "perform_search: {msg}");
145                    return Err(LlamaCoreError::Search(format!(
146                        "On converting headers: {msg}"
147                    )));
148                }
149            },
150        );
151
152        // For POST requests, search_input goes into the request body. For GET requests, in the
153        // params.
154        req = match method_as_string {
155            reqwest::Method::POST => match self.content_type {
156                ContentType::JSON => req.json(search_input),
157            },
158            reqwest::Method::GET => req.query(search_input),
159            _ => {
160                let msg = format!(
161                    "Unsupported request method: {}",
162                    method_as_string.to_owned()
163                );
164                #[cfg(feature = "logging")]
165                error!(target: "stdout", "perform_search: {msg}");
166                return Err(LlamaCoreError::Search(msg));
167            }
168        };
169
170        let res = match req.send().await {
171            Ok(r) => r,
172            Err(e) => {
173                let msg = e.to_string();
174                #[cfg(feature = "logging")]
175                error!(target: "stdout", "perform_search: {msg}");
176                return Err(LlamaCoreError::Search(format!(
177                    "When recieving response: {msg}"
178                )));
179            }
180        };
181
182        match res.content_length() {
183            Some(length) => {
184                if length == 0 {
185                    let msg = "Empty response from server";
186                    #[cfg(feature = "logging")]
187                    error!(target: "stdout", "perform_search: {msg}");
188                    return Err(LlamaCoreError::Search(format!(
189                        "Unexpected content length: {msg}"
190                    )));
191                }
192            }
193            None => {
194                let msg = "Content length returned None";
195                #[cfg(feature = "logging")]
196                error!(target: "stdout", "perform_search: {msg}");
197                return Err(LlamaCoreError::Search(format!(
198                    "Content length field not found: {msg}"
199                )));
200            }
201        }
202
203        // start parsing the output.
204        //
205        // only checking for JSON as the output content type since it's the most common and widely
206        // supported.
207        let raw_results: Value;
208        match self.output_content_type {
209            ContentType::JSON => {
210                let body_text = match res.text().await {
211                    Ok(body) => body,
212                    Err(e) => {
213                        let msg = e.to_string();
214                        #[cfg(feature = "logging")]
215                        error!(target: "stdout", "perform_search: {msg}");
216                        return Err(LlamaCoreError::Search(format!(
217                            "When accessing response body: {msg}"
218                        )));
219                    }
220                };
221                println!("{body_text}");
222                raw_results = match serde_json::from_str(body_text.as_str()) {
223                    Ok(value) => value,
224                    Err(e) => {
225                        let msg = e.to_string();
226                        #[cfg(feature = "logging")]
227                        error!(target: "stdout", "perform_search: {msg}");
228                        return Err(LlamaCoreError::Search(format!(
229                            "When converting to a JSON object: {msg}"
230                        )));
231                    }
232                };
233            }
234        };
235
236        // start cleaning the output.
237
238        // produce SearchOutput instance with the raw results obtained from the endpoint.
239        let mut search_output: SearchOutput = match self.parse_into_results(&raw_results) {
240            Ok(search_output) => search_output,
241            Err(e) => {
242                let msg = e.to_string();
243                #[cfg(feature = "logging")]
244                error!(target: "stdout", "perform_search: {msg}");
245                return Err(LlamaCoreError::Search(format!(
246                    "When calling parse_into_results: {msg}"
247                )));
248            }
249        };
250
251        // apply maximum search result limit.
252        search_output
253            .results
254            .truncate(self.max_search_results as usize);
255
256        // apply per result character limit.
257        //
258        // since the clipping only happens when split_at_checked() returns Some, the results will
259        // remain unchanged should split_at_checked() return None.
260        for result in search_output.results.iter_mut() {
261            if let Some(clipped_content) = result
262                .text_content
263                .split_at_checked(self.size_limit_per_result as usize)
264            {
265                result.text_content = clipped_content.0.to_string();
266            }
267        }
268
269        // Search Output cleaned and finalized.
270        Ok(search_output)
271    }
272    /// Perform a search and summarize the corresponding search results
273    pub async fn summarize_search<T: Serialize>(
274        &self,
275        search_input: &T,
276    ) -> Result<String, LlamaCoreError> {
277        let search_output = self.perform_search(&search_input).await?;
278
279        let summarization_prompts = self.summarization_prompts.clone().unwrap_or((
280            "The following are search results I found on the internet:\n\n".to_string(),
281            "\n\nTo sum up them up: ".to_string(),
282        ));
283
284        // the fallback context size limit for the search summary to be generated.
285        let summarize_ctx_size = self
286            .summarize_ctx_size
287            .unwrap_or((self.size_limit_per_result * self.max_search_results as u16) as usize);
288
289        summarize(
290            search_output,
291            summarize_ctx_size,
292            summarization_prompts.0,
293            summarization_prompts.1,
294        )
295    }
296}
297
298/// Summarize the search output provided
299fn summarize(
300    search_output: SearchOutput,
301    summarize_ctx_size: usize,
302    initial_prompt: String,
303    final_prompt: String,
304) -> Result<String, LlamaCoreError> {
305    let mut search_output_string: String = String::new();
306
307    // Add the text content of every result together.
308    search_output
309        .results
310        .iter()
311        .for_each(|result| search_output_string.push_str(result.text_content.as_str()));
312
313    // Error on embedding running mode.
314    let running_mode = crate::running_mode()?;
315    if running_mode.contains(crate::RunningMode::EMBEDDINGS) {
316        let err_msg = "Summarization is not supported in the EMBEDDINGS running mode.";
317
318        #[cfg(feature = "logging")]
319        error!(target: "stdout", "{err_msg}");
320
321        return Err(LlamaCoreError::Search(err_msg.into()));
322    }
323
324    // Get graphs and pick the first graph.
325    let chat_graphs = match CHAT_GRAPHS.get() {
326        Some(chat_graphs) => chat_graphs,
327        None => {
328            let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
329
330            #[cfg(feature = "logging")]
331            error!(target: "stdout", "{err_msg}");
332
333            return Err(LlamaCoreError::Search(err_msg.into()));
334        }
335    };
336
337    let mut chat_graphs = chat_graphs.lock().map_err(|e| {
338        let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
339
340        #[cfg(feature = "logging")]
341        error!(target: "stdout", "{}", &err_msg);
342
343        LlamaCoreError::Search(err_msg)
344    })?;
345
346    // Prepare input prompt.
347    let input = initial_prompt + search_output_string.as_str() + final_prompt.as_str();
348    let tensor_data = input.as_bytes().to_vec();
349
350    // Use first available chat graph
351    let graph: &mut crate::Graph<GgmlMetadata> = match chat_graphs.values_mut().next() {
352        Some(graph) => graph,
353        None => {
354            let err_msg = "No available chat graph.";
355
356            #[cfg(feature = "logging")]
357            error!(target: "stdout", "{err_msg}");
358
359            return Err(LlamaCoreError::Search(err_msg.into()));
360        }
361    };
362
363    graph
364        .set_input(0, wasmedge_wasi_nn::TensorType::U8, &[1], &tensor_data)
365        .expect("Failed to set prompt as the input tensor");
366
367    #[cfg(feature = "logging")]
368    info!(target: "stdout", "Generating a summary for search results...");
369    // Execute the inference.
370    graph.compute().expect("Failed to complete inference");
371
372    // Retrieve the output.
373    let mut output_buffer = vec![0u8; summarize_ctx_size];
374    let mut output_size = graph
375        .get_output(0, &mut output_buffer)
376        .expect("Failed to get output tensor");
377    output_size = std::cmp::min(summarize_ctx_size, output_size);
378
379    // Compute lossy UTF-8 output (text only).
380    let output = String::from_utf8_lossy(&output_buffer[..output_size]).to_string();
381
382    #[cfg(feature = "logging")]
383    info!(target: "stdout", "Summary generated.");
384
385    Ok(output)
386}