1use crate::{
4 error::{BackendError, LlamaCoreError},
5 metadata::ggml::GgmlMetadata,
6 running_mode,
7 utils::{get_output_buffer, get_token_info_by_graph, set_tensor_data_u8},
8 Graph, RunningMode, CHAT_GRAPHS, EMBEDDING_GRAPHS, OUTPUT_TENSOR,
9};
10use endpoints::{
11 common::Usage,
12 embeddings::{EmbeddingObject, EmbeddingRequest, EmbeddingsResponse, InputText},
13};
14use serde::{Deserialize, Serialize};
15use text_splitter::{MarkdownSplitter, TextSplitter};
16use tiktoken_rs::cl100k_base;
17
18pub async fn embeddings(
28 embedding_request: &EmbeddingRequest,
29) -> Result<EmbeddingsResponse, LlamaCoreError> {
30 #[cfg(feature = "logging")]
31 info!(target: "stdout", "Computing embeddings");
32
33 let running_mode = running_mode()?;
34 if !running_mode.contains(RunningMode::EMBEDDINGS) && !running_mode.contains(RunningMode::RAG) {
35 let err_msg = "Computing embeddings is only supported in the embeddings and rag modes.";
36
37 #[cfg(feature = "logging")]
38 error!(target: "stdout", "{}", err_msg);
39
40 return Err(LlamaCoreError::Operation(err_msg.into()));
41 }
42
43 let model_name = &embedding_request.model;
44
45 let embedding_reponse = {
46 let embedding_graphs = match EMBEDDING_GRAPHS.get() {
49 Some(embedding_graphs) => embedding_graphs,
50 None => match CHAT_GRAPHS.get() {
51 Some(chat_graphs) => chat_graphs,
52 None => {
53 let err_msg = "No embedding model is available.";
54
55 #[cfg(feature = "logging")]
56 error!(target: "stdout", "{}", err_msg);
57
58 return Err(LlamaCoreError::Operation(err_msg.into()));
59 }
60 },
61 };
62
63 let mut embedding_graphs = embedding_graphs.lock().map_err(|e| {
64 let err_msg = format!("Fail to acquire the lock of `EMBEDDING_GRAPHS`. {}", e);
65
66 #[cfg(feature = "logging")]
67 error!(target: "stdout", "{}", &err_msg);
68
69 LlamaCoreError::Operation(err_msg)
70 })?;
71
72 let graph = match model_name {
73 Some(model_name) if embedding_graphs.contains_key(model_name) => {
74 embedding_graphs.get_mut(model_name).unwrap()
75 }
76 _ => match embedding_graphs.iter_mut().next() {
77 Some((_, graph)) => graph,
78 None => {
79 let err_msg = "Not found available model in the embedding graphs.";
80
81 #[cfg(feature = "logging")]
82 error!(target: "stdout", "{}", &err_msg);
83
84 return Err(LlamaCoreError::Operation(err_msg.into()));
85 }
86 },
87 };
88
89 if !graph.metadata.embeddings {
91 graph.metadata.embeddings = true;
92 graph.update_metadata()?;
93 }
94
95 let (data, usage) = match &embedding_request.input {
97 InputText::String(text) => compute_embeddings(graph, &[text.to_owned()])?,
98 InputText::ArrayOfStrings(texts) => compute_embeddings(graph, texts.as_slice())?,
99 InputText::ArrayOfTokens(tokens) => {
100 let texts: Vec<String> = tokens.iter().map(|t| t.to_string()).collect();
101 compute_embeddings(graph, texts.as_slice())?
102 }
103 InputText::ArrayOfTokenArrays(token_arrays) => {
104 let texts: Vec<String> = token_arrays
105 .iter()
106 .map(|tokens| {
107 tokens
108 .iter()
109 .map(|t| t.to_string())
110 .collect::<Vec<String>>()
111 .join(" ")
112 })
113 .collect();
114 compute_embeddings(graph, texts.as_slice())?
115 }
116 };
117
118 EmbeddingsResponse {
119 object: String::from("list"),
120 data,
121 model: graph.name().to_owned(),
122 usage,
123 }
124 };
125
126 #[cfg(feature = "logging")]
127 info!(target: "stdout", "Reset the model metadata");
128
129 reset_model_metadata(model_name.as_ref())?;
131
132 Ok(embedding_reponse)
133}
134
135fn compute_embeddings(
136 graph: &mut Graph<GgmlMetadata>,
137 input: &[String],
138) -> Result<(Vec<EmbeddingObject>, Usage), LlamaCoreError> {
139 #[cfg(feature = "logging")]
140 info!(target: "stdout", "Compute embeddings for {} chunks", input.len());
141
142 let mut embeddings: Vec<EmbeddingObject> = Vec::new();
144 let mut usage = Usage::default();
145 for (idx, input) in input.iter().enumerate() {
146 let tensor_data = input.as_bytes().to_vec();
148 graph
149 .set_input(0, wasmedge_wasi_nn::TensorType::U8, &[1], &tensor_data)
150 .map_err(|e| {
151 let err_msg = e.to_string();
152
153 #[cfg(feature = "logging")]
154 error!(target: "stdout", "{}", &err_msg);
155
156 LlamaCoreError::Backend(BackendError::SetInput(err_msg))
157 })?;
158
159 #[cfg(feature = "logging")]
160 debug!(target: "stdout", "compute embeddings for chunk {}", idx + 1);
161
162 match graph.compute() {
163 Ok(_) => {
164 let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
166
167 let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
169 let err_msg = format!(
170 "Failed to decode the buffer of the inference result to a utf-8 string. Reason: {}",
171 e
172 );
173
174 #[cfg(feature = "logging")]
175 error!(target: "stdout", "{}", &err_msg);
176
177 LlamaCoreError::Operation(err_msg)
178 })?;
179
180 let embedding = serde_json::from_str::<Embedding>(output).map_err(|e| {
182 let err_msg =
183 format!("Failed to deserialize the embedding data. Reason: {}", e);
184
185 #[cfg(feature = "logging")]
186 error!(target: "stdout", "{}", &err_msg);
187
188 LlamaCoreError::Operation(err_msg)
189 })?;
190
191 let embedding_object = EmbeddingObject {
192 index: idx as u64,
193 object: String::from("embedding"),
194 embedding: embedding.data,
195 };
196
197 embeddings.push(embedding_object);
198
199 let token_info = get_token_info_by_graph(graph)?;
201
202 usage.prompt_tokens += token_info.prompt_tokens;
203 usage.completion_tokens += token_info.completion_tokens;
204 usage.total_tokens = usage.prompt_tokens + usage.completion_tokens;
205 }
206 Err(e) => {
207 let err_msg = format!("Failed to compute embeddings. Reason: {}", e);
208
209 #[cfg(feature = "logging")]
210 error!(target: "stdout", "{}", &err_msg);
211
212 return Err(LlamaCoreError::Backend(BackendError::Compute(err_msg)));
213 }
214 }
215 }
216
217 #[cfg(feature = "logging")]
218 info!(target: "stdout", "token usage of embeddings: {} prompt tokens, {} comletion tokens", usage.prompt_tokens, usage.completion_tokens);
219
220 Ok((embeddings, usage))
221}
222
223pub fn dimension(name: Option<&str>) -> Result<u64, LlamaCoreError> {
238 let embedding_graphs = match EMBEDDING_GRAPHS.get() {
239 Some(embedding_graphs) => embedding_graphs,
240 None => {
241 let err_msg = "Fail to get the underlying value of `EMBEDDING_GRAPHS`.";
242
243 #[cfg(feature = "logging")]
244 error!(target: "stdout", "{}", err_msg);
245
246 return Err(LlamaCoreError::Operation(err_msg.into()));
247 }
248 };
249
250 let embedding_graphs = embedding_graphs.lock().map_err(|e| {
251 let err_msg = format!("Fail to acquire the lock of `EMBEDDING_GRAPHS`. {}", e);
252
253 #[cfg(feature = "logging")]
254 error!(target: "stdout", "{}", &err_msg);
255
256 LlamaCoreError::Operation(err_msg)
257 })?;
258
259 match name {
260 Some(model_name) => match embedding_graphs.get(model_name) {
261 Some(graph) => Ok(graph.metadata.ctx_size),
262 None => {
263 let err_msg = format!(
264 "The model `{}` does not exist in the embedding graphs.",
265 model_name
266 );
267
268 #[cfg(feature = "logging")]
269 error!(target: "stdout", "{}", &err_msg);
270
271 Err(LlamaCoreError::Operation(err_msg))
272 }
273 },
274 None => {
275 if !embedding_graphs.is_empty() {
276 let graph = match embedding_graphs.values().next() {
277 Some(graph) => graph,
278 None => {
279 let err_msg = "Fail to get the underlying value of `EMBEDDING_GRAPHS`.";
280
281 #[cfg(feature = "logging")]
282 error!(target: "stdout", "{}", err_msg);
283
284 return Err(LlamaCoreError::Operation(err_msg.into()));
285 }
286 };
287
288 Ok(graph.metadata.ctx_size)
289 } else {
290 let err_msg = "There is no model available in the embedding graphs.";
291
292 #[cfg(feature = "logging")]
293 error!(target: "stdout", "{}", &err_msg);
294
295 Err(LlamaCoreError::Operation(err_msg.into()))
296 }
297 }
298 }
299}
300
301#[derive(Debug, Serialize, Deserialize)]
302struct Embedding {
303 #[serde(rename = "n_embedding")]
304 len: u64,
305 #[serde(rename = "embedding")]
306 data: Vec<f64>,
307}
308
309pub fn chunk_text(
327 text: impl AsRef<str>,
328 ty: impl AsRef<str>,
329 chunk_capacity: usize,
330) -> Result<Vec<String>, LlamaCoreError> {
331 if ty.as_ref().to_lowercase().as_str() != "txt" && ty.as_ref().to_lowercase().as_str() != "md" {
332 let err_msg = "Failed to upload the target file. Only files with 'txt' and 'md' extensions are supported.";
333
334 #[cfg(feature = "logging")]
335 error!(target: "stdout", "{}", err_msg);
336
337 return Err(LlamaCoreError::Operation(err_msg.into()));
338 }
339
340 match ty.as_ref().to_lowercase().as_str() {
341 "txt" => {
342 #[cfg(feature = "logging")]
343 info!(target: "stdout", "Chunk the plain text contents.");
344
345 let tokenizer = cl100k_base().map_err(|e| {
346 let err_msg = e.to_string();
347
348 #[cfg(feature = "logging")]
349 error!(target: "stdout", "{}", &err_msg);
350
351 LlamaCoreError::Operation(err_msg)
352 })?;
353
354 let splitter = TextSplitter::new(tokenizer).with_trim_chunks(true);
356
357 let chunks = splitter
358 .chunks(text.as_ref(), chunk_capacity)
359 .map(|s| s.to_string())
360 .collect::<Vec<_>>();
361
362 #[cfg(feature = "logging")]
363 info!(target: "stdout", "Number of chunks: {}", chunks.len());
364
365 Ok(chunks)
366 }
367 "md" => {
368 #[cfg(feature = "logging")]
369 info!(target: "stdout", "Chunk the markdown contents.");
370
371 let tokenizer = cl100k_base().map_err(|e| {
372 let err_msg = e.to_string();
373
374 #[cfg(feature = "logging")]
375 error!(target: "stdout", "{}", &err_msg);
376
377 LlamaCoreError::Operation(err_msg)
378 })?;
379
380 let splitter = MarkdownSplitter::new(tokenizer).with_trim_chunks(true);
382
383 let chunks = splitter
384 .chunks(text.as_ref(), chunk_capacity)
385 .map(|s| s.to_string())
386 .collect::<Vec<_>>();
387
388 #[cfg(feature = "logging")]
389 info!(target: "stdout", "Number of chunks: {}", chunks.len());
390
391 Ok(chunks)
392 }
393 _ => {
394 let err_msg =
395 "Failed to upload the target file. Only text and markdown files are supported.";
396
397 #[cfg(feature = "logging")]
398 error!(target: "stdout", "{}", err_msg);
399
400 Err(LlamaCoreError::Operation(err_msg.into()))
401 }
402 }
403}
404
405fn get_model_metadata(model_name: Option<&String>) -> Result<GgmlMetadata, LlamaCoreError> {
407 let embedding_graphs = match EMBEDDING_GRAPHS.get() {
408 Some(embedding_graphs) => embedding_graphs,
409 None => {
410 let err_msg = "Fail to get the underlying value of `EMBEDDING_GRAPHS`.";
411
412 #[cfg(feature = "logging")]
413 error!(target: "stdout", "{}", err_msg);
414
415 return Err(LlamaCoreError::Operation(err_msg.into()));
416 }
417 };
418
419 let embedding_graphs = embedding_graphs.lock().map_err(|e| {
420 let err_msg = format!("Fail to acquire the lock of `EMBEDDING_GRAPHS`. {}", e);
421
422 #[cfg(feature = "logging")]
423 error!(target: "stdout", "{}", &err_msg);
424
425 LlamaCoreError::Operation(err_msg)
426 })?;
427
428 match model_name {
429 Some(model_name) => match embedding_graphs.contains_key(model_name) {
430 true => {
431 let graph = embedding_graphs.get(model_name).unwrap();
432 Ok(graph.metadata.clone())
433 }
434 false => match embedding_graphs.iter().next() {
435 Some((_, graph)) => Ok(graph.metadata.clone()),
436 None => {
437 let err_msg = "There is no model available in the embedding graphs.";
438
439 #[cfg(feature = "logging")]
440 error!(target: "stdout", "{}", &err_msg);
441
442 Err(LlamaCoreError::Operation(err_msg.into()))
443 }
444 },
445 },
446 None => match embedding_graphs.iter().next() {
447 Some((_, graph)) => Ok(graph.metadata.clone()),
448 None => {
449 let err_msg = "There is no model available in the embedding graphs.";
450
451 #[cfg(feature = "logging")]
452 error!(target: "stdout", "{}", err_msg);
453
454 Err(LlamaCoreError::Operation(err_msg.into()))
455 }
456 },
457 }
458}
459
460fn update_model_metadata(
461 model_name: Option<&String>,
462 metadata: &GgmlMetadata,
463) -> Result<(), LlamaCoreError> {
464 let config = match serde_json::to_string(metadata) {
465 Ok(config) => config,
466 Err(e) => {
467 let err_msg = format!("Fail to serialize metadata to a JSON string. {}", e);
468
469 #[cfg(feature = "logging")]
470 error!(target: "stdout", "{}", &err_msg);
471
472 return Err(LlamaCoreError::Operation(err_msg));
473 }
474 };
475
476 let embedding_graphs = match EMBEDDING_GRAPHS.get() {
477 Some(embedding_graphs) => embedding_graphs,
478 None => {
479 let err_msg = "Fail to get the underlying value of `EMBEDDING_GRAPHS`.";
480
481 #[cfg(feature = "logging")]
482 error!(target: "stdout", "{}", err_msg);
483
484 return Err(LlamaCoreError::Operation(err_msg.into()));
485 }
486 };
487
488 let mut embedding_graphs = embedding_graphs.lock().map_err(|e| {
489 let err_msg = format!(
490 "Fail to acquire the lock of `EMBEDDING_GRAPHS`. Reason: {}",
491 e
492 );
493
494 #[cfg(feature = "logging")]
495 error!(target: "stdout", "{}", &err_msg);
496
497 LlamaCoreError::Operation(err_msg)
498 })?;
499
500 match model_name {
501 Some(model_name) => {
502 match embedding_graphs.contains_key(model_name) {
503 true => {
504 let graph = embedding_graphs.get_mut(model_name).unwrap();
505 set_tensor_data_u8(graph, 1, config.as_bytes())
507 }
508 false => match embedding_graphs.iter_mut().next() {
509 Some((_, graph)) => {
510 set_tensor_data_u8(graph, 1, config.as_bytes())
512 }
513 None => {
514 let err_msg = "There is no model available in the embedding graphs.";
515
516 #[cfg(feature = "logging")]
517 error!(target: "stdout", "{}", &err_msg);
518
519 Err(LlamaCoreError::Operation(err_msg.into()))
520 }
521 },
522 }
523 }
524 None => {
525 match embedding_graphs.iter_mut().next() {
526 Some((_, graph)) => {
527 set_tensor_data_u8(graph, 1, config.as_bytes())
529 }
530 None => {
531 let err_msg = "There is no model available in the embedding graphs.";
532
533 #[cfg(feature = "logging")]
534 error!(target: "stdout", "{}", err_msg);
535
536 Err(LlamaCoreError::Operation(err_msg.into()))
537 }
538 }
539 }
540 }
541}
542
543fn reset_model_metadata(model_name: Option<&String>) -> Result<(), LlamaCoreError> {
544 let metadata = get_model_metadata(model_name)?;
546
547 update_model_metadata(model_name, &metadata)
549}