llama_core/
utils.rs

1//! Define utility functions.
2
3use crate::{
4    error::{BackendError, LlamaCoreError},
5    BaseMetadata, Graph, CHAT_GRAPHS, EMBEDDING_GRAPHS, MAX_BUFFER_SIZE,
6};
7use bitflags::bitflags;
8use chat_prompts::PromptTemplateType;
9use serde_json::Value;
10
11pub(crate) fn gen_chat_id() -> String {
12    format!("chatcmpl-{}", uuid::Uuid::new_v4())
13}
14
15/// Return the names of the chat models.
16pub fn chat_model_names() -> Result<Vec<String>, LlamaCoreError> {
17    #[cfg(feature = "logging")]
18    info!(target: "stdout", "Get the names of the chat models.");
19
20    let chat_graphs = match CHAT_GRAPHS.get() {
21        Some(chat_graphs) => chat_graphs,
22        None => {
23            let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
24
25            #[cfg(feature = "logging")]
26            error!(target: "stdout", "{}", err_msg);
27
28            return Err(LlamaCoreError::Operation(err_msg.into()));
29        }
30    };
31
32    let chat_graphs = chat_graphs.lock().map_err(|e| {
33        let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {}", e);
34
35        #[cfg(feature = "logging")]
36        error!(target: "stdout", "{}", &err_msg);
37
38        LlamaCoreError::Operation(err_msg)
39    })?;
40
41    let mut model_names = Vec::new();
42    for model_name in chat_graphs.keys() {
43        model_names.push(model_name.clone());
44    }
45
46    Ok(model_names)
47}
48
49/// Return the names of the embedding models.
50pub fn embedding_model_names() -> Result<Vec<String>, LlamaCoreError> {
51    #[cfg(feature = "logging")]
52    info!(target: "stdout", "Get the names of the embedding models.");
53
54    let embedding_graphs = match EMBEDDING_GRAPHS.get() {
55        Some(embedding_graphs) => embedding_graphs,
56        None => {
57            return Err(LlamaCoreError::Operation(String::from(
58                "Fail to get the underlying value of `EMBEDDING_GRAPHS`.",
59            )));
60        }
61    };
62
63    let embedding_graphs = match embedding_graphs.lock() {
64        Ok(embedding_graphs) => embedding_graphs,
65        Err(e) => {
66            let err_msg = format!("Fail to acquire the lock of `EMBEDDING_GRAPHS`. {}", e);
67
68            #[cfg(feature = "logging")]
69            error!(target: "stdout", "{}", &err_msg);
70
71            return Err(LlamaCoreError::Operation(err_msg));
72        }
73    };
74
75    let mut model_names = Vec::new();
76    for model_name in embedding_graphs.keys() {
77        model_names.push(model_name.clone());
78    }
79
80    Ok(model_names)
81}
82
83/// Get the chat prompt template type from the given model name.
84pub fn chat_prompt_template(name: Option<&str>) -> Result<PromptTemplateType, LlamaCoreError> {
85    #[cfg(feature = "logging")]
86    match name {
87        Some(name) => {
88            info!(target: "stdout", "Get the chat prompt template type from the chat model named {}.", name)
89        }
90        None => {
91            info!(target: "stdout", "Get the chat prompt template type from the default chat model.")
92        }
93    }
94
95    let chat_graphs = match CHAT_GRAPHS.get() {
96        Some(chat_graphs) => chat_graphs,
97        None => {
98            let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
99
100            #[cfg(feature = "logging")]
101            error!(target: "stdout", "{}", err_msg);
102
103            return Err(LlamaCoreError::Operation(err_msg.into()));
104        }
105    };
106
107    let chat_graphs = chat_graphs.lock().map_err(|e| {
108        let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {}", e);
109
110        #[cfg(feature = "logging")]
111        error!(target: "stdout", "{}", &err_msg);
112
113        LlamaCoreError::Operation(err_msg)
114    })?;
115
116    match name {
117        Some(model_name) => match chat_graphs.contains_key(model_name) {
118            true => {
119                let graph = chat_graphs.get(model_name).unwrap();
120                let prompt_template = graph.metadata.prompt_template();
121
122                #[cfg(feature = "logging")]
123                info!(target: "stdout", "prompt_template: {}", &prompt_template);
124
125                Ok(prompt_template)
126            }
127            false => match chat_graphs.iter().next() {
128                Some((_, graph)) => {
129                    let prompt_template = graph.metadata.prompt_template();
130
131                    #[cfg(feature = "logging")]
132                    info!(target: "stdout", "prompt_template: {}", &prompt_template);
133
134                    Ok(prompt_template)
135                }
136                None => {
137                    let err_msg = "There is no model available in the chat graphs.";
138
139                    #[cfg(feature = "logging")]
140                    error!(target: "stdout", "{}", &err_msg);
141
142                    Err(LlamaCoreError::Operation(err_msg.into()))
143                }
144            },
145        },
146        None => match chat_graphs.iter().next() {
147            Some((_, graph)) => {
148                let prompt_template = graph.metadata.prompt_template();
149
150                #[cfg(feature = "logging")]
151                info!(target: "stdout", "prompt_template: {}", &prompt_template);
152
153                Ok(prompt_template)
154            }
155            None => {
156                let err_msg = "There is no model available in the chat graphs.";
157
158                #[cfg(feature = "logging")]
159                error!(target: "stdout", "{}", &err_msg);
160
161                Err(LlamaCoreError::Operation(err_msg.into()))
162            }
163        },
164    }
165}
166
167/// Get output buffer generated by model.
168pub(crate) fn get_output_buffer<M>(
169    graph: &Graph<M>,
170    index: usize,
171) -> Result<Vec<u8>, LlamaCoreError>
172where
173    M: BaseMetadata + serde::Serialize + Clone + Default,
174{
175    let mut output_buffer: Vec<u8> = Vec::with_capacity(MAX_BUFFER_SIZE);
176
177    let output_size: usize = graph.get_output(index, &mut output_buffer).map_err(|e| {
178        let err_msg = format!("Fail to get the generated output tensor. {msg}", msg = e);
179
180        #[cfg(feature = "logging")]
181        error!(target: "stdout", "{}", &err_msg);
182
183        LlamaCoreError::Backend(BackendError::GetOutput(err_msg))
184    })?;
185
186    unsafe {
187        output_buffer.set_len(output_size);
188    }
189
190    Ok(output_buffer)
191}
192
193/// Get output buffer generated by model in the stream mode.
194pub(crate) fn get_output_buffer_single<M>(
195    graph: &Graph<M>,
196    index: usize,
197) -> Result<Vec<u8>, LlamaCoreError>
198where
199    M: BaseMetadata + serde::Serialize + Clone + Default,
200{
201    #[cfg(feature = "logging")]
202    info!(target: "stdout", "Get output buffer generated by the model named {} in the stream mode.", graph.name());
203
204    let mut output_buffer: Vec<u8> = Vec::with_capacity(MAX_BUFFER_SIZE);
205
206    let output_size: usize = graph
207        .get_output_single(index, &mut output_buffer)
208        .map_err(|e| {
209            let err_msg = format!("Fail to get plugin metadata. {msg}", msg = e);
210
211            #[cfg(feature = "logging")]
212            error!(target: "stdout", "{}", &err_msg);
213
214            LlamaCoreError::Backend(BackendError::GetOutput(err_msg))
215        })?;
216
217    unsafe {
218        output_buffer.set_len(output_size);
219    }
220
221    Ok(output_buffer)
222}
223
224pub(crate) fn set_tensor_data_u8<M>(
225    graph: &mut Graph<M>,
226    idx: usize,
227    tensor_data: &[u8],
228) -> Result<(), LlamaCoreError>
229where
230    M: BaseMetadata + serde::Serialize + Clone + Default,
231{
232    if graph
233        .set_input(idx, wasmedge_wasi_nn::TensorType::U8, &[1], tensor_data)
234        .is_err()
235    {
236        let err_msg = format!("Fail to set input tensor at index {}", idx);
237
238        #[cfg(feature = "logging")]
239        error!(target: "stdout", "{}", &err_msg);
240
241        return Err(LlamaCoreError::Operation(err_msg));
242    };
243
244    Ok(())
245}
246
247/// Get the token information from the graph.
248pub(crate) fn get_token_info_by_graph<M>(graph: &Graph<M>) -> Result<TokenInfo, LlamaCoreError>
249where
250    M: BaseMetadata + serde::Serialize + Clone + Default,
251{
252    #[cfg(feature = "logging")]
253    info!(target: "stdout", "Get token info from the model named {}", graph.name());
254
255    let output_buffer = get_output_buffer(graph, 1)?;
256    let token_info: Value = match serde_json::from_slice(&output_buffer[..]) {
257        Ok(token_info) => token_info,
258        Err(e) => {
259            let err_msg = format!("Fail to deserialize token info: {msg}", msg = e);
260
261            #[cfg(feature = "logging")]
262            error!(target: "stdout", "{}", &err_msg);
263
264            return Err(LlamaCoreError::Operation(err_msg));
265        }
266    };
267
268    let prompt_tokens = match token_info["input_tokens"].as_u64() {
269        Some(prompt_tokens) => prompt_tokens,
270        None => {
271            let err_msg = "Fail to convert `input_tokens` to u64.";
272
273            #[cfg(feature = "logging")]
274            error!(target: "stdout", "{}", err_msg);
275
276            return Err(LlamaCoreError::Operation(err_msg.into()));
277        }
278    };
279    let completion_tokens = match token_info["output_tokens"].as_u64() {
280        Some(completion_tokens) => completion_tokens,
281        None => {
282            let err_msg = "Fail to convert `output_tokens` to u64.";
283
284            #[cfg(feature = "logging")]
285            error!(target: "stdout", "{}", err_msg);
286
287            return Err(LlamaCoreError::Operation(err_msg.into()));
288        }
289    };
290
291    Ok(TokenInfo {
292        prompt_tokens,
293        completion_tokens,
294    })
295}
296
297/// Get the token information from the graph by the model name.
298pub(crate) fn get_token_info_by_graph_name(
299    name: Option<&String>,
300) -> Result<TokenInfo, LlamaCoreError> {
301    let chat_graphs = match CHAT_GRAPHS.get() {
302        Some(chat_graphs) => chat_graphs,
303        None => {
304            let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
305
306            #[cfg(feature = "logging")]
307            error!(target: "stdout", "{}", err_msg);
308
309            return Err(LlamaCoreError::Operation(err_msg.into()));
310        }
311    };
312
313    let chat_graphs = chat_graphs.lock().map_err(|e| {
314        let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {}", e);
315
316        #[cfg(feature = "logging")]
317        error!(target: "stdout", "{}", &err_msg);
318
319        LlamaCoreError::Operation(err_msg)
320    })?;
321
322    match name {
323        Some(model_name) => match chat_graphs.contains_key(model_name) {
324            true => {
325                let graph = chat_graphs.get(model_name).unwrap();
326                get_token_info_by_graph(graph)
327            }
328            false => match chat_graphs.iter().next() {
329                Some((_, graph)) => get_token_info_by_graph(graph),
330                None => {
331                    let err_msg = "There is no model available in the chat graphs.";
332
333                    #[cfg(feature = "logging")]
334                    error!(target: "stdout", "{}", &err_msg);
335
336                    Err(LlamaCoreError::Operation(err_msg.into()))
337                }
338            },
339        },
340        None => match chat_graphs.iter().next() {
341            Some((_, graph)) => get_token_info_by_graph(graph),
342            None => {
343                let err_msg = "There is no model available in the chat graphs.";
344
345                #[cfg(feature = "logging")]
346                error!(target: "stdout", "{}", &err_msg);
347
348                Err(LlamaCoreError::Operation(err_msg.into()))
349            }
350        },
351    }
352}
353
354#[derive(Debug)]
355pub(crate) struct TokenInfo {
356    pub(crate) prompt_tokens: u64,
357    pub(crate) completion_tokens: u64,
358}
359
360pub(crate) trait TensorType {
361    fn tensor_type() -> wasmedge_wasi_nn::TensorType;
362    fn shape(shape: impl AsRef<[usize]>) -> Vec<usize> {
363        shape.as_ref().to_vec()
364    }
365}
366
367impl TensorType for u8 {
368    fn tensor_type() -> wasmedge_wasi_nn::TensorType {
369        wasmedge_wasi_nn::TensorType::U8
370    }
371}
372
373impl TensorType for f32 {
374    fn tensor_type() -> wasmedge_wasi_nn::TensorType {
375        wasmedge_wasi_nn::TensorType::F32
376    }
377}
378
379pub(crate) fn set_tensor_data<T, M>(
380    graph: &mut Graph<M>,
381    idx: usize,
382    tensor_data: &[T],
383    shape: impl AsRef<[usize]>,
384) -> Result<(), LlamaCoreError>
385where
386    T: TensorType,
387    M: BaseMetadata + serde::Serialize + Clone + Default,
388{
389    if graph
390        .set_input(idx, T::tensor_type(), &T::shape(shape), tensor_data)
391        .is_err()
392    {
393        let err_msg = format!("Fail to set input tensor at index {}", idx);
394
395        #[cfg(feature = "logging")]
396        error!(target: "stdout", "{}", &err_msg);
397
398        return Err(LlamaCoreError::Operation(err_msg));
399    };
400
401    Ok(())
402}
403
404bitflags! {
405    #[derive(Debug, Clone, Copy, PartialEq, Eq)]
406    pub struct RunningMode: u32 {
407        const UNSET = 0b00000000;
408        const CHAT = 0b00000001;
409        const EMBEDDINGS = 0b00000010;
410        const TTS = 0b00000100;
411        const RAG = 0b00001000;
412    }
413}
414impl std::fmt::Display for RunningMode {
415    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
416        let mut mode = String::new();
417
418        if self.contains(RunningMode::CHAT) {
419            mode.push_str("chat, ");
420        }
421        if self.contains(RunningMode::EMBEDDINGS) {
422            mode.push_str("embeddings, ");
423        }
424        if self.contains(RunningMode::TTS) {
425            mode.push_str("tts, ");
426        }
427
428        mode = mode.trim_end_matches(", ").to_string();
429
430        write!(f, "{}", mode)
431    }
432}