1use crate::{
4 error::{BackendError, LlamaCoreError},
5 metadata::ggml::GgmlMetadata,
6 running_mode,
7 utils::{get_output_buffer, get_token_info_by_graph, set_tensor_data_u8},
8 Graph, RunningMode, CHAT_GRAPHS, EMBEDDING_GRAPHS, OUTPUT_TENSOR,
9};
10use endpoints::{
11 common::Usage,
12 embeddings::{EmbeddingObject, EmbeddingRequest, EmbeddingsResponse, InputText},
13};
14use serde::{Deserialize, Serialize};
15use text_splitter::{MarkdownSplitter, TextSplitter};
16use tiktoken_rs::cl100k_base;
17
18pub async fn embeddings(
28 embedding_request: &EmbeddingRequest,
29) -> Result<EmbeddingsResponse, LlamaCoreError> {
30 #[cfg(feature = "logging")]
31 info!(target: "stdout", "Computing embeddings");
32
33 let running_mode = running_mode()?;
34 if !running_mode.contains(RunningMode::EMBEDDINGS) && !running_mode.contains(RunningMode::RAG) {
35 let err_msg = "Computing embeddings is only supported in the embeddings and rag modes.";
36
37 #[cfg(feature = "logging")]
38 error!(target: "stdout", "{err_msg}");
39
40 return Err(LlamaCoreError::Operation(err_msg.into()));
41 }
42
43 let model_name = &embedding_request.model;
44
45 let embedding_reponse = {
46 let embedding_graphs = match EMBEDDING_GRAPHS.get() {
49 Some(embedding_graphs) => embedding_graphs,
50 None => match CHAT_GRAPHS.get() {
51 Some(chat_graphs) => chat_graphs,
52 None => {
53 let err_msg = "No embedding model is available.";
54
55 #[cfg(feature = "logging")]
56 error!(target: "stdout", "{err_msg}");
57
58 return Err(LlamaCoreError::Operation(err_msg.into()));
59 }
60 },
61 };
62
63 let mut embedding_graphs = embedding_graphs.lock().map_err(|e| {
64 let err_msg = format!("Fail to acquire the lock of `EMBEDDING_GRAPHS`. {e}");
65
66 #[cfg(feature = "logging")]
67 error!(target: "stdout", "{}", &err_msg);
68
69 LlamaCoreError::Operation(err_msg)
70 })?;
71
72 let graph = match model_name {
73 Some(model_name) if embedding_graphs.contains_key(model_name) => {
74 embedding_graphs.get_mut(model_name).unwrap()
75 }
76 _ => match embedding_graphs.iter_mut().next() {
77 Some((_, graph)) => graph,
78 None => {
79 let err_msg = "Not found available model in the embedding graphs.";
80
81 #[cfg(feature = "logging")]
82 error!(target: "stdout", "{}", &err_msg);
83
84 return Err(LlamaCoreError::Operation(err_msg.into()));
85 }
86 },
87 };
88
89 if !graph.metadata.embeddings {
91 graph.metadata.embeddings = true;
92 graph.update_metadata()?;
93 }
94
95 let (data, usage) = match &embedding_request.input {
97 InputText::String(text) => compute_embeddings(graph, &[text.to_owned()])?,
98 InputText::ArrayOfStrings(texts) => compute_embeddings(graph, texts.as_slice())?,
99 InputText::ArrayOfTokens(tokens) => {
100 let texts: Vec<String> = tokens.iter().map(|t| t.to_string()).collect();
101 compute_embeddings(graph, texts.as_slice())?
102 }
103 InputText::ArrayOfTokenArrays(token_arrays) => {
104 let texts: Vec<String> = token_arrays
105 .iter()
106 .map(|tokens| {
107 tokens
108 .iter()
109 .map(|t| t.to_string())
110 .collect::<Vec<String>>()
111 .join(" ")
112 })
113 .collect();
114 compute_embeddings(graph, texts.as_slice())?
115 }
116 };
117
118 EmbeddingsResponse {
119 object: String::from("list"),
120 data,
121 model: graph.name().to_owned(),
122 usage,
123 }
124 };
125
126 #[cfg(feature = "logging")]
127 info!(target: "stdout", "Reset the model metadata");
128
129 reset_model_metadata(model_name.as_ref())?;
131
132 Ok(embedding_reponse)
133}
134
135fn compute_embeddings(
136 graph: &mut Graph<GgmlMetadata>,
137 input: &[String],
138) -> Result<(Vec<EmbeddingObject>, Usage), LlamaCoreError> {
139 #[cfg(feature = "logging")]
140 info!(target: "stdout", "Compute embeddings for {} chunks", input.len());
141
142 let mut embeddings: Vec<EmbeddingObject> = Vec::new();
144 let mut usage = Usage::default();
145 for (idx, input) in input.iter().enumerate() {
146 let tensor_data = input.as_bytes().to_vec();
148 graph
149 .set_input(0, wasmedge_wasi_nn::TensorType::U8, &[1], &tensor_data)
150 .map_err(|e| {
151 let err_msg = e.to_string();
152
153 #[cfg(feature = "logging")]
154 error!(target: "stdout", "{}", &err_msg);
155
156 LlamaCoreError::Backend(BackendError::SetInput(err_msg))
157 })?;
158
159 #[cfg(feature = "logging")]
160 debug!(target: "stdout", "compute embeddings for chunk {}", idx + 1);
161
162 match graph.compute() {
163 Ok(_) => {
164 let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
166
167 let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
169 let err_msg = format!(
170 "Failed to decode the buffer of the inference result to a utf-8 string. Reason: {e}"
171 );
172
173 #[cfg(feature = "logging")]
174 error!(target: "stdout", "{}", &err_msg);
175
176 LlamaCoreError::Operation(err_msg)
177 })?;
178
179 let embedding = serde_json::from_str::<Embedding>(output).map_err(|e| {
181 let err_msg = format!("Failed to deserialize the embedding data. Reason: {e}");
182
183 #[cfg(feature = "logging")]
184 error!(target: "stdout", "{}", &err_msg);
185
186 LlamaCoreError::Operation(err_msg)
187 })?;
188
189 let embedding_object = EmbeddingObject {
190 index: idx as u64,
191 object: String::from("embedding"),
192 embedding: embedding.data,
193 };
194
195 embeddings.push(embedding_object);
196
197 let token_info = get_token_info_by_graph(graph)?;
199
200 usage.prompt_tokens += token_info.prompt_tokens;
201 usage.completion_tokens += token_info.completion_tokens;
202 usage.total_tokens = usage.prompt_tokens + usage.completion_tokens;
203 }
204 Err(e) => {
205 let err_msg = format!("Failed to compute embeddings. Reason: {e}");
206
207 #[cfg(feature = "logging")]
208 error!(target: "stdout", "{}", &err_msg);
209
210 return Err(LlamaCoreError::Backend(BackendError::Compute(err_msg)));
211 }
212 }
213 }
214
215 #[cfg(feature = "logging")]
216 info!(target: "stdout", "token usage of embeddings: {} prompt tokens, {} comletion tokens", usage.prompt_tokens, usage.completion_tokens);
217
218 Ok((embeddings, usage))
219}
220
221pub fn dimension(name: Option<&str>) -> Result<u64, LlamaCoreError> {
236 let embedding_graphs = match EMBEDDING_GRAPHS.get() {
237 Some(embedding_graphs) => embedding_graphs,
238 None => {
239 let err_msg = "Fail to get the underlying value of `EMBEDDING_GRAPHS`.";
240
241 #[cfg(feature = "logging")]
242 error!(target: "stdout", "{err_msg}");
243
244 return Err(LlamaCoreError::Operation(err_msg.into()));
245 }
246 };
247
248 let embedding_graphs = embedding_graphs.lock().map_err(|e| {
249 let err_msg = format!("Fail to acquire the lock of `EMBEDDING_GRAPHS`. {e}");
250
251 #[cfg(feature = "logging")]
252 error!(target: "stdout", "{}", &err_msg);
253
254 LlamaCoreError::Operation(err_msg)
255 })?;
256
257 match name {
258 Some(model_name) => match embedding_graphs.get(model_name) {
259 Some(graph) => Ok(graph.metadata.ctx_size),
260 None => {
261 let err_msg =
262 format!("The model `{model_name}` does not exist in the embedding graphs.");
263
264 #[cfg(feature = "logging")]
265 error!(target: "stdout", "{}", &err_msg);
266
267 Err(LlamaCoreError::Operation(err_msg))
268 }
269 },
270 None => {
271 if !embedding_graphs.is_empty() {
272 let graph = match embedding_graphs.values().next() {
273 Some(graph) => graph,
274 None => {
275 let err_msg = "Fail to get the underlying value of `EMBEDDING_GRAPHS`.";
276
277 #[cfg(feature = "logging")]
278 error!(target: "stdout", "{err_msg}");
279
280 return Err(LlamaCoreError::Operation(err_msg.into()));
281 }
282 };
283
284 Ok(graph.metadata.ctx_size)
285 } else {
286 let err_msg = "There is no model available in the embedding graphs.";
287
288 #[cfg(feature = "logging")]
289 error!(target: "stdout", "{}", &err_msg);
290
291 Err(LlamaCoreError::Operation(err_msg.into()))
292 }
293 }
294 }
295}
296
297#[derive(Debug, Serialize, Deserialize)]
298struct Embedding {
299 #[serde(rename = "n_embedding")]
300 len: u64,
301 #[serde(rename = "embedding")]
302 data: Vec<f64>,
303}
304
305pub fn chunk_text(
323 text: impl AsRef<str>,
324 ty: impl AsRef<str>,
325 chunk_capacity: usize,
326) -> Result<Vec<String>, LlamaCoreError> {
327 if ty.as_ref().to_lowercase().as_str() != "txt" && ty.as_ref().to_lowercase().as_str() != "md" {
328 let err_msg = "Failed to upload the target file. Only files with 'txt' and 'md' extensions are supported.";
329
330 #[cfg(feature = "logging")]
331 error!(target: "stdout", "{err_msg}");
332
333 return Err(LlamaCoreError::Operation(err_msg.into()));
334 }
335
336 match ty.as_ref().to_lowercase().as_str() {
337 "txt" => {
338 #[cfg(feature = "logging")]
339 info!(target: "stdout", "Chunk the plain text contents.");
340
341 let tokenizer = cl100k_base().map_err(|e| {
342 let err_msg = e.to_string();
343
344 #[cfg(feature = "logging")]
345 error!(target: "stdout", "{}", &err_msg);
346
347 LlamaCoreError::Operation(err_msg)
348 })?;
349
350 let splitter = TextSplitter::new(tokenizer).with_trim_chunks(true);
352
353 let chunks = splitter
354 .chunks(text.as_ref(), chunk_capacity)
355 .map(|s| s.to_string())
356 .collect::<Vec<_>>();
357
358 #[cfg(feature = "logging")]
359 info!(target: "stdout", "Number of chunks: {}", chunks.len());
360
361 Ok(chunks)
362 }
363 "md" => {
364 #[cfg(feature = "logging")]
365 info!(target: "stdout", "Chunk the markdown contents.");
366
367 let tokenizer = cl100k_base().map_err(|e| {
368 let err_msg = e.to_string();
369
370 #[cfg(feature = "logging")]
371 error!(target: "stdout", "{}", &err_msg);
372
373 LlamaCoreError::Operation(err_msg)
374 })?;
375
376 let splitter = MarkdownSplitter::new(tokenizer).with_trim_chunks(true);
378
379 let chunks = splitter
380 .chunks(text.as_ref(), chunk_capacity)
381 .map(|s| s.to_string())
382 .collect::<Vec<_>>();
383
384 #[cfg(feature = "logging")]
385 info!(target: "stdout", "Number of chunks: {}", chunks.len());
386
387 Ok(chunks)
388 }
389 _ => {
390 let err_msg =
391 "Failed to upload the target file. Only text and markdown files are supported.";
392
393 #[cfg(feature = "logging")]
394 error!(target: "stdout", "{err_msg}");
395
396 Err(LlamaCoreError::Operation(err_msg.into()))
397 }
398 }
399}
400
401fn get_model_metadata(model_name: Option<&String>) -> Result<GgmlMetadata, LlamaCoreError> {
403 let embedding_graphs = match EMBEDDING_GRAPHS.get() {
404 Some(embedding_graphs) => embedding_graphs,
405 None => {
406 let err_msg = "Fail to get the underlying value of `EMBEDDING_GRAPHS`.";
407
408 #[cfg(feature = "logging")]
409 error!(target: "stdout", "{err_msg}");
410
411 return Err(LlamaCoreError::Operation(err_msg.into()));
412 }
413 };
414
415 let embedding_graphs = embedding_graphs.lock().map_err(|e| {
416 let err_msg = format!("Fail to acquire the lock of `EMBEDDING_GRAPHS`. {e}");
417
418 #[cfg(feature = "logging")]
419 error!(target: "stdout", "{}", &err_msg);
420
421 LlamaCoreError::Operation(err_msg)
422 })?;
423
424 match model_name {
425 Some(model_name) => match embedding_graphs.contains_key(model_name) {
426 true => {
427 let graph = embedding_graphs.get(model_name).unwrap();
428 Ok(graph.metadata.clone())
429 }
430 false => match embedding_graphs.iter().next() {
431 Some((_, graph)) => Ok(graph.metadata.clone()),
432 None => {
433 let err_msg = "There is no model available in the embedding graphs.";
434
435 #[cfg(feature = "logging")]
436 error!(target: "stdout", "{}", &err_msg);
437
438 Err(LlamaCoreError::Operation(err_msg.into()))
439 }
440 },
441 },
442 None => match embedding_graphs.iter().next() {
443 Some((_, graph)) => Ok(graph.metadata.clone()),
444 None => {
445 let err_msg = "There is no model available in the embedding graphs.";
446
447 #[cfg(feature = "logging")]
448 error!(target: "stdout", "{err_msg}");
449
450 Err(LlamaCoreError::Operation(err_msg.into()))
451 }
452 },
453 }
454}
455
456fn update_model_metadata(
457 model_name: Option<&String>,
458 metadata: &GgmlMetadata,
459) -> Result<(), LlamaCoreError> {
460 let config = match serde_json::to_string(metadata) {
461 Ok(config) => config,
462 Err(e) => {
463 let err_msg = format!("Fail to serialize metadata to a JSON string. {e}");
464
465 #[cfg(feature = "logging")]
466 error!(target: "stdout", "{}", &err_msg);
467
468 return Err(LlamaCoreError::Operation(err_msg));
469 }
470 };
471
472 let embedding_graphs = match EMBEDDING_GRAPHS.get() {
473 Some(embedding_graphs) => embedding_graphs,
474 None => {
475 let err_msg = "Fail to get the underlying value of `EMBEDDING_GRAPHS`.";
476
477 #[cfg(feature = "logging")]
478 error!(target: "stdout", "{err_msg}");
479
480 return Err(LlamaCoreError::Operation(err_msg.into()));
481 }
482 };
483
484 let mut embedding_graphs = embedding_graphs.lock().map_err(|e| {
485 let err_msg = format!("Fail to acquire the lock of `EMBEDDING_GRAPHS`. Reason: {e}");
486
487 #[cfg(feature = "logging")]
488 error!(target: "stdout", "{}", &err_msg);
489
490 LlamaCoreError::Operation(err_msg)
491 })?;
492
493 match model_name {
494 Some(model_name) => {
495 match embedding_graphs.contains_key(model_name) {
496 true => {
497 let graph = embedding_graphs.get_mut(model_name).unwrap();
498 set_tensor_data_u8(graph, 1, config.as_bytes())
500 }
501 false => match embedding_graphs.iter_mut().next() {
502 Some((_, graph)) => {
503 set_tensor_data_u8(graph, 1, config.as_bytes())
505 }
506 None => {
507 let err_msg = "There is no model available in the embedding graphs.";
508
509 #[cfg(feature = "logging")]
510 error!(target: "stdout", "{}", &err_msg);
511
512 Err(LlamaCoreError::Operation(err_msg.into()))
513 }
514 },
515 }
516 }
517 None => {
518 match embedding_graphs.iter_mut().next() {
519 Some((_, graph)) => {
520 set_tensor_data_u8(graph, 1, config.as_bytes())
522 }
523 None => {
524 let err_msg = "There is no model available in the embedding graphs.";
525
526 #[cfg(feature = "logging")]
527 error!(target: "stdout", "{err_msg}");
528
529 Err(LlamaCoreError::Operation(err_msg.into()))
530 }
531 }
532 }
533 }
534}
535
536fn reset_model_metadata(model_name: Option<&String>) -> Result<(), LlamaCoreError> {
537 let metadata = get_model_metadata(model_name)?;
539
540 update_model_metadata(model_name, &metadata)
542}