1use crate::{
4 error::{BackendError, LlamaCoreError},
5 BaseMetadata, Graph, CHAT_GRAPHS, EMBEDDING_GRAPHS, MAX_BUFFER_SIZE,
6};
7use bitflags::bitflags;
8use chat_prompts::PromptTemplateType;
9use serde_json::Value;
10
11pub(crate) fn gen_chat_id() -> String {
12 format!("chatcmpl-{}", uuid::Uuid::new_v4())
13}
14
15pub fn chat_model_names() -> Result<Vec<String>, LlamaCoreError> {
17 #[cfg(feature = "logging")]
18 info!(target: "stdout", "Get the names of the chat models.");
19
20 let chat_graphs = match CHAT_GRAPHS.get() {
21 Some(chat_graphs) => chat_graphs,
22 None => {
23 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
24
25 #[cfg(feature = "logging")]
26 error!(target: "stdout", "{}", err_msg);
27
28 return Err(LlamaCoreError::Operation(err_msg.into()));
29 }
30 };
31
32 let chat_graphs = chat_graphs.lock().map_err(|e| {
33 let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {}", e);
34
35 #[cfg(feature = "logging")]
36 error!(target: "stdout", "{}", &err_msg);
37
38 LlamaCoreError::Operation(err_msg)
39 })?;
40
41 let mut model_names = Vec::new();
42 for model_name in chat_graphs.keys() {
43 model_names.push(model_name.clone());
44 }
45
46 Ok(model_names)
47}
48
49pub fn embedding_model_names() -> Result<Vec<String>, LlamaCoreError> {
51 #[cfg(feature = "logging")]
52 info!(target: "stdout", "Get the names of the embedding models.");
53
54 let embedding_graphs = match EMBEDDING_GRAPHS.get() {
55 Some(embedding_graphs) => embedding_graphs,
56 None => {
57 return Err(LlamaCoreError::Operation(String::from(
58 "Fail to get the underlying value of `EMBEDDING_GRAPHS`.",
59 )));
60 }
61 };
62
63 let embedding_graphs = match embedding_graphs.lock() {
64 Ok(embedding_graphs) => embedding_graphs,
65 Err(e) => {
66 let err_msg = format!("Fail to acquire the lock of `EMBEDDING_GRAPHS`. {}", e);
67
68 #[cfg(feature = "logging")]
69 error!(target: "stdout", "{}", &err_msg);
70
71 return Err(LlamaCoreError::Operation(err_msg));
72 }
73 };
74
75 let mut model_names = Vec::new();
76 for model_name in embedding_graphs.keys() {
77 model_names.push(model_name.clone());
78 }
79
80 Ok(model_names)
81}
82
83pub fn chat_prompt_template(name: Option<&str>) -> Result<PromptTemplateType, LlamaCoreError> {
85 #[cfg(feature = "logging")]
86 match name {
87 Some(name) => {
88 info!(target: "stdout", "Get the chat prompt template type from the chat model named {}.", name)
89 }
90 None => {
91 info!(target: "stdout", "Get the chat prompt template type from the default chat model.")
92 }
93 }
94
95 let chat_graphs = match CHAT_GRAPHS.get() {
96 Some(chat_graphs) => chat_graphs,
97 None => {
98 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
99
100 #[cfg(feature = "logging")]
101 error!(target: "stdout", "{}", err_msg);
102
103 return Err(LlamaCoreError::Operation(err_msg.into()));
104 }
105 };
106
107 let chat_graphs = chat_graphs.lock().map_err(|e| {
108 let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {}", e);
109
110 #[cfg(feature = "logging")]
111 error!(target: "stdout", "{}", &err_msg);
112
113 LlamaCoreError::Operation(err_msg)
114 })?;
115
116 match name {
117 Some(model_name) => match chat_graphs.contains_key(model_name) {
118 true => {
119 let graph = chat_graphs.get(model_name).unwrap();
120 let prompt_template = graph.metadata.prompt_template();
121
122 #[cfg(feature = "logging")]
123 info!(target: "stdout", "prompt_template: {}", &prompt_template);
124
125 Ok(prompt_template)
126 }
127 false => match chat_graphs.iter().next() {
128 Some((_, graph)) => {
129 let prompt_template = graph.metadata.prompt_template();
130
131 #[cfg(feature = "logging")]
132 info!(target: "stdout", "prompt_template: {}", &prompt_template);
133
134 Ok(prompt_template)
135 }
136 None => {
137 let err_msg = "There is no model available in the chat graphs.";
138
139 #[cfg(feature = "logging")]
140 error!(target: "stdout", "{}", &err_msg);
141
142 Err(LlamaCoreError::Operation(err_msg.into()))
143 }
144 },
145 },
146 None => match chat_graphs.iter().next() {
147 Some((_, graph)) => {
148 let prompt_template = graph.metadata.prompt_template();
149
150 #[cfg(feature = "logging")]
151 info!(target: "stdout", "prompt_template: {}", &prompt_template);
152
153 Ok(prompt_template)
154 }
155 None => {
156 let err_msg = "There is no model available in the chat graphs.";
157
158 #[cfg(feature = "logging")]
159 error!(target: "stdout", "{}", &err_msg);
160
161 Err(LlamaCoreError::Operation(err_msg.into()))
162 }
163 },
164 }
165}
166
167pub(crate) fn get_output_buffer<M>(
169 graph: &Graph<M>,
170 index: usize,
171) -> Result<Vec<u8>, LlamaCoreError>
172where
173 M: BaseMetadata + serde::Serialize + Clone + Default,
174{
175 let mut output_buffer: Vec<u8> = Vec::with_capacity(MAX_BUFFER_SIZE);
176
177 let output_size: usize = graph.get_output(index, &mut output_buffer).map_err(|e| {
178 let err_msg = format!("Fail to get the generated output tensor. {msg}", msg = e);
179
180 #[cfg(feature = "logging")]
181 error!(target: "stdout", "{}", &err_msg);
182
183 LlamaCoreError::Backend(BackendError::GetOutput(err_msg))
184 })?;
185
186 unsafe {
187 output_buffer.set_len(output_size);
188 }
189
190 Ok(output_buffer)
191}
192
193pub(crate) fn get_output_buffer_single<M>(
195 graph: &Graph<M>,
196 index: usize,
197) -> Result<Vec<u8>, LlamaCoreError>
198where
199 M: BaseMetadata + serde::Serialize + Clone + Default,
200{
201 #[cfg(feature = "logging")]
202 info!(target: "stdout", "Get output buffer generated by the model named {} in the stream mode.", graph.name());
203
204 let mut output_buffer: Vec<u8> = Vec::with_capacity(MAX_BUFFER_SIZE);
205
206 let output_size: usize = graph
207 .get_output_single(index, &mut output_buffer)
208 .map_err(|e| {
209 let err_msg = format!("Fail to get plugin metadata. {msg}", msg = e);
210
211 #[cfg(feature = "logging")]
212 error!(target: "stdout", "{}", &err_msg);
213
214 LlamaCoreError::Backend(BackendError::GetOutput(err_msg))
215 })?;
216
217 unsafe {
218 output_buffer.set_len(output_size);
219 }
220
221 Ok(output_buffer)
222}
223
224pub(crate) fn set_tensor_data_u8<M>(
225 graph: &mut Graph<M>,
226 idx: usize,
227 tensor_data: &[u8],
228) -> Result<(), LlamaCoreError>
229where
230 M: BaseMetadata + serde::Serialize + Clone + Default,
231{
232 if graph
233 .set_input(idx, wasmedge_wasi_nn::TensorType::U8, &[1], tensor_data)
234 .is_err()
235 {
236 let err_msg = format!("Fail to set input tensor at index {}", idx);
237
238 #[cfg(feature = "logging")]
239 error!(target: "stdout", "{}", &err_msg);
240
241 return Err(LlamaCoreError::Operation(err_msg));
242 };
243
244 Ok(())
245}
246
247pub(crate) fn get_token_info_by_graph<M>(graph: &Graph<M>) -> Result<TokenInfo, LlamaCoreError>
249where
250 M: BaseMetadata + serde::Serialize + Clone + Default,
251{
252 #[cfg(feature = "logging")]
253 info!(target: "stdout", "Get token info from the model named {}", graph.name());
254
255 let output_buffer = get_output_buffer(graph, 1)?;
256 let token_info: Value = match serde_json::from_slice(&output_buffer[..]) {
257 Ok(token_info) => token_info,
258 Err(e) => {
259 let err_msg = format!("Fail to deserialize token info: {msg}", msg = e);
260
261 #[cfg(feature = "logging")]
262 error!(target: "stdout", "{}", &err_msg);
263
264 return Err(LlamaCoreError::Operation(err_msg));
265 }
266 };
267
268 let prompt_tokens = match token_info["input_tokens"].as_u64() {
269 Some(prompt_tokens) => prompt_tokens,
270 None => {
271 let err_msg = "Fail to convert `input_tokens` to u64.";
272
273 #[cfg(feature = "logging")]
274 error!(target: "stdout", "{}", err_msg);
275
276 return Err(LlamaCoreError::Operation(err_msg.into()));
277 }
278 };
279 let completion_tokens = match token_info["output_tokens"].as_u64() {
280 Some(completion_tokens) => completion_tokens,
281 None => {
282 let err_msg = "Fail to convert `output_tokens` to u64.";
283
284 #[cfg(feature = "logging")]
285 error!(target: "stdout", "{}", err_msg);
286
287 return Err(LlamaCoreError::Operation(err_msg.into()));
288 }
289 };
290
291 Ok(TokenInfo {
292 prompt_tokens,
293 completion_tokens,
294 })
295}
296
297pub(crate) fn get_token_info_by_graph_name(
299 name: Option<&String>,
300) -> Result<TokenInfo, LlamaCoreError> {
301 let chat_graphs = match CHAT_GRAPHS.get() {
302 Some(chat_graphs) => chat_graphs,
303 None => {
304 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
305
306 #[cfg(feature = "logging")]
307 error!(target: "stdout", "{}", err_msg);
308
309 return Err(LlamaCoreError::Operation(err_msg.into()));
310 }
311 };
312
313 let chat_graphs = chat_graphs.lock().map_err(|e| {
314 let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {}", e);
315
316 #[cfg(feature = "logging")]
317 error!(target: "stdout", "{}", &err_msg);
318
319 LlamaCoreError::Operation(err_msg)
320 })?;
321
322 match name {
323 Some(model_name) => match chat_graphs.contains_key(model_name) {
324 true => {
325 let graph = chat_graphs.get(model_name).unwrap();
326 get_token_info_by_graph(graph)
327 }
328 false => match chat_graphs.iter().next() {
329 Some((_, graph)) => get_token_info_by_graph(graph),
330 None => {
331 let err_msg = "There is no model available in the chat graphs.";
332
333 #[cfg(feature = "logging")]
334 error!(target: "stdout", "{}", &err_msg);
335
336 Err(LlamaCoreError::Operation(err_msg.into()))
337 }
338 },
339 },
340 None => match chat_graphs.iter().next() {
341 Some((_, graph)) => get_token_info_by_graph(graph),
342 None => {
343 let err_msg = "There is no model available in the chat graphs.";
344
345 #[cfg(feature = "logging")]
346 error!(target: "stdout", "{}", &err_msg);
347
348 Err(LlamaCoreError::Operation(err_msg.into()))
349 }
350 },
351 }
352}
353
354#[derive(Debug)]
355pub(crate) struct TokenInfo {
356 pub(crate) prompt_tokens: u64,
357 pub(crate) completion_tokens: u64,
358}
359
360pub(crate) trait TensorType {
361 fn tensor_type() -> wasmedge_wasi_nn::TensorType;
362 fn shape(shape: impl AsRef<[usize]>) -> Vec<usize> {
363 shape.as_ref().to_vec()
364 }
365}
366
367impl TensorType for u8 {
368 fn tensor_type() -> wasmedge_wasi_nn::TensorType {
369 wasmedge_wasi_nn::TensorType::U8
370 }
371}
372
373impl TensorType for f32 {
374 fn tensor_type() -> wasmedge_wasi_nn::TensorType {
375 wasmedge_wasi_nn::TensorType::F32
376 }
377}
378
379pub(crate) fn set_tensor_data<T, M>(
380 graph: &mut Graph<M>,
381 idx: usize,
382 tensor_data: &[T],
383 shape: impl AsRef<[usize]>,
384) -> Result<(), LlamaCoreError>
385where
386 T: TensorType,
387 M: BaseMetadata + serde::Serialize + Clone + Default,
388{
389 if graph
390 .set_input(idx, T::tensor_type(), &T::shape(shape), tensor_data)
391 .is_err()
392 {
393 let err_msg = format!("Fail to set input tensor at index {}", idx);
394
395 #[cfg(feature = "logging")]
396 error!(target: "stdout", "{}", &err_msg);
397
398 return Err(LlamaCoreError::Operation(err_msg));
399 };
400
401 Ok(())
402}
403
404bitflags! {
405 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
406 pub struct RunningMode: u32 {
407 const UNSET = 0b00000000;
408 const CHAT = 0b00000001;
409 const EMBEDDINGS = 0b00000010;
410 const TTS = 0b00000100;
411 const RAG = 0b00001000;
412 }
413}
414impl std::fmt::Display for RunningMode {
415 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
416 let mut mode = String::new();
417
418 if self.contains(RunningMode::CHAT) {
419 mode.push_str("chat, ");
420 }
421 if self.contains(RunningMode::EMBEDDINGS) {
422 mode.push_str("embeddings, ");
423 }
424 if self.contains(RunningMode::TTS) {
425 mode.push_str("tts, ");
426 }
427
428 mode = mode.trim_end_matches(", ").to_string();
429
430 write!(f, "{}", mode)
431 }
432}