llama_core/
embeddings.rs

1//! Define APIs for computing embeddings.
2
3use crate::{
4    error::{BackendError, LlamaCoreError},
5    metadata::ggml::GgmlMetadata,
6    running_mode,
7    utils::{get_output_buffer, get_token_info_by_graph, set_tensor_data_u8},
8    Graph, RunningMode, CHAT_GRAPHS, EMBEDDING_GRAPHS, OUTPUT_TENSOR,
9};
10use endpoints::{
11    common::Usage,
12    embeddings::{EmbeddingObject, EmbeddingRequest, EmbeddingsResponse, InputText},
13};
14use serde::{Deserialize, Serialize};
15use text_splitter::{MarkdownSplitter, TextSplitter};
16use tiktoken_rs::cl100k_base;
17
18/// Compute embeddings for the given input.
19///
20/// # Argument
21///
22/// * `embedding_request` - The embedding request.
23///
24/// # Returns
25///
26/// The embeddings response.
27pub async fn embeddings(
28    embedding_request: &EmbeddingRequest,
29) -> Result<EmbeddingsResponse, LlamaCoreError> {
30    #[cfg(feature = "logging")]
31    info!(target: "stdout", "Computing embeddings");
32
33    let running_mode = running_mode()?;
34    if !running_mode.contains(RunningMode::EMBEDDINGS) && !running_mode.contains(RunningMode::RAG) {
35        let err_msg = "Computing embeddings is only supported in the embeddings and rag modes.";
36
37        #[cfg(feature = "logging")]
38        error!(target: "stdout", "{}", err_msg);
39
40        return Err(LlamaCoreError::Operation(err_msg.into()));
41    }
42
43    let model_name = &embedding_request.model;
44
45    let embedding_reponse = {
46        // For general embedding scenario, the embedding model is the same as the chat model.
47        // For RAG scenario, the embedding model is different from the chat model.
48        let embedding_graphs = match EMBEDDING_GRAPHS.get() {
49            Some(embedding_graphs) => embedding_graphs,
50            None => match CHAT_GRAPHS.get() {
51                Some(chat_graphs) => chat_graphs,
52                None => {
53                    let err_msg = "No embedding model is available.";
54
55                    #[cfg(feature = "logging")]
56                    error!(target: "stdout", "{}", err_msg);
57
58                    return Err(LlamaCoreError::Operation(err_msg.into()));
59                }
60            },
61        };
62
63        let mut embedding_graphs = embedding_graphs.lock().map_err(|e| {
64            let err_msg = format!("Fail to acquire the lock of `EMBEDDING_GRAPHS`. {}", e);
65
66            #[cfg(feature = "logging")]
67            error!(target: "stdout", "{}", &err_msg);
68
69            LlamaCoreError::Operation(err_msg)
70        })?;
71
72        let graph = match model_name {
73            Some(model_name) if embedding_graphs.contains_key(model_name) => {
74                embedding_graphs.get_mut(model_name).unwrap()
75            }
76            _ => match embedding_graphs.iter_mut().next() {
77                Some((_, graph)) => graph,
78                None => {
79                    let err_msg = "Not found available model in the embedding graphs.";
80
81                    #[cfg(feature = "logging")]
82                    error!(target: "stdout", "{}", &err_msg);
83
84                    return Err(LlamaCoreError::Operation(err_msg.into()));
85                }
86            },
87        };
88
89        // check if the `embedding` option of metadata is enabled
90        if !graph.metadata.embeddings {
91            graph.metadata.embeddings = true;
92            graph.update_metadata()?;
93        }
94
95        // compute embeddings
96        let (data, usage) = match &embedding_request.input {
97            InputText::String(text) => compute_embeddings(graph, &[text.to_owned()])?,
98            InputText::ArrayOfStrings(texts) => compute_embeddings(graph, texts.as_slice())?,
99            InputText::ArrayOfTokens(tokens) => {
100                let texts: Vec<String> = tokens.iter().map(|t| t.to_string()).collect();
101                compute_embeddings(graph, texts.as_slice())?
102            }
103            InputText::ArrayOfTokenArrays(token_arrays) => {
104                let texts: Vec<String> = token_arrays
105                    .iter()
106                    .map(|tokens| {
107                        tokens
108                            .iter()
109                            .map(|t| t.to_string())
110                            .collect::<Vec<String>>()
111                            .join(" ")
112                    })
113                    .collect();
114                compute_embeddings(graph, texts.as_slice())?
115            }
116        };
117
118        EmbeddingsResponse {
119            object: String::from("list"),
120            data,
121            model: graph.name().to_owned(),
122            usage,
123        }
124    };
125
126    #[cfg(feature = "logging")]
127    info!(target: "stdout", "Reset the model metadata");
128
129    // reset the model metadata
130    reset_model_metadata(model_name.as_ref())?;
131
132    Ok(embedding_reponse)
133}
134
135fn compute_embeddings(
136    graph: &mut Graph<GgmlMetadata>,
137    input: &[String],
138) -> Result<(Vec<EmbeddingObject>, Usage), LlamaCoreError> {
139    #[cfg(feature = "logging")]
140    info!(target: "stdout", "Compute embeddings for {} chunks", input.len());
141
142    // compute embeddings
143    let mut embeddings: Vec<EmbeddingObject> = Vec::new();
144    let mut usage = Usage::default();
145    for (idx, input) in input.iter().enumerate() {
146        // set input
147        let tensor_data = input.as_bytes().to_vec();
148        graph
149            .set_input(0, wasmedge_wasi_nn::TensorType::U8, &[1], &tensor_data)
150            .map_err(|e| {
151                let err_msg = e.to_string();
152
153                #[cfg(feature = "logging")]
154                error!(target: "stdout", "{}", &err_msg);
155
156                LlamaCoreError::Backend(BackendError::SetInput(err_msg))
157            })?;
158
159        #[cfg(feature = "logging")]
160        debug!(target: "stdout", "compute embeddings for chunk {}", idx + 1);
161
162        match graph.compute() {
163            Ok(_) => {
164                // Retrieve the output.
165                let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
166
167                // convert inference result to string
168                let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
169                    let err_msg = format!(
170                        "Failed to decode the buffer of the inference result to a utf-8 string. Reason: {}",
171                        e
172                    );
173
174                    #[cfg(feature = "logging")]
175                    error!(target: "stdout", "{}", &err_msg);
176
177                    LlamaCoreError::Operation(err_msg)
178                })?;
179
180                // deserialize the embedding data
181                let embedding = serde_json::from_str::<Embedding>(output).map_err(|e| {
182                    let err_msg =
183                        format!("Failed to deserialize the embedding data. Reason: {}", e);
184
185                    #[cfg(feature = "logging")]
186                    error!(target: "stdout", "{}", &err_msg);
187
188                    LlamaCoreError::Operation(err_msg)
189                })?;
190
191                let embedding_object = EmbeddingObject {
192                    index: idx as u64,
193                    object: String::from("embedding"),
194                    embedding: embedding.data,
195                };
196
197                embeddings.push(embedding_object);
198
199                // retrieve the number of prompt and completion tokens
200                let token_info = get_token_info_by_graph(graph)?;
201
202                usage.prompt_tokens += token_info.prompt_tokens;
203                usage.completion_tokens += token_info.completion_tokens;
204                usage.total_tokens = usage.prompt_tokens + usage.completion_tokens;
205            }
206            Err(e) => {
207                let err_msg = format!("Failed to compute embeddings. Reason: {}", e);
208
209                #[cfg(feature = "logging")]
210                error!(target: "stdout", "{}", &err_msg);
211
212                return Err(LlamaCoreError::Backend(BackendError::Compute(err_msg)));
213            }
214        }
215    }
216
217    #[cfg(feature = "logging")]
218    info!(target: "stdout", "token usage of embeddings: {} prompt tokens, {} comletion tokens", usage.prompt_tokens, usage.completion_tokens);
219
220    Ok((embeddings, usage))
221}
222
223/// Get the dimension of the embedding model.
224///
225/// # Arguments
226///
227/// * `name` - The name of the embedding model. If `None`, the dimension of the first model will be returned.
228///
229/// # Returns
230///
231/// The dimension of the embedding model.
232///
233/// # Errors
234///
235/// * The model does not exist in the embedding graphs.
236/// * No embedding model is available.
237pub fn dimension(name: Option<&str>) -> Result<u64, LlamaCoreError> {
238    let embedding_graphs = match EMBEDDING_GRAPHS.get() {
239        Some(embedding_graphs) => embedding_graphs,
240        None => {
241            let err_msg = "Fail to get the underlying value of `EMBEDDING_GRAPHS`.";
242
243            #[cfg(feature = "logging")]
244            error!(target: "stdout", "{}", err_msg);
245
246            return Err(LlamaCoreError::Operation(err_msg.into()));
247        }
248    };
249
250    let embedding_graphs = embedding_graphs.lock().map_err(|e| {
251        let err_msg = format!("Fail to acquire the lock of `EMBEDDING_GRAPHS`. {}", e);
252
253        #[cfg(feature = "logging")]
254        error!(target: "stdout", "{}", &err_msg);
255
256        LlamaCoreError::Operation(err_msg)
257    })?;
258
259    match name {
260        Some(model_name) => match embedding_graphs.get(model_name) {
261            Some(graph) => Ok(graph.metadata.ctx_size),
262            None => {
263                let err_msg = format!(
264                    "The model `{}` does not exist in the embedding graphs.",
265                    model_name
266                );
267
268                #[cfg(feature = "logging")]
269                error!(target: "stdout", "{}", &err_msg);
270
271                Err(LlamaCoreError::Operation(err_msg))
272            }
273        },
274        None => {
275            if !embedding_graphs.is_empty() {
276                let graph = match embedding_graphs.values().next() {
277                    Some(graph) => graph,
278                    None => {
279                        let err_msg = "Fail to get the underlying value of `EMBEDDING_GRAPHS`.";
280
281                        #[cfg(feature = "logging")]
282                        error!(target: "stdout", "{}", err_msg);
283
284                        return Err(LlamaCoreError::Operation(err_msg.into()));
285                    }
286                };
287
288                Ok(graph.metadata.ctx_size)
289            } else {
290                let err_msg = "There is no model available in the embedding graphs.";
291
292                #[cfg(feature = "logging")]
293                error!(target: "stdout", "{}", &err_msg);
294
295                Err(LlamaCoreError::Operation(err_msg.into()))
296            }
297        }
298    }
299}
300
301#[derive(Debug, Serialize, Deserialize)]
302struct Embedding {
303    #[serde(rename = "n_embedding")]
304    len: u64,
305    #[serde(rename = "embedding")]
306    data: Vec<f64>,
307}
308
309/// Generate a list of chunks from a given text. Each chunk will be up to the `chunk_capacity`.
310///
311/// # Arguments
312///
313/// * `text` - A reference to a text.
314///
315/// * `ty` - Type of the text, `txt` for text content or `md` for markdown content.
316///
317/// * `chunk_capacity` - The max tokens each chunk contains.
318///
319/// # Returns
320///
321/// A vector of strings.
322///
323/// # Errors
324///
325/// Returns an error if the operation fails.
326pub fn chunk_text(
327    text: impl AsRef<str>,
328    ty: impl AsRef<str>,
329    chunk_capacity: usize,
330) -> Result<Vec<String>, LlamaCoreError> {
331    if ty.as_ref().to_lowercase().as_str() != "txt" && ty.as_ref().to_lowercase().as_str() != "md" {
332        let err_msg = "Failed to upload the target file. Only files with 'txt' and 'md' extensions are supported.";
333
334        #[cfg(feature = "logging")]
335        error!(target: "stdout", "{}", err_msg);
336
337        return Err(LlamaCoreError::Operation(err_msg.into()));
338    }
339
340    match ty.as_ref().to_lowercase().as_str() {
341        "txt" => {
342            #[cfg(feature = "logging")]
343            info!(target: "stdout", "Chunk the plain text contents.");
344
345            let tokenizer = cl100k_base().map_err(|e| {
346                let err_msg = e.to_string();
347
348                #[cfg(feature = "logging")]
349                error!(target: "stdout", "{}", &err_msg);
350
351                LlamaCoreError::Operation(err_msg)
352            })?;
353
354            // create a text splitter
355            let splitter = TextSplitter::new(tokenizer).with_trim_chunks(true);
356
357            let chunks = splitter
358                .chunks(text.as_ref(), chunk_capacity)
359                .map(|s| s.to_string())
360                .collect::<Vec<_>>();
361
362            #[cfg(feature = "logging")]
363            info!(target: "stdout", "Number of chunks: {}", chunks.len());
364
365            Ok(chunks)
366        }
367        "md" => {
368            #[cfg(feature = "logging")]
369            info!(target: "stdout", "Chunk the markdown contents.");
370
371            let tokenizer = cl100k_base().map_err(|e| {
372                let err_msg = e.to_string();
373
374                #[cfg(feature = "logging")]
375                error!(target: "stdout", "{}", &err_msg);
376
377                LlamaCoreError::Operation(err_msg)
378            })?;
379
380            // create a markdown splitter
381            let splitter = MarkdownSplitter::new(tokenizer).with_trim_chunks(true);
382
383            let chunks = splitter
384                .chunks(text.as_ref(), chunk_capacity)
385                .map(|s| s.to_string())
386                .collect::<Vec<_>>();
387
388            #[cfg(feature = "logging")]
389            info!(target: "stdout", "Number of chunks: {}", chunks.len());
390
391            Ok(chunks)
392        }
393        _ => {
394            let err_msg =
395                "Failed to upload the target file. Only text and markdown files are supported.";
396
397            #[cfg(feature = "logging")]
398            error!(target: "stdout", "{}", err_msg);
399
400            Err(LlamaCoreError::Operation(err_msg.into()))
401        }
402    }
403}
404
405/// Get a copy of the metadata of the model.
406fn get_model_metadata(model_name: Option<&String>) -> Result<GgmlMetadata, LlamaCoreError> {
407    let embedding_graphs = match EMBEDDING_GRAPHS.get() {
408        Some(embedding_graphs) => embedding_graphs,
409        None => {
410            let err_msg = "Fail to get the underlying value of `EMBEDDING_GRAPHS`.";
411
412            #[cfg(feature = "logging")]
413            error!(target: "stdout", "{}", err_msg);
414
415            return Err(LlamaCoreError::Operation(err_msg.into()));
416        }
417    };
418
419    let embedding_graphs = embedding_graphs.lock().map_err(|e| {
420        let err_msg = format!("Fail to acquire the lock of `EMBEDDING_GRAPHS`. {}", e);
421
422        #[cfg(feature = "logging")]
423        error!(target: "stdout", "{}", &err_msg);
424
425        LlamaCoreError::Operation(err_msg)
426    })?;
427
428    match model_name {
429        Some(model_name) => match embedding_graphs.contains_key(model_name) {
430            true => {
431                let graph = embedding_graphs.get(model_name).unwrap();
432                Ok(graph.metadata.clone())
433            }
434            false => match embedding_graphs.iter().next() {
435                Some((_, graph)) => Ok(graph.metadata.clone()),
436                None => {
437                    let err_msg = "There is no model available in the embedding graphs.";
438
439                    #[cfg(feature = "logging")]
440                    error!(target: "stdout", "{}", &err_msg);
441
442                    Err(LlamaCoreError::Operation(err_msg.into()))
443                }
444            },
445        },
446        None => match embedding_graphs.iter().next() {
447            Some((_, graph)) => Ok(graph.metadata.clone()),
448            None => {
449                let err_msg = "There is no model available in the embedding graphs.";
450
451                #[cfg(feature = "logging")]
452                error!(target: "stdout", "{}", err_msg);
453
454                Err(LlamaCoreError::Operation(err_msg.into()))
455            }
456        },
457    }
458}
459
460fn update_model_metadata(
461    model_name: Option<&String>,
462    metadata: &GgmlMetadata,
463) -> Result<(), LlamaCoreError> {
464    let config = match serde_json::to_string(metadata) {
465        Ok(config) => config,
466        Err(e) => {
467            let err_msg = format!("Fail to serialize metadata to a JSON string. {}", e);
468
469            #[cfg(feature = "logging")]
470            error!(target: "stdout", "{}", &err_msg);
471
472            return Err(LlamaCoreError::Operation(err_msg));
473        }
474    };
475
476    let embedding_graphs = match EMBEDDING_GRAPHS.get() {
477        Some(embedding_graphs) => embedding_graphs,
478        None => {
479            let err_msg = "Fail to get the underlying value of `EMBEDDING_GRAPHS`.";
480
481            #[cfg(feature = "logging")]
482            error!(target: "stdout", "{}", err_msg);
483
484            return Err(LlamaCoreError::Operation(err_msg.into()));
485        }
486    };
487
488    let mut embedding_graphs = embedding_graphs.lock().map_err(|e| {
489        let err_msg = format!(
490            "Fail to acquire the lock of `EMBEDDING_GRAPHS`. Reason: {}",
491            e
492        );
493
494        #[cfg(feature = "logging")]
495        error!(target: "stdout", "{}", &err_msg);
496
497        LlamaCoreError::Operation(err_msg)
498    })?;
499
500    match model_name {
501        Some(model_name) => {
502            match embedding_graphs.contains_key(model_name) {
503                true => {
504                    let graph = embedding_graphs.get_mut(model_name).unwrap();
505                    // update metadata
506                    set_tensor_data_u8(graph, 1, config.as_bytes())
507                }
508                false => match embedding_graphs.iter_mut().next() {
509                    Some((_, graph)) => {
510                        // update metadata
511                        set_tensor_data_u8(graph, 1, config.as_bytes())
512                    }
513                    None => {
514                        let err_msg = "There is no model available in the embedding graphs.";
515
516                        #[cfg(feature = "logging")]
517                        error!(target: "stdout", "{}", &err_msg);
518
519                        Err(LlamaCoreError::Operation(err_msg.into()))
520                    }
521                },
522            }
523        }
524        None => {
525            match embedding_graphs.iter_mut().next() {
526                Some((_, graph)) => {
527                    // update metadata
528                    set_tensor_data_u8(graph, 1, config.as_bytes())
529                }
530                None => {
531                    let err_msg = "There is no model available in the embedding graphs.";
532
533                    #[cfg(feature = "logging")]
534                    error!(target: "stdout", "{}", err_msg);
535
536                    Err(LlamaCoreError::Operation(err_msg.into()))
537                }
538            }
539        }
540    }
541}
542
543fn reset_model_metadata(model_name: Option<&String>) -> Result<(), LlamaCoreError> {
544    // get metadata
545    let metadata = get_model_metadata(model_name)?;
546
547    // update model with the original metadata
548    update_model_metadata(model_name, &metadata)
549}