1use crate::{error::LlamaCoreError, metadata::ggml::GgmlMetadata, CHAT_GRAPHS};
4use reqwest::{Client, Url};
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7
8#[derive(Debug, Eq, PartialEq)]
10pub enum ContentType {
11 JSON,
12}
13
14impl std::fmt::Display for ContentType {
15 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
16 write!(
17 f,
18 "{}",
19 match &self {
20 ContentType::JSON => "application/json",
21 }
22 )
23 }
24}
25
26#[derive(Debug)]
28pub struct SearchConfig {
29 #[allow(dead_code)]
31 pub search_engine: String,
32 pub max_search_results: u8,
34 pub size_limit_per_result: u16,
36 pub endpoint: String,
38 pub content_type: ContentType,
40 pub output_content_type: ContentType,
42 pub method: String,
44 pub additional_headers: Option<std::collections::HashMap<String, String>>,
46 pub parser: fn(&serde_json::Value) -> Result<SearchOutput, Box<dyn std::error::Error>>,
48 pub summarization_prompts: Option<(String, String)>,
50 pub summarize_ctx_size: Option<usize>,
52}
53
54#[derive(Serialize, Deserialize)]
56pub struct SearchResult {
57 pub url: String,
58 pub site_name: String,
59 pub text_content: String,
60}
61
62#[derive(Serialize, Deserialize)]
64pub struct SearchOutput {
65 pub results: Vec<SearchResult>,
66}
67
68impl SearchConfig {
69 pub fn parse_into_results(
71 &self,
72 raw_results: &serde_json::Value,
73 ) -> Result<SearchOutput, Box<dyn std::error::Error>> {
74 (self.parser)(raw_results)
75 }
76
77 #[allow(clippy::too_many_arguments)]
78 pub fn new(
79 search_engine: String,
80 max_search_results: u8,
81 size_limit_per_result: u16,
82 endpoint: String,
83 content_type: ContentType,
84 output_content_type: ContentType,
85 method: String,
86 additional_headers: Option<std::collections::HashMap<String, String>>,
87 parser: fn(&serde_json::Value) -> Result<SearchOutput, Box<dyn std::error::Error>>,
88 summarization_prompts: Option<(String, String)>,
89 summarize_ctx_size: Option<usize>,
90 ) -> SearchConfig {
91 SearchConfig {
92 search_engine,
93 max_search_results,
94 size_limit_per_result,
95 endpoint,
96 content_type,
97 output_content_type,
98 method,
99 additional_headers,
100 parser,
101 summarization_prompts,
102 summarize_ctx_size,
103 }
104 }
105 pub async fn perform_search<T: Serialize>(
107 &self,
108 search_input: &T,
109 ) -> Result<SearchOutput, LlamaCoreError> {
110 let client = Client::new();
111 let url = match Url::parse(&self.endpoint) {
112 Ok(url) => url,
113 Err(_) => {
114 let msg = "Malformed endpoint url";
115 #[cfg(feature = "logging")]
116 error!(target: "stdout", "perform_search: {}", msg);
117 return Err(LlamaCoreError::Search(format!(
118 "When parsing endpoint url: {}",
119 msg
120 )));
121 }
122 };
123
124 let method_as_string = match reqwest::Method::from_bytes(self.method.as_bytes()) {
125 Ok(method) => method,
126 _ => {
127 let msg = "Non Standard or unknown method";
128 #[cfg(feature = "logging")]
129 error!(target: "stdout", "perform_search: {}", msg);
130 return Err(LlamaCoreError::Search(format!(
131 "When converting method from bytes: {}",
132 msg
133 )));
134 }
135 };
136
137 let mut req = client.request(method_as_string.clone(), url);
138
139 req = req.headers(
141 match (&self.additional_headers.clone().unwrap_or_default()).try_into() {
142 Ok(headers) => headers,
143 Err(_) => {
144 let msg = "Failed to convert headers from HashMaps to HeaderMaps";
145 #[cfg(feature = "logging")]
146 error!(target: "stdout", "perform_search: {}", msg);
147 return Err(LlamaCoreError::Search(format!(
148 "On converting headers: {}",
149 msg
150 )));
151 }
152 },
153 );
154
155 req = match method_as_string {
158 reqwest::Method::POST => match self.content_type {
159 ContentType::JSON => req.json(search_input),
160 },
161 reqwest::Method::GET => req.query(search_input),
162 _ => {
163 let msg = format!(
164 "Unsupported request method: {}",
165 method_as_string.to_owned()
166 );
167 #[cfg(feature = "logging")]
168 error!(target: "stdout", "perform_search: {}", msg);
169 return Err(LlamaCoreError::Search(msg));
170 }
171 };
172
173 let res = match req.send().await {
174 Ok(r) => r,
175 Err(e) => {
176 let msg = e.to_string();
177 #[cfg(feature = "logging")]
178 error!(target: "stdout", "perform_search: {}", msg);
179 return Err(LlamaCoreError::Search(format!(
180 "When recieving response: {}",
181 msg
182 )));
183 }
184 };
185
186 match res.content_length() {
187 Some(length) => {
188 if length == 0 {
189 let msg = "Empty response from server";
190 #[cfg(feature = "logging")]
191 error!(target: "stdout", "perform_search: {}", msg);
192 return Err(LlamaCoreError::Search(format!(
193 "Unexpected content length: {}",
194 msg
195 )));
196 }
197 }
198 None => {
199 let msg = "Content length returned None";
200 #[cfg(feature = "logging")]
201 error!(target: "stdout", "perform_search: {}", msg);
202 return Err(LlamaCoreError::Search(format!(
203 "Content length field not found: {}",
204 msg
205 )));
206 }
207 }
208
209 let raw_results: Value;
214 match self.output_content_type {
215 ContentType::JSON => {
216 let body_text = match res.text().await {
217 Ok(body) => body,
218 Err(e) => {
219 let msg = e.to_string();
220 #[cfg(feature = "logging")]
221 error!(target: "stdout", "perform_search: {}", msg);
222 return Err(LlamaCoreError::Search(format!(
223 "When accessing response body: {}",
224 msg
225 )));
226 }
227 };
228 println!("{}", body_text);
229 raw_results = match serde_json::from_str(body_text.as_str()) {
230 Ok(value) => value,
231 Err(e) => {
232 let msg = e.to_string();
233 #[cfg(feature = "logging")]
234 error!(target: "stdout", "perform_search: {}", msg);
235 return Err(LlamaCoreError::Search(format!(
236 "When converting to a JSON object: {}",
237 msg
238 )));
239 }
240 };
241 }
242 };
243
244 let mut search_output: SearchOutput = match self.parse_into_results(&raw_results) {
248 Ok(search_output) => search_output,
249 Err(e) => {
250 let msg = e.to_string();
251 #[cfg(feature = "logging")]
252 error!(target: "stdout", "perform_search: {}", msg);
253 return Err(LlamaCoreError::Search(format!(
254 "When calling parse_into_results: {}",
255 msg
256 )));
257 }
258 };
259
260 search_output
262 .results
263 .truncate(self.max_search_results as usize);
264
265 for result in search_output.results.iter_mut() {
270 if let Some(clipped_content) = result
271 .text_content
272 .split_at_checked(self.size_limit_per_result as usize)
273 {
274 result.text_content = clipped_content.0.to_string();
275 }
276 }
277
278 Ok(search_output)
280 }
281 pub async fn summarize_search<T: Serialize>(
283 &self,
284 search_input: &T,
285 ) -> Result<String, LlamaCoreError> {
286 let search_output = self.perform_search(&search_input).await?;
287
288 let summarization_prompts = self.summarization_prompts.clone().unwrap_or((
289 "The following are search results I found on the internet:\n\n".to_string(),
290 "\n\nTo sum up them up: ".to_string(),
291 ));
292
293 let summarize_ctx_size = self
295 .summarize_ctx_size
296 .unwrap_or((self.size_limit_per_result * self.max_search_results as u16) as usize);
297
298 summarize(
299 search_output,
300 summarize_ctx_size,
301 summarization_prompts.0,
302 summarization_prompts.1,
303 )
304 }
305}
306
307fn summarize(
309 search_output: SearchOutput,
310 summarize_ctx_size: usize,
311 initial_prompt: String,
312 final_prompt: String,
313) -> Result<String, LlamaCoreError> {
314 let mut search_output_string: String = String::new();
315
316 search_output
318 .results
319 .iter()
320 .for_each(|result| search_output_string.push_str(result.text_content.as_str()));
321
322 let running_mode = crate::running_mode()?;
324 if running_mode.contains(crate::RunningMode::EMBEDDINGS) {
325 let err_msg = "Summarization is not supported in the EMBEDDINGS running mode.";
326
327 #[cfg(feature = "logging")]
328 error!(target: "stdout", "{}", err_msg);
329
330 return Err(LlamaCoreError::Search(err_msg.into()));
331 }
332
333 let chat_graphs = match CHAT_GRAPHS.get() {
335 Some(chat_graphs) => chat_graphs,
336 None => {
337 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
338
339 #[cfg(feature = "logging")]
340 error!(target: "stdout", "{}", err_msg);
341
342 return Err(LlamaCoreError::Search(err_msg.into()));
343 }
344 };
345
346 let mut chat_graphs = chat_graphs.lock().map_err(|e| {
347 let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {}", e);
348
349 #[cfg(feature = "logging")]
350 error!(target: "stdout", "{}", &err_msg);
351
352 LlamaCoreError::Search(err_msg)
353 })?;
354
355 let input = initial_prompt + search_output_string.as_str() + final_prompt.as_str();
357 let tensor_data = input.as_bytes().to_vec();
358
359 let graph: &mut crate::Graph<GgmlMetadata> = match chat_graphs.values_mut().next() {
361 Some(graph) => graph,
362 None => {
363 let err_msg = "No available chat graph.";
364
365 #[cfg(feature = "logging")]
366 error!(target: "stdout", "{}", err_msg);
367
368 return Err(LlamaCoreError::Search(err_msg.into()));
369 }
370 };
371
372 graph
373 .set_input(0, wasmedge_wasi_nn::TensorType::U8, &[1], &tensor_data)
374 .expect("Failed to set prompt as the input tensor");
375
376 #[cfg(feature = "logging")]
377 info!(target: "stdout", "Generating a summary for search results...");
378 graph.compute().expect("Failed to complete inference");
380
381 let mut output_buffer = vec![0u8; summarize_ctx_size];
383 let mut output_size = graph
384 .get_output(0, &mut output_buffer)
385 .expect("Failed to get output tensor");
386 output_size = std::cmp::min(summarize_ctx_size, output_size);
387
388 let output = String::from_utf8_lossy(&output_buffer[..output_size]).to_string();
390
391 #[cfg(feature = "logging")]
392 info!(target: "stdout", "Summary generated.");
393
394 Ok(output)
395}