llama_core/
embeddings.rs

1//! Define APIs for computing embeddings.
2
3use crate::{
4    error::{BackendError, LlamaCoreError},
5    metadata::ggml::GgmlMetadata,
6    running_mode,
7    utils::{get_output_buffer, get_token_info_by_graph, set_tensor_data_u8},
8    Graph, RunningMode, CHAT_GRAPHS, EMBEDDING_GRAPHS, OUTPUT_TENSOR,
9};
10use endpoints::{
11    common::Usage,
12    embeddings::{EmbeddingObject, EmbeddingRequest, EmbeddingsResponse, InputText},
13};
14use serde::{Deserialize, Serialize};
15use text_splitter::{MarkdownSplitter, TextSplitter};
16use tiktoken_rs::cl100k_base;
17
18/// Compute embeddings for the given input.
19///
20/// # Argument
21///
22/// * `embedding_request` - The embedding request.
23///
24/// # Returns
25///
26/// The embeddings response.
27pub async fn embeddings(
28    embedding_request: &EmbeddingRequest,
29) -> Result<EmbeddingsResponse, LlamaCoreError> {
30    #[cfg(feature = "logging")]
31    info!(target: "stdout", "Computing embeddings");
32
33    let running_mode = running_mode()?;
34    if !running_mode.contains(RunningMode::EMBEDDINGS) && !running_mode.contains(RunningMode::RAG) {
35        let err_msg = "Computing embeddings is only supported in the embeddings and rag modes.";
36
37        #[cfg(feature = "logging")]
38        error!(target: "stdout", "{err_msg}");
39
40        return Err(LlamaCoreError::Operation(err_msg.into()));
41    }
42
43    let model_name = &embedding_request.model;
44
45    let embedding_reponse = {
46        // For general embedding scenario, the embedding model is the same as the chat model.
47        // For RAG scenario, the embedding model is different from the chat model.
48        let embedding_graphs = match EMBEDDING_GRAPHS.get() {
49            Some(embedding_graphs) => embedding_graphs,
50            None => match CHAT_GRAPHS.get() {
51                Some(chat_graphs) => chat_graphs,
52                None => {
53                    let err_msg = "No embedding model is available.";
54
55                    #[cfg(feature = "logging")]
56                    error!(target: "stdout", "{err_msg}");
57
58                    return Err(LlamaCoreError::Operation(err_msg.into()));
59                }
60            },
61        };
62
63        let mut embedding_graphs = embedding_graphs.lock().map_err(|e| {
64            let err_msg = format!("Fail to acquire the lock of `EMBEDDING_GRAPHS`. {e}");
65
66            #[cfg(feature = "logging")]
67            error!(target: "stdout", "{}", &err_msg);
68
69            LlamaCoreError::Operation(err_msg)
70        })?;
71
72        let graph = match model_name {
73            Some(model_name) if embedding_graphs.contains_key(model_name) => {
74                embedding_graphs.get_mut(model_name).unwrap()
75            }
76            _ => match embedding_graphs.iter_mut().next() {
77                Some((_, graph)) => graph,
78                None => {
79                    let err_msg = "Not found available model in the embedding graphs.";
80
81                    #[cfg(feature = "logging")]
82                    error!(target: "stdout", "{}", &err_msg);
83
84                    return Err(LlamaCoreError::Operation(err_msg.into()));
85                }
86            },
87        };
88
89        // check if the `embedding` option of metadata is enabled
90        if !graph.metadata.embeddings {
91            graph.metadata.embeddings = true;
92            graph.update_metadata()?;
93        }
94
95        // compute embeddings
96        let (data, usage) = match &embedding_request.input {
97            InputText::String(text) => compute_embeddings(graph, &[text.to_owned()])?,
98            InputText::ArrayOfStrings(texts) => compute_embeddings(graph, texts.as_slice())?,
99            InputText::ArrayOfTokens(tokens) => {
100                let texts: Vec<String> = tokens.iter().map(|t| t.to_string()).collect();
101                compute_embeddings(graph, texts.as_slice())?
102            }
103            InputText::ArrayOfTokenArrays(token_arrays) => {
104                let texts: Vec<String> = token_arrays
105                    .iter()
106                    .map(|tokens| {
107                        tokens
108                            .iter()
109                            .map(|t| t.to_string())
110                            .collect::<Vec<String>>()
111                            .join(" ")
112                    })
113                    .collect();
114                compute_embeddings(graph, texts.as_slice())?
115            }
116        };
117
118        EmbeddingsResponse {
119            object: String::from("list"),
120            data,
121            model: graph.name().to_owned(),
122            usage,
123        }
124    };
125
126    #[cfg(feature = "logging")]
127    info!(target: "stdout", "Reset the model metadata");
128
129    // reset the model metadata
130    reset_model_metadata(model_name.as_ref())?;
131
132    Ok(embedding_reponse)
133}
134
135fn compute_embeddings(
136    graph: &mut Graph<GgmlMetadata>,
137    input: &[String],
138) -> Result<(Vec<EmbeddingObject>, Usage), LlamaCoreError> {
139    #[cfg(feature = "logging")]
140    info!(target: "stdout", "Compute embeddings for {} chunks", input.len());
141
142    // compute embeddings
143    let mut embeddings: Vec<EmbeddingObject> = Vec::new();
144    let mut usage = Usage::default();
145    for (idx, input) in input.iter().enumerate() {
146        // set input
147        let tensor_data = input.as_bytes().to_vec();
148        graph
149            .set_input(0, wasmedge_wasi_nn::TensorType::U8, &[1], &tensor_data)
150            .map_err(|e| {
151                let err_msg = e.to_string();
152
153                #[cfg(feature = "logging")]
154                error!(target: "stdout", "{}", &err_msg);
155
156                LlamaCoreError::Backend(BackendError::SetInput(err_msg))
157            })?;
158
159        #[cfg(feature = "logging")]
160        debug!(target: "stdout", "compute embeddings for chunk {}", idx + 1);
161
162        match graph.compute() {
163            Ok(_) => {
164                // Retrieve the output.
165                let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
166
167                // convert inference result to string
168                let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
169                    let err_msg = format!(
170                        "Failed to decode the buffer of the inference result to a utf-8 string. Reason: {e}"
171                    );
172
173                    #[cfg(feature = "logging")]
174                    error!(target: "stdout", "{}", &err_msg);
175
176                    LlamaCoreError::Operation(err_msg)
177                })?;
178
179                // deserialize the embedding data
180                let embedding = serde_json::from_str::<Embedding>(output).map_err(|e| {
181                    let err_msg = format!("Failed to deserialize the embedding data. Reason: {e}");
182
183                    #[cfg(feature = "logging")]
184                    error!(target: "stdout", "{}", &err_msg);
185
186                    LlamaCoreError::Operation(err_msg)
187                })?;
188
189                let embedding_object = EmbeddingObject {
190                    index: idx as u64,
191                    object: String::from("embedding"),
192                    embedding: embedding.data,
193                };
194
195                embeddings.push(embedding_object);
196
197                // retrieve the number of prompt and completion tokens
198                let token_info = get_token_info_by_graph(graph)?;
199
200                usage.prompt_tokens += token_info.prompt_tokens;
201                usage.completion_tokens += token_info.completion_tokens;
202                usage.total_tokens = usage.prompt_tokens + usage.completion_tokens;
203            }
204            Err(e) => {
205                let err_msg = format!("Failed to compute embeddings. Reason: {e}");
206
207                #[cfg(feature = "logging")]
208                error!(target: "stdout", "{}", &err_msg);
209
210                return Err(LlamaCoreError::Backend(BackendError::Compute(err_msg)));
211            }
212        }
213    }
214
215    #[cfg(feature = "logging")]
216    info!(target: "stdout", "token usage of embeddings: {} prompt tokens, {} comletion tokens", usage.prompt_tokens, usage.completion_tokens);
217
218    Ok((embeddings, usage))
219}
220
221/// Get the dimension of the embedding model.
222///
223/// # Arguments
224///
225/// * `name` - The name of the embedding model. If `None`, the dimension of the first model will be returned.
226///
227/// # Returns
228///
229/// The dimension of the embedding model.
230///
231/// # Errors
232///
233/// * The model does not exist in the embedding graphs.
234/// * No embedding model is available.
235pub fn dimension(name: Option<&str>) -> Result<u64, LlamaCoreError> {
236    let embedding_graphs = match EMBEDDING_GRAPHS.get() {
237        Some(embedding_graphs) => embedding_graphs,
238        None => {
239            let err_msg = "Fail to get the underlying value of `EMBEDDING_GRAPHS`.";
240
241            #[cfg(feature = "logging")]
242            error!(target: "stdout", "{err_msg}");
243
244            return Err(LlamaCoreError::Operation(err_msg.into()));
245        }
246    };
247
248    let embedding_graphs = embedding_graphs.lock().map_err(|e| {
249        let err_msg = format!("Fail to acquire the lock of `EMBEDDING_GRAPHS`. {e}");
250
251        #[cfg(feature = "logging")]
252        error!(target: "stdout", "{}", &err_msg);
253
254        LlamaCoreError::Operation(err_msg)
255    })?;
256
257    match name {
258        Some(model_name) => match embedding_graphs.get(model_name) {
259            Some(graph) => Ok(graph.metadata.ctx_size),
260            None => {
261                let err_msg =
262                    format!("The model `{model_name}` does not exist in the embedding graphs.");
263
264                #[cfg(feature = "logging")]
265                error!(target: "stdout", "{}", &err_msg);
266
267                Err(LlamaCoreError::Operation(err_msg))
268            }
269        },
270        None => {
271            if !embedding_graphs.is_empty() {
272                let graph = match embedding_graphs.values().next() {
273                    Some(graph) => graph,
274                    None => {
275                        let err_msg = "Fail to get the underlying value of `EMBEDDING_GRAPHS`.";
276
277                        #[cfg(feature = "logging")]
278                        error!(target: "stdout", "{err_msg}");
279
280                        return Err(LlamaCoreError::Operation(err_msg.into()));
281                    }
282                };
283
284                Ok(graph.metadata.ctx_size)
285            } else {
286                let err_msg = "There is no model available in the embedding graphs.";
287
288                #[cfg(feature = "logging")]
289                error!(target: "stdout", "{}", &err_msg);
290
291                Err(LlamaCoreError::Operation(err_msg.into()))
292            }
293        }
294    }
295}
296
297#[derive(Debug, Serialize, Deserialize)]
298struct Embedding {
299    #[serde(rename = "n_embedding")]
300    len: u64,
301    #[serde(rename = "embedding")]
302    data: Vec<f64>,
303}
304
305/// Generate a list of chunks from a given text. Each chunk will be up to the `chunk_capacity`.
306///
307/// # Arguments
308///
309/// * `text` - A reference to a text.
310///
311/// * `ty` - Type of the text, `txt` for text content or `md` for markdown content.
312///
313/// * `chunk_capacity` - The max tokens each chunk contains.
314///
315/// # Returns
316///
317/// A vector of strings.
318///
319/// # Errors
320///
321/// Returns an error if the operation fails.
322pub fn chunk_text(
323    text: impl AsRef<str>,
324    ty: impl AsRef<str>,
325    chunk_capacity: usize,
326) -> Result<Vec<String>, LlamaCoreError> {
327    if ty.as_ref().to_lowercase().as_str() != "txt" && ty.as_ref().to_lowercase().as_str() != "md" {
328        let err_msg = "Failed to upload the target file. Only files with 'txt' and 'md' extensions are supported.";
329
330        #[cfg(feature = "logging")]
331        error!(target: "stdout", "{err_msg}");
332
333        return Err(LlamaCoreError::Operation(err_msg.into()));
334    }
335
336    match ty.as_ref().to_lowercase().as_str() {
337        "txt" => {
338            #[cfg(feature = "logging")]
339            info!(target: "stdout", "Chunk the plain text contents.");
340
341            let tokenizer = cl100k_base().map_err(|e| {
342                let err_msg = e.to_string();
343
344                #[cfg(feature = "logging")]
345                error!(target: "stdout", "{}", &err_msg);
346
347                LlamaCoreError::Operation(err_msg)
348            })?;
349
350            // create a text splitter
351            let splitter = TextSplitter::new(tokenizer).with_trim_chunks(true);
352
353            let chunks = splitter
354                .chunks(text.as_ref(), chunk_capacity)
355                .map(|s| s.to_string())
356                .collect::<Vec<_>>();
357
358            #[cfg(feature = "logging")]
359            info!(target: "stdout", "Number of chunks: {}", chunks.len());
360
361            Ok(chunks)
362        }
363        "md" => {
364            #[cfg(feature = "logging")]
365            info!(target: "stdout", "Chunk the markdown contents.");
366
367            let tokenizer = cl100k_base().map_err(|e| {
368                let err_msg = e.to_string();
369
370                #[cfg(feature = "logging")]
371                error!(target: "stdout", "{}", &err_msg);
372
373                LlamaCoreError::Operation(err_msg)
374            })?;
375
376            // create a markdown splitter
377            let splitter = MarkdownSplitter::new(tokenizer).with_trim_chunks(true);
378
379            let chunks = splitter
380                .chunks(text.as_ref(), chunk_capacity)
381                .map(|s| s.to_string())
382                .collect::<Vec<_>>();
383
384            #[cfg(feature = "logging")]
385            info!(target: "stdout", "Number of chunks: {}", chunks.len());
386
387            Ok(chunks)
388        }
389        _ => {
390            let err_msg =
391                "Failed to upload the target file. Only text and markdown files are supported.";
392
393            #[cfg(feature = "logging")]
394            error!(target: "stdout", "{err_msg}");
395
396            Err(LlamaCoreError::Operation(err_msg.into()))
397        }
398    }
399}
400
401/// Get a copy of the metadata of the model.
402fn get_model_metadata(model_name: Option<&String>) -> Result<GgmlMetadata, LlamaCoreError> {
403    let embedding_graphs = match EMBEDDING_GRAPHS.get() {
404        Some(embedding_graphs) => embedding_graphs,
405        None => {
406            let err_msg = "Fail to get the underlying value of `EMBEDDING_GRAPHS`.";
407
408            #[cfg(feature = "logging")]
409            error!(target: "stdout", "{err_msg}");
410
411            return Err(LlamaCoreError::Operation(err_msg.into()));
412        }
413    };
414
415    let embedding_graphs = embedding_graphs.lock().map_err(|e| {
416        let err_msg = format!("Fail to acquire the lock of `EMBEDDING_GRAPHS`. {e}");
417
418        #[cfg(feature = "logging")]
419        error!(target: "stdout", "{}", &err_msg);
420
421        LlamaCoreError::Operation(err_msg)
422    })?;
423
424    match model_name {
425        Some(model_name) => match embedding_graphs.contains_key(model_name) {
426            true => {
427                let graph = embedding_graphs.get(model_name).unwrap();
428                Ok(graph.metadata.clone())
429            }
430            false => match embedding_graphs.iter().next() {
431                Some((_, graph)) => Ok(graph.metadata.clone()),
432                None => {
433                    let err_msg = "There is no model available in the embedding graphs.";
434
435                    #[cfg(feature = "logging")]
436                    error!(target: "stdout", "{}", &err_msg);
437
438                    Err(LlamaCoreError::Operation(err_msg.into()))
439                }
440            },
441        },
442        None => match embedding_graphs.iter().next() {
443            Some((_, graph)) => Ok(graph.metadata.clone()),
444            None => {
445                let err_msg = "There is no model available in the embedding graphs.";
446
447                #[cfg(feature = "logging")]
448                error!(target: "stdout", "{err_msg}");
449
450                Err(LlamaCoreError::Operation(err_msg.into()))
451            }
452        },
453    }
454}
455
456fn update_model_metadata(
457    model_name: Option<&String>,
458    metadata: &GgmlMetadata,
459) -> Result<(), LlamaCoreError> {
460    let config = match serde_json::to_string(metadata) {
461        Ok(config) => config,
462        Err(e) => {
463            let err_msg = format!("Fail to serialize metadata to a JSON string. {e}");
464
465            #[cfg(feature = "logging")]
466            error!(target: "stdout", "{}", &err_msg);
467
468            return Err(LlamaCoreError::Operation(err_msg));
469        }
470    };
471
472    let embedding_graphs = match EMBEDDING_GRAPHS.get() {
473        Some(embedding_graphs) => embedding_graphs,
474        None => {
475            let err_msg = "Fail to get the underlying value of `EMBEDDING_GRAPHS`.";
476
477            #[cfg(feature = "logging")]
478            error!(target: "stdout", "{err_msg}");
479
480            return Err(LlamaCoreError::Operation(err_msg.into()));
481        }
482    };
483
484    let mut embedding_graphs = embedding_graphs.lock().map_err(|e| {
485        let err_msg = format!("Fail to acquire the lock of `EMBEDDING_GRAPHS`. Reason: {e}");
486
487        #[cfg(feature = "logging")]
488        error!(target: "stdout", "{}", &err_msg);
489
490        LlamaCoreError::Operation(err_msg)
491    })?;
492
493    match model_name {
494        Some(model_name) => {
495            match embedding_graphs.contains_key(model_name) {
496                true => {
497                    let graph = embedding_graphs.get_mut(model_name).unwrap();
498                    // update metadata
499                    set_tensor_data_u8(graph, 1, config.as_bytes())
500                }
501                false => match embedding_graphs.iter_mut().next() {
502                    Some((_, graph)) => {
503                        // update metadata
504                        set_tensor_data_u8(graph, 1, config.as_bytes())
505                    }
506                    None => {
507                        let err_msg = "There is no model available in the embedding graphs.";
508
509                        #[cfg(feature = "logging")]
510                        error!(target: "stdout", "{}", &err_msg);
511
512                        Err(LlamaCoreError::Operation(err_msg.into()))
513                    }
514                },
515            }
516        }
517        None => {
518            match embedding_graphs.iter_mut().next() {
519                Some((_, graph)) => {
520                    // update metadata
521                    set_tensor_data_u8(graph, 1, config.as_bytes())
522                }
523                None => {
524                    let err_msg = "There is no model available in the embedding graphs.";
525
526                    #[cfg(feature = "logging")]
527                    error!(target: "stdout", "{err_msg}");
528
529                    Err(LlamaCoreError::Operation(err_msg.into()))
530                }
531            }
532        }
533    }
534}
535
536fn reset_model_metadata(model_name: Option<&String>) -> Result<(), LlamaCoreError> {
537    // get metadata
538    let metadata = get_model_metadata(model_name)?;
539
540    // update model with the original metadata
541    update_model_metadata(model_name, &metadata)
542}