llama_core/
chat.rs

1//! Define APIs for chat completion.
2
3use crate::{
4    error,
5    metadata::ggml::GgmlMetadata,
6    running_mode,
7    utils::{
8        gen_chat_id, get_output_buffer, get_output_buffer_single, get_token_info_by_graph,
9        get_token_info_by_graph_name, set_tensor_data_u8,
10    },
11    Graph, RunningMode, CACHED_UTF8_ENCODINGS, CHAT_GRAPHS, OUTPUT_TENSOR,
12};
13use chat_prompts::{
14    chat::{BuildChatPrompt, ChatPrompt},
15    PromptTemplateType,
16};
17use either::{Either, Left, Right};
18use endpoints::{
19    chat::{
20        ChatCompletionChunk, ChatCompletionChunkChoice, ChatCompletionChunkChoiceDelta,
21        ChatCompletionObject, ChatCompletionObjectChoice, ChatCompletionObjectMessage,
22        ChatCompletionRequest, ChatCompletionRequestMessage, ChatCompletionRole,
23        ChatCompletionUserMessageContent, ContentPart, Function, ToolCall, ToolCallForChunk,
24        ToolChoice,
25    },
26    common::{FinishReason, Usage},
27};
28use error::{BackendError, LlamaCoreError};
29use futures::StreamExt;
30use std::{
31    collections::VecDeque,
32    fs::{self, File},
33    path::Path,
34    pin::Pin,
35    sync::Mutex,
36    task::{Context, Poll},
37    time::SystemTime,
38};
39
40/// Processes a chat-completion request and returns either a stream of ChatCompletionChunk instances or a ChatCompletionObject instance.
41pub async fn chat(
42    chat_request: &mut ChatCompletionRequest,
43) -> Result<
44    Either<impl futures::TryStream<Ok = String, Error = LlamaCoreError>, ChatCompletionObject>,
45    LlamaCoreError,
46> {
47    #[cfg(feature = "logging")]
48    {
49        debug!(target: "stdout", "tool choice: {:?}", chat_request.tool_choice.as_ref());
50        debug!(target: "stdout", "tools: {:?}", chat_request.tools.as_ref());
51        debug!(target: "stdout", "stream mode: {:?}", chat_request.stream);
52    }
53
54    let model_name = chat_request.model.clone();
55
56    let result = match chat_request.stream {
57        Some(true) => match chat_stream(chat_request).await {
58            Ok(stream) => Ok(Left(stream)),
59            Err(e) => Err(e),
60        },
61        Some(false) | None => match chat_once(chat_request).await {
62            Ok(chat_completion_object) => Ok(Right(chat_completion_object)),
63            Err(e) => Err(e),
64        },
65    };
66
67    #[cfg(feature = "logging")]
68    info!(target: "stdout", "Reset the model metadata");
69
70    // reset the model metadata
71    reset_model_metadata(model_name.as_ref())?;
72
73    result
74}
75
76/// Processes a chat-completion request and returns ChatCompletionChunk instances in stream.
77#[deprecated(since = "0.10.0", note = "Please use the `chat` function.")]
78pub async fn chat_completions_stream(
79    chat_request: &mut ChatCompletionRequest,
80) -> Result<impl futures::TryStream<Ok = String, Error = LlamaCoreError>, LlamaCoreError> {
81    chat_stream(chat_request).await
82}
83
84/// Processes a chat-completion request and returns a ChatCompletionObject instance.
85#[deprecated(since = "0.10.0", note = "Please use the `chat` function.")]
86pub async fn chat_completions(
87    chat_request: &mut ChatCompletionRequest,
88) -> Result<ChatCompletionObject, LlamaCoreError> {
89    chat_once(chat_request).await
90}
91
92async fn chat_stream(
93    chat_request: &mut ChatCompletionRequest,
94) -> Result<impl futures::TryStream<Ok = String, Error = LlamaCoreError>, LlamaCoreError> {
95    #[cfg(feature = "logging")]
96    info!(target: "stdout", "Process chat completion request in the stream mode");
97
98    let running_mode = running_mode()?;
99    if !running_mode.contains(RunningMode::CHAT) && !running_mode.contains(RunningMode::RAG) {
100        let err_msg = "The chat completion is only supported in the chat or rag mode.";
101
102        #[cfg(feature = "logging")]
103        error!(target: "stdout", "{}", err_msg);
104
105        return Err(LlamaCoreError::Operation(err_msg.to_string()));
106    }
107
108    let model_name = chat_request.model.clone();
109    let id = match &chat_request.user {
110        Some(id) => id.clone(),
111        None => gen_chat_id(),
112    };
113    #[cfg(feature = "logging")]
114    info!(target: "stdout", "user: {}", &id);
115
116    #[cfg(feature = "logging")]
117    info!(target: "stdout", "Check model metadata");
118
119    // update metadata
120    let mut metadata = check_model_metadata(chat_request).await?;
121
122    // parse the `include_usage` option
123    let include_usage = match chat_request.stream_options {
124        Some(ref stream_options) => stream_options.include_usage.unwrap_or_default(),
125        None => metadata.include_usage,
126    };
127    #[cfg(feature = "logging")]
128    info!(target: "stdout", "include_usage: {}", include_usage);
129
130    #[cfg(feature = "logging")]
131    info!(target: "stdout", "Build the chat prompt");
132
133    // build prompt
134    let (prompt, avaible_completion_tokens, tool_use) =
135        build_prompt(model_name.as_ref(), chat_request)?;
136
137    #[cfg(feature = "logging")]
138    {
139        info!(target: "stdout", "prompt:\n{}", &prompt);
140        info!(target: "stdout", "available_completion_tokens: {}", avaible_completion_tokens);
141        info!(target: "stdout", "tool_use: {}", tool_use);
142    }
143
144    #[cfg(feature = "logging")]
145    info!(target: "stdout", "Update the n_predict");
146
147    // update metadata n_predict
148    update_n_predict(chat_request, &mut metadata, avaible_completion_tokens).await?;
149
150    #[cfg(feature = "logging")]
151    info!(target: "stdout", "Feed the prompt to the model");
152
153    // set prompt
154    set_prompt(chat_request.model.as_ref(), &prompt)?;
155
156    let stream = match tool_use {
157        false => ChatStream::new(model_name, id, include_usage, None),
158        true => {
159            let chat_graphs = match CHAT_GRAPHS.get() {
160                Some(chat_graphs) => chat_graphs,
161                None => {
162                    let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
163
164                    #[cfg(feature = "logging")]
165                    error!(target: "stdout", "{}", &err_msg);
166
167                    return Err(LlamaCoreError::Operation(err_msg.into()));
168                }
169            };
170
171            let mut chat_graphs = chat_graphs.lock().map_err(|e| {
172                let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {}", e);
173
174                #[cfg(feature = "logging")]
175                error!(target: "stdout", "{}", &err_msg);
176
177                LlamaCoreError::Operation(err_msg)
178            })?;
179
180            match model_name {
181                Some(model_name) => match chat_graphs.contains_key(&model_name) {
182                    true => {
183                        let graph = chat_graphs.get_mut(&model_name).unwrap();
184                        chat_stream_by_graph(graph, id, include_usage)?
185                    }
186                    false => match chat_graphs.iter_mut().next() {
187                        Some((_, graph)) => chat_stream_by_graph(graph, id, include_usage)?,
188                        None => {
189                            let err_msg = "There is no model available in the chat graphs.";
190
191                            #[cfg(feature = "logging")]
192                            error!(target: "stdout", "{}", &err_msg);
193
194                            return Err(LlamaCoreError::Operation(err_msg.into()));
195                        }
196                    },
197                },
198                None => match chat_graphs.iter_mut().next() {
199                    Some((_, graph)) => chat_stream_by_graph(graph, id, include_usage)?,
200                    None => {
201                        let err_msg = "There is no model available in the chat graphs.";
202
203                        #[cfg(feature = "logging")]
204                        error!(target: "stdout", "{}", &err_msg);
205
206                        return Err(LlamaCoreError::Operation(err_msg.into()));
207                    }
208                },
209            }
210        }
211    };
212
213    #[cfg(feature = "logging")]
214    info!(target: "stdout", "End of the chat completion stream.");
215
216    Ok(stream)
217}
218
219fn chat_stream_by_graph(
220    graph: &mut Graph<GgmlMetadata>,
221    id: impl Into<String>,
222    include_usage: bool,
223) -> Result<ChatStream, LlamaCoreError> {
224    #[cfg(feature = "logging")]
225    info!(target: "stdout", "Handle chat request with available tools by the model named {}.", graph.name());
226
227    let id = id.into();
228
229    match graph.compute() {
230        Ok(_) => {
231            // Retrieve the output.
232            let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
233            let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
234                let err_msg = format!(
235                    "Failed to decode the buffer of the inference result to a utf-8 string. {}",
236                    e
237                );
238
239                #[cfg(feature = "logging")]
240                error!(target: "stdout", "{}", &err_msg);
241
242                LlamaCoreError::Operation(err_msg)
243            })?;
244
245            #[cfg(feature = "logging")]
246            info!(target: "stdout", "raw generation:\n{}", output);
247
248            // post-process
249            let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
250                LlamaCoreError::Operation(format!("Failed to post-process the output. {}", e))
251            })?;
252
253            #[cfg(feature = "logging")]
254            info!(target: "stdout", "post-processed generation:\n{}", &message);
255
256            // retrieve the number of prompt and completion tokens
257            let token_info = get_token_info_by_graph(graph)?;
258
259            #[cfg(feature = "logging")]
260            info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
261
262            let usage = Some(Usage {
263                prompt_tokens: token_info.prompt_tokens,
264                completion_tokens: token_info.completion_tokens,
265                total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
266            });
267
268            let created = SystemTime::now()
269                .duration_since(std::time::UNIX_EPOCH)
270                .map_err(|e| {
271                    let err_msg = format!("Failed to get the current time. Reason: {}", e);
272
273                    #[cfg(feature = "logging")]
274                    error!(target: "stdout", "{}", &err_msg);
275
276                    LlamaCoreError::Operation(err_msg)
277                })?;
278
279            if graph.metadata.prompt_template != PromptTemplateType::MistralTool
280                && graph.metadata.prompt_template != PromptTemplateType::ChatMLTool
281                && graph.metadata.prompt_template != PromptTemplateType::GroqLlama3Tool
282                && graph.metadata.prompt_template != PromptTemplateType::Llama3Tool
283                && graph.metadata.prompt_template != PromptTemplateType::InternLM2Tool
284                && graph.metadata.prompt_template != PromptTemplateType::NemotronTool
285                && graph.metadata.prompt_template != PromptTemplateType::FunctionaryV32
286                && graph.metadata.prompt_template != PromptTemplateType::FunctionaryV31
287                && graph.metadata.prompt_template != PromptTemplateType::MistralSmallTool
288            {
289                let err_msg = format!("Unsupported prompt template: {}. The tool use is only supported for 'mistral-tool', 'chatml-tool', 'groq-llama3-tool', 'llama-3-tool', 'internlm-2-tool', 'nemotron-tool', 'functionary-31', 'functionary-32', and 'mistral-small-tool' prompt templates.", graph.metadata.prompt_template);
290
291                #[cfg(feature = "logging")]
292                error!(target: "stdout", "{}", &err_msg);
293
294                return Err(LlamaCoreError::Operation(err_msg));
295            }
296
297            let parsed_result = parse_tool_calls(&message, graph.metadata.prompt_template)?;
298
299            let content = match parsed_result.content {
300                Some(content) => Some(content),
301                None => Some(parsed_result.raw),
302            };
303
304            let tool_calls: Vec<ToolCallForChunk> = parsed_result
305                .tool_calls
306                .into_iter()
307                .enumerate()
308                .map(|(index, tool_call)| ToolCallForChunk {
309                    index,
310                    id: tool_call.id,
311                    ty: tool_call.ty,
312                    function: tool_call.function,
313                })
314                .collect();
315
316            // tool_calls chunk
317            let tool_call_chunk = {
318                let chat_completion_chunk = ChatCompletionChunk {
319                    id: id.clone(),
320                    object: "chat.completion.chunk".to_string(),
321                    created: created.as_secs(),
322                    model: graph.name().to_owned(),
323                    system_fingerprint: "fp_44709d6fcb".to_string(),
324                    choices: vec![ChatCompletionChunkChoice {
325                        index: 0,
326                        delta: ChatCompletionChunkChoiceDelta {
327                            role: ChatCompletionRole::Assistant,
328                            content,
329                            tool_calls,
330                        },
331                        logprobs: None,
332                        finish_reason: None,
333                    }],
334                    usage: None,
335                };
336                let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
337                    let err_msg =
338                        format!("Failed to serialize chat completion chunk. Reason: {}", e);
339
340                    #[cfg(feature = "logging")]
341                    error!(target: "stdout", "{}", &err_msg);
342
343                    LlamaCoreError::Operation(err_msg)
344                })?;
345
346                format!("data: {}\n\n", chunk_str)
347            };
348
349            // uage chunk
350            let usage_chunk = {
351                let chat_completion_chunk = ChatCompletionChunk {
352                    id: id.clone(),
353                    object: "chat.completion.chunk".to_string(),
354                    created: created.as_secs(),
355                    model: graph.name().to_owned(),
356                    system_fingerprint: "fp_44709d6fcb".to_string(),
357                    choices: vec![],
358                    usage,
359                };
360                let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
361                    let err_msg =
362                        format!("Failed to serialize chat completion chunk. Reason: {}", e);
363
364                    #[cfg(feature = "logging")]
365                    error!(target: "stdout", "{}", &err_msg);
366
367                    LlamaCoreError::Operation(err_msg)
368                })?;
369
370                format!("data: {}\n\n", chunk_str)
371            };
372
373            // ending chunk
374            let ending_chunk = "data: [DONE]\n\n".to_string();
375
376            let chunks = vec![tool_call_chunk, usage_chunk, ending_chunk];
377
378            Ok(ChatStream::new(
379                Some(graph.name().to_owned()),
380                id,
381                include_usage,
382                Some(chunks),
383            ))
384        }
385        Err(wasmedge_wasi_nn::Error::BackendError(wasmedge_wasi_nn::BackendError::ContextFull)) => {
386            // Retrieve the output.
387            let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
388            let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
389                let err_msg = format!(
390                    "Failed to decode the buffer of the inference result to a utf-8 string. {}",
391                    e
392                );
393
394                #[cfg(feature = "logging")]
395                error!(target: "stdout", "{}", &err_msg);
396
397                LlamaCoreError::Operation(err_msg)
398            })?;
399
400            // post-process
401            let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
402                let err_msg = format!("Failed to post-process the output. {}", e);
403
404                #[cfg(feature = "logging")]
405                error!(target: "stdout", "{}", &err_msg);
406
407                LlamaCoreError::Operation(err_msg)
408            })?;
409
410            // retrieve the number of prompt and completion tokens
411            let token_info = get_token_info_by_graph(graph)?;
412
413            #[cfg(feature = "logging")]
414            info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
415
416            let usage = Some(Usage {
417                prompt_tokens: token_info.prompt_tokens,
418                completion_tokens: token_info.completion_tokens,
419                total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
420            });
421
422            let created = SystemTime::now()
423                .duration_since(std::time::UNIX_EPOCH)
424                .map_err(|e| {
425                    let err_msg = format!("Failed to get the current time. Reason: {}", e);
426
427                    #[cfg(feature = "logging")]
428                    error!(target: "stdout", "{}", &err_msg);
429
430                    LlamaCoreError::Operation(err_msg)
431                })?;
432
433            // context full chunk
434            let context_full_chunk = {
435                let chat_completion_chunk = ChatCompletionChunk {
436                    id: id.clone(),
437                    object: "chat.completion.chunk".to_string(),
438                    created: created.as_secs(),
439                    model: graph.name().to_owned(),
440                    system_fingerprint: "fp_44709d6fcb".to_string(),
441                    choices: vec![ChatCompletionChunkChoice {
442                        index: 0,
443                        delta: ChatCompletionChunkChoiceDelta {
444                            role: ChatCompletionRole::Assistant,
445                            content: Some(message),
446                            tool_calls: vec![],
447                        },
448                        logprobs: None,
449                        finish_reason: Some(FinishReason::length),
450                    }],
451                    usage: None,
452                };
453
454                // serialize chat completion chunk
455                let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
456                    let err_msg =
457                        format!("Failed to serialize chat completion chunk. Reason: {}", e);
458
459                    #[cfg(feature = "logging")]
460                    error!(target: "stdout", "{}", &err_msg);
461
462                    LlamaCoreError::Operation(err_msg)
463                })?;
464
465                format!("data: {}\n\n", chunk_str)
466            };
467
468            // usage chunk
469            let usage_chunk = {
470                let chat_completion_chunk = ChatCompletionChunk {
471                    id: id.clone(),
472                    object: "chat.completion.chunk".to_string(),
473                    created: created.as_secs(),
474                    model: graph.name().to_owned(),
475                    system_fingerprint: "fp_44709d6fcb".to_string(),
476                    choices: vec![],
477                    usage,
478                };
479
480                // serialize chat completion chunk
481                let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
482                    let err_msg =
483                        format!("Failed to serialize chat completion chunk. Reason: {}", e);
484
485                    #[cfg(feature = "logging")]
486                    error!(target: "stdout", "{}", &err_msg);
487
488                    LlamaCoreError::Operation(err_msg)
489                })?;
490
491                format!("data: {}\n\n", chunk_str)
492            };
493
494            // ending chunk
495            let ending_chunk = "data: [DONE]\n\n".to_string();
496
497            let chunks = vec![context_full_chunk, usage_chunk, ending_chunk];
498
499            Ok(ChatStream::new(
500                Some(graph.name().to_owned()),
501                id,
502                include_usage,
503                Some(chunks),
504            ))
505        }
506        Err(wasmedge_wasi_nn::Error::BackendError(
507            wasmedge_wasi_nn::BackendError::PromptTooLong,
508        )) => {
509            #[cfg(feature = "logging")]
510            warn!(target: "stdout", "The prompt is too long. Please reduce the length of your input and try again.");
511
512            // Retrieve the output.
513            let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
514            let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
515                let err_msg = format!(
516                    "Failed to decode the buffer of the inference result to a utf-8 string. {}",
517                    e
518                );
519
520                #[cfg(feature = "logging")]
521                error!(target: "stdout", "{}", &err_msg);
522
523                LlamaCoreError::Operation(err_msg)
524            })?;
525
526            // post-process
527            let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
528                let err_msg = format!("Failed to post-process the output. {}", e);
529
530                #[cfg(feature = "logging")]
531                error!(target: "stdout", "{}", &err_msg);
532
533                LlamaCoreError::Operation(err_msg)
534            })?;
535
536            // retrieve the number of prompt and completion token
537            let token_info = get_token_info_by_graph(graph)?;
538
539            #[cfg(feature = "logging")]
540            info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
541
542            let usage = Some(Usage {
543                prompt_tokens: token_info.prompt_tokens,
544                completion_tokens: token_info.completion_tokens,
545                total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
546            });
547
548            let created = SystemTime::now()
549                .duration_since(std::time::UNIX_EPOCH)
550                .map_err(|e| {
551                    let err_msg = format!("Failed to get the current time. Reason: {}", e);
552
553                    #[cfg(feature = "logging")]
554                    error!(target: "stdout", "{}", &err_msg);
555
556                    LlamaCoreError::Operation(err_msg)
557                })?;
558
559            // prompt too long chunk
560            let prompt_too_long_chunk = {
561                let chat_completion_chunk = ChatCompletionChunk {
562                    id: id.clone(),
563                    object: "chat.completion.chunk".to_string(),
564                    created: created.as_secs(),
565                    model: graph.name().to_owned(),
566                    system_fingerprint: "fp_44709d6fcb".to_string(),
567                    choices: vec![ChatCompletionChunkChoice {
568                        index: 0,
569                        delta: ChatCompletionChunkChoiceDelta {
570                            role: ChatCompletionRole::Assistant,
571                            content: Some(message),
572                            tool_calls: vec![],
573                        },
574                        logprobs: None,
575                        finish_reason: Some(FinishReason::length),
576                    }],
577                    usage: None,
578                };
579
580                // serialize chat completion chunk
581                let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
582                    let err_msg =
583                        format!("Failed to serialize chat completion chunk. Reason: {}", e);
584
585                    #[cfg(feature = "logging")]
586                    error!(target: "stdout", "{}", &err_msg);
587
588                    LlamaCoreError::Operation(err_msg)
589                })?;
590
591                format!("data: {}\n\n", chunk_str)
592            };
593
594            // usage chunk
595            let usage_chunk = {
596                let chat_completion_chunk = ChatCompletionChunk {
597                    id: id.clone(),
598                    object: "chat.completion.chunk".to_string(),
599                    created: created.as_secs(),
600                    model: graph.name().to_owned(),
601                    system_fingerprint: "fp_44709d6fcb".to_string(),
602                    choices: vec![],
603                    usage,
604                };
605
606                // serialize chat completion chunk
607                let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
608                    let err_msg =
609                        format!("Failed to serialize chat completion chunk. Reason: {}", e);
610
611                    #[cfg(feature = "logging")]
612                    error!(target: "stdout", "{}", &err_msg);
613
614                    LlamaCoreError::Operation(err_msg)
615                })?;
616
617                format!("data: {}\n\n", chunk_str)
618            };
619
620            // ending chunk
621            let ending_chunk = "data: [DONE]\n\n".to_string();
622
623            let chunks = vec![prompt_too_long_chunk, usage_chunk, ending_chunk];
624
625            Ok(ChatStream::new(
626                Some(graph.name().to_owned()),
627                id,
628                include_usage,
629                Some(chunks),
630            ))
631        }
632        Err(e) => {
633            let err_msg = format!("Failed to compute the chat completion. Reason: {}", e);
634
635            #[cfg(feature = "logging")]
636            error!(target: "stdout", "{}", &err_msg);
637
638            Err(LlamaCoreError::Backend(BackendError::Compute(err_msg)))
639        }
640    }
641}
642
643async fn chat_once(
644    chat_request: &mut ChatCompletionRequest,
645) -> Result<ChatCompletionObject, LlamaCoreError> {
646    #[cfg(feature = "logging")]
647    info!(target: "stdout", "Processing chat completion request in non-stream mode");
648
649    let running_mode = running_mode()?;
650    if !running_mode.contains(RunningMode::CHAT) && !running_mode.contains(RunningMode::RAG) {
651        let err_msg = "The chat completion is only supported in the chat or rag mode.";
652
653        #[cfg(feature = "logging")]
654        error!(target: "stdout", "{}", err_msg);
655
656        return Err(LlamaCoreError::Operation(err_msg.to_string()));
657    }
658
659    let model_name = chat_request.model.clone();
660    let id = match &chat_request.user {
661        Some(id) => id.clone(),
662        None => gen_chat_id(),
663    };
664
665    #[cfg(feature = "logging")]
666    info!(target: "stdout", "user: {}", &id);
667
668    #[cfg(feature = "logging")]
669    info!(target: "stdout", "Check model metadata");
670
671    // update metadata
672    let mut metadata = check_model_metadata(chat_request).await?;
673
674    #[cfg(feature = "logging")]
675    info!(target: "stdout", "Build the chat prompt");
676
677    // build prompt
678    let (prompt, avaible_completion_tokens, tool_use) =
679        build_prompt(model_name.as_ref(), chat_request)?;
680
681    #[cfg(feature = "logging")]
682    {
683        info!(target: "stdout", "prompt:\n{}", &prompt);
684        info!(target: "stdout", "available_completion_tokens: {}", avaible_completion_tokens);
685        info!(target: "stdout", "tool_use: {}", tool_use);
686    }
687
688    #[cfg(feature = "logging")]
689    info!(target: "stdout", "Update n_predict");
690
691    // update metadata n_predict
692    update_n_predict(chat_request, &mut metadata, avaible_completion_tokens).await?;
693
694    #[cfg(feature = "logging")]
695    info!(target: "stdout", "Feed the prompt to the model");
696
697    // feed the prompt to the model
698    set_prompt(model_name.as_ref(), &prompt)?;
699
700    #[cfg(feature = "logging")]
701    info!(target: "stdout", "Compute chat completion.");
702
703    // compute
704    let res = compute(model_name.as_ref(), id, tool_use);
705
706    #[cfg(feature = "logging")]
707    info!(target: "stdout", "End of the chat completion");
708
709    res
710}
711
712fn compute(
713    model_name: Option<&String>,
714    id: impl Into<String>,
715    tool_use: bool,
716) -> Result<ChatCompletionObject, LlamaCoreError> {
717    let chat_graphs = match CHAT_GRAPHS.get() {
718        Some(chat_graphs) => chat_graphs,
719        None => {
720            let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
721
722            #[cfg(feature = "logging")]
723            error!(target: "stdout", "{}", &err_msg);
724
725            return Err(LlamaCoreError::Operation(err_msg.into()));
726        }
727    };
728
729    let mut chat_graphs = chat_graphs.lock().map_err(|e| {
730        let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {}", e);
731
732        #[cfg(feature = "logging")]
733        error!(target: "stdout", "{}", &err_msg);
734
735        LlamaCoreError::Operation(err_msg)
736    })?;
737
738    match model_name {
739        Some(model_name) => match chat_graphs.contains_key(model_name) {
740            true => {
741                let graph = chat_graphs.get_mut(model_name).unwrap();
742                compute_by_graph(graph, id, tool_use)
743            }
744            false => match chat_graphs.iter_mut().next() {
745                Some((_, graph)) => compute_by_graph(graph, id, tool_use),
746                None => {
747                    let err_msg = "There is no model available in the chat graphs.";
748
749                    #[cfg(feature = "logging")]
750                    error!(target: "stdout", "{}", &err_msg);
751
752                    Err(LlamaCoreError::Operation(err_msg.into()))
753                }
754            },
755        },
756        None => match chat_graphs.iter_mut().next() {
757            Some((_, graph)) => compute_by_graph(graph, id, tool_use),
758            None => {
759                let err_msg = "There is no model available in the chat graphs.";
760
761                #[cfg(feature = "logging")]
762                error!(target: "stdout", "{}", &err_msg);
763
764                Err(LlamaCoreError::Operation(err_msg.into()))
765            }
766        },
767    }
768}
769
770fn compute_by_graph(
771    graph: &mut Graph<GgmlMetadata>,
772    id: impl Into<String>,
773    tool_use: bool,
774) -> Result<ChatCompletionObject, LlamaCoreError> {
775    #[cfg(feature = "logging")]
776    info!(target: "stdout", "Compute chat completion by the model named {}.", graph.name());
777
778    match graph.compute() {
779        Ok(_) => {
780            // Retrieve the output.
781            let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
782            let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
783                let err_msg = format!(
784                    "Failed to decode the buffer of the inference result to a utf-8 string. {}",
785                    e
786                );
787
788                #[cfg(feature = "logging")]
789                error!(target: "stdout", "{}", &err_msg);
790
791                LlamaCoreError::Operation(err_msg)
792            })?;
793
794            #[cfg(feature = "logging")]
795            info!(target: "stdout", "raw generation: {}", output);
796
797            // post-process
798            let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
799                LlamaCoreError::Operation(format!("Failed to post-process the output. {}", e))
800            })?;
801
802            #[cfg(feature = "logging")]
803            info!(target: "stdout", "post-processed generation:\n{}", &message);
804
805            // retrieve the number of prompt and completion tokens
806            let token_info = get_token_info_by_graph(graph)?;
807
808            #[cfg(feature = "logging")]
809            info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
810
811            let created = SystemTime::now()
812                .duration_since(std::time::UNIX_EPOCH)
813                .map_err(|e| {
814                    let err_msg = format!("Failed to get the current time. Reason: {}", e);
815
816                    #[cfg(feature = "logging")]
817                    error!(target: "stdout", "{}", &err_msg);
818
819                    LlamaCoreError::Operation(err_msg)
820                })?;
821
822            match tool_use {
823                true => {
824                    if graph.metadata.prompt_template != PromptTemplateType::MistralTool
825                        && graph.metadata.prompt_template != PromptTemplateType::ChatMLTool
826                        && graph.metadata.prompt_template != PromptTemplateType::GroqLlama3Tool
827                        && graph.metadata.prompt_template != PromptTemplateType::Llama3Tool
828                        && graph.metadata.prompt_template != PromptTemplateType::InternLM2Tool
829                        && graph.metadata.prompt_template != PromptTemplateType::NemotronTool
830                        && graph.metadata.prompt_template != PromptTemplateType::FunctionaryV32
831                        && graph.metadata.prompt_template != PromptTemplateType::FunctionaryV31
832                        && graph.metadata.prompt_template != PromptTemplateType::MistralSmallTool
833                    {
834                        let err_msg = format!("Unsupported prompt template: {}. The tool use is only supported for 'mistral-tool', 'chatml-tool', 'groq-llama3-tool', 'llama-3-tool', 'internlm-2-tool', 'nemotron-tool', 'functionary-31', 'functionary-32', and 'mistral-small-tool' prompt templates.", graph.metadata.prompt_template);
835
836                        #[cfg(feature = "logging")]
837                        error!(target: "stdout", "{}", &err_msg);
838
839                        return Err(LlamaCoreError::Operation(err_msg));
840                    }
841
842                    let parsed_result = parse_tool_calls(&message, graph.metadata.prompt_template)?;
843
844                    let finish_reason = if parsed_result.tool_calls.is_empty() {
845                        FinishReason::stop
846                    } else {
847                        FinishReason::tool_calls
848                    };
849
850                    let content = match parsed_result.content {
851                        Some(content) => Some(content),
852                        None => Some(parsed_result.raw),
853                    };
854
855                    // create ChatCompletionResponse
856                    Ok(ChatCompletionObject {
857                        id: id.into(),
858                        object: String::from("chat.completion"),
859                        created: created.as_secs(),
860                        model: graph.name().to_owned(),
861                        choices: vec![ChatCompletionObjectChoice {
862                            index: 0,
863                            message: ChatCompletionObjectMessage {
864                                role: ChatCompletionRole::Assistant,
865                                content,
866                                tool_calls: parsed_result.tool_calls,
867                                function_call: None,
868                            },
869                            finish_reason,
870                            logprobs: None,
871                        }],
872                        usage: Usage {
873                            prompt_tokens: token_info.prompt_tokens,
874                            completion_tokens: token_info.completion_tokens,
875                            total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
876                        },
877                    })
878                }
879                false => {
880                    // create ChatCompletionResponse
881                    Ok(ChatCompletionObject {
882                        id: id.into(),
883                        object: String::from("chat.completion"),
884                        created: created.as_secs(),
885                        model: graph.name().to_owned(),
886                        choices: vec![ChatCompletionObjectChoice {
887                            index: 0,
888                            message: ChatCompletionObjectMessage {
889                                role: ChatCompletionRole::Assistant,
890                                content: Some(message),
891                                tool_calls: vec![],
892                                function_call: None,
893                            },
894                            finish_reason: FinishReason::stop,
895                            logprobs: None,
896                        }],
897                        usage: Usage {
898                            prompt_tokens: token_info.prompt_tokens,
899                            completion_tokens: token_info.completion_tokens,
900                            total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
901                        },
902                    })
903                }
904            }
905        }
906        Err(wasmedge_wasi_nn::Error::BackendError(wasmedge_wasi_nn::BackendError::ContextFull)) => {
907            // Retrieve the output.
908            let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
909            let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
910                let err_msg = format!(
911                    "Failed to decode the buffer of the inference result to a utf-8 string. {}",
912                    e
913                );
914
915                #[cfg(feature = "logging")]
916                error!(target: "stdout", "{}", &err_msg);
917
918                LlamaCoreError::Operation(err_msg)
919            })?;
920
921            // post-process
922            let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
923                let err_msg = format!("Failed to post-process the output. {}", e);
924
925                #[cfg(feature = "logging")]
926                error!(target: "stdout", "{}", &err_msg);
927
928                LlamaCoreError::Operation(err_msg)
929            })?;
930
931            // retrieve the number of prompt and completion tokens
932            let token_info = get_token_info_by_graph(graph)?;
933
934            #[cfg(feature = "logging")]
935            info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
936
937            let created = SystemTime::now()
938                .duration_since(std::time::UNIX_EPOCH)
939                .map_err(|e| {
940                    let err_msg = format!("Failed to get the current time. Reason: {}", e);
941
942                    #[cfg(feature = "logging")]
943                    error!(target: "stdout", "{}", &err_msg);
944
945                    LlamaCoreError::Operation(err_msg)
946                })?;
947
948            // create ChatCompletionResponse
949            Ok(ChatCompletionObject {
950                id: id.into(),
951                object: String::from("chat.completion"),
952                created: created.as_secs(),
953                model: graph.name().to_owned(),
954                choices: vec![ChatCompletionObjectChoice {
955                    index: 0,
956                    message: ChatCompletionObjectMessage {
957                        role: ChatCompletionRole::Assistant,
958                        content: Some(message),
959                        tool_calls: vec![],
960                        function_call: None,
961                    },
962                    finish_reason: FinishReason::length,
963                    logprobs: None,
964                }],
965                usage: Usage {
966                    prompt_tokens: token_info.prompt_tokens,
967                    completion_tokens: token_info.completion_tokens,
968                    total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
969                },
970            })
971        }
972        Err(wasmedge_wasi_nn::Error::BackendError(
973            wasmedge_wasi_nn::BackendError::PromptTooLong,
974        )) => {
975            #[cfg(feature = "logging")]
976            warn!(target: "stdout", "The prompt is too long. Please reduce the length of your input and try again.");
977
978            // Retrieve the output.
979            let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
980            let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
981                let err_msg = format!(
982                    "Failed to decode the buffer of the inference result to a utf-8 string. {}",
983                    e
984                );
985
986                #[cfg(feature = "logging")]
987                error!(target: "stdout", "{}", &err_msg);
988
989                LlamaCoreError::Operation(err_msg)
990            })?;
991
992            // post-process
993            let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
994                let err_msg = format!("Failed to post-process the output. {}", e);
995
996                #[cfg(feature = "logging")]
997                error!(target: "stdout", "{}", &err_msg);
998
999                LlamaCoreError::Operation(err_msg)
1000            })?;
1001
1002            // retrieve the number of prompt and completion token
1003            let token_info = get_token_info_by_graph(graph)?;
1004
1005            #[cfg(feature = "logging")]
1006            info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
1007
1008            let created = SystemTime::now()
1009                .duration_since(std::time::UNIX_EPOCH)
1010                .map_err(|e| {
1011                    let err_msg = format!("Failed to get the current time. Reason: {}", e);
1012
1013                    #[cfg(feature = "logging")]
1014                    error!(target: "stdout", "{}", &err_msg);
1015
1016                    LlamaCoreError::Operation(err_msg)
1017                })?;
1018
1019            // create ChatCompletionResponse
1020            Ok(ChatCompletionObject {
1021                id: id.into(),
1022                object: String::from("chat.completion"),
1023                created: created.as_secs(),
1024                model: graph.name().to_owned(),
1025                choices: vec![ChatCompletionObjectChoice {
1026                    index: 0,
1027                    message: ChatCompletionObjectMessage {
1028                        role: ChatCompletionRole::Assistant,
1029                        content: Some(message),
1030                        tool_calls: vec![],
1031                        function_call: None,
1032                    },
1033                    finish_reason: FinishReason::length,
1034                    logprobs: None,
1035                }],
1036                usage: Usage {
1037                    prompt_tokens: token_info.prompt_tokens,
1038                    completion_tokens: token_info.completion_tokens,
1039                    total_tokens: token_info.completion_tokens + token_info.completion_tokens,
1040                },
1041            })
1042        }
1043        Err(e) => {
1044            let err_msg = format!("Failed to compute the chat completion. Reason: {}", e);
1045
1046            #[cfg(feature = "logging")]
1047            error!(target: "stdout", "{}", &err_msg);
1048
1049            Err(LlamaCoreError::Backend(BackendError::Compute(err_msg)))
1050        }
1051    }
1052}
1053
1054fn parse_tool_calls(
1055    input: &str,
1056    prompt_template: PromptTemplateType,
1057) -> Result<ParseResult, LlamaCoreError> {
1058    match prompt_template {
1059        PromptTemplateType::MistralTool => match regex::Regex::new(r"\[\{.*?\}\]") {
1060            Ok(re) => {
1061                let mut values: Vec<serde_json::Value> = vec![];
1062                for cap in re.captures_iter(input) {
1063                    let matched = &cap[0];
1064
1065                    #[cfg(feature = "logging")]
1066                    info!(target: "stdout", "captured: {}", matched);
1067
1068                    match serde_json::from_str::<Vec<serde_json::Value>>(matched) {
1069                        Ok(group) => values.extend(group),
1070                        Err(e) => {
1071                            let err_msg = format!(
1072                                "Failed to deserialize generated tool calls. Reason: {}",
1073                                e
1074                            );
1075
1076                            #[cfg(feature = "logging")]
1077                            error!(target: "stdout", "{}", &err_msg);
1078
1079                            return Err(LlamaCoreError::Operation(err_msg));
1080                        }
1081                    }
1082                }
1083
1084                let mut tool_calls: Vec<ToolCall> = vec![];
1085                for value in values.iter() {
1086                    let name = match value.get("name") {
1087                        Some(name) => name.to_string().replace("\"", ""),
1088                        None => {
1089                            let err_msg = format!(
1090                                "Failed to get the name of the function. Tool call: {:?}",
1091                                value
1092                            );
1093
1094                            #[cfg(feature = "logging")]
1095                            error!(target: "stdout", "{}", &err_msg);
1096
1097                            return Err(LlamaCoreError::Operation(err_msg));
1098                        }
1099                    };
1100
1101                    let arguments = match value.get("arguments") {
1102                        Some(arguments) => arguments.to_string(),
1103                        None => {
1104                            let err_msg = format!(
1105                                "Failed to get the arguments of the function. Tool call: {:?}",
1106                                value
1107                            );
1108
1109                            #[cfg(feature = "logging")]
1110                            error!(target: "stdout", "{}", &err_msg);
1111
1112                            return Err(LlamaCoreError::Operation(err_msg));
1113                        }
1114                    };
1115
1116                    let function = Function { name, arguments };
1117
1118                    let tool_call = ToolCall {
1119                        id: "call_abc123".to_string(),
1120                        ty: "function".to_string(),
1121                        function,
1122                    };
1123
1124                    tool_calls.push(tool_call);
1125                }
1126
1127                let parsed = ParseResult {
1128                    raw: input.to_owned(),
1129                    content: None,
1130                    tool_calls,
1131                };
1132
1133                #[cfg(feature = "logging")]
1134                info!(target: "stdout", "parsed result: {:?}", parsed);
1135
1136                Ok(parsed)
1137            }
1138            Err(e) => {
1139                let err_msg = format!("Failed to create a regex pattern. Reason: {}", e);
1140
1141                #[cfg(feature = "logging")]
1142                error!(target: "stdout", "{}", &err_msg);
1143
1144                Err(LlamaCoreError::Operation(err_msg))
1145            }
1146        },
1147        PromptTemplateType::ChatMLTool => {
1148            match regex::Regex::new(r"<tool_call>(.*?)</tool_call>") {
1149                Ok(re) => {
1150                    let mut values: Vec<serde_json::Value> = vec![];
1151                    for cap in re.captures_iter(input) {
1152                        let matched = cap[1].replace("\\n", ""); // Remove "\\n" from the captured group
1153
1154                        #[cfg(feature = "logging")]
1155                        info!(target: "stdout", "captured: {}", &matched);
1156
1157                        match serde_json::from_str::<serde_json::Value>(&matched) {
1158                            Ok(value) => values.push(value),
1159                            Err(e) => {
1160                                let err_msg = format!(
1161                                    "Failed to deserialize generated tool calls. Reason: {}",
1162                                    e
1163                                );
1164
1165                                #[cfg(feature = "logging")]
1166                                error!(target: "stdout", "{}", &err_msg);
1167
1168                                return Err(LlamaCoreError::Operation(err_msg));
1169                            }
1170                        }
1171                    }
1172
1173                    let mut tool_calls: Vec<ToolCall> = vec![];
1174                    for value in values.iter() {
1175                        let name = match value.get("name") {
1176                            Some(name) => name.to_string().replace("\"", ""),
1177                            None => {
1178                                let err_msg = format!(
1179                                    "Failed to get the name of the function. Tool call: {:?}",
1180                                    value
1181                                );
1182
1183                                #[cfg(feature = "logging")]
1184                                error!(target: "stdout", "{}", &err_msg);
1185
1186                                return Err(LlamaCoreError::Operation(err_msg));
1187                            }
1188                        };
1189
1190                        let arguments = match value.get("arguments") {
1191                            Some(arguments) => arguments.to_string(),
1192                            None => {
1193                                let err_msg = format!(
1194                                    "Failed to get the arguments of the function. Tool call: {:?}",
1195                                    value
1196                                );
1197
1198                                #[cfg(feature = "logging")]
1199                                error!(target: "stdout", "{}", &err_msg);
1200
1201                                return Err(LlamaCoreError::Operation(err_msg));
1202                            }
1203                        };
1204
1205                        let function = Function { name, arguments };
1206
1207                        let tool_call = ToolCall {
1208                            id: "call_abc123".to_string(),
1209                            ty: "function".to_string(),
1210                            function,
1211                        };
1212
1213                        tool_calls.push(tool_call);
1214                    }
1215
1216                    let parsed = ParseResult {
1217                        raw: input.to_owned(),
1218                        content: None,
1219                        tool_calls,
1220                    };
1221
1222                    #[cfg(feature = "logging")]
1223                    info!(target: "stdout", "parsed result: {:?}", parsed);
1224
1225                    Ok(parsed)
1226                }
1227                Err(e) => {
1228                    let err_msg = format!("Failed to create a regex pattern. Reason: {}", e);
1229
1230                    #[cfg(feature = "logging")]
1231                    error!(target: "stdout", "{}", &err_msg);
1232
1233                    Err(LlamaCoreError::Operation(err_msg))
1234                }
1235            }
1236        }
1237        PromptTemplateType::GroqLlama3Tool => {
1238            #[cfg(feature = "logging")]
1239            info!(target: "stdout", "raw input: {}", input);
1240
1241            match regex::Regex::new(r"(?s)<tool_call>((.|\r|\n)*?)</tool_call>") {
1242                Ok(re) => {
1243                    let mut values: Vec<serde_json::Value> = vec![];
1244                    for cap in re.captures_iter(input) {
1245                        let matched = cap[1].trim();
1246
1247                        #[cfg(feature = "logging")]
1248                        info!(target: "stdout", "captured: {}", matched);
1249
1250                        match serde_json::from_str::<serde_json::Value>(matched) {
1251                            Ok(value) => values.push(value),
1252                            Err(e) => {
1253                                let err_msg = format!(
1254                                    "Failed to deserialize generated tool calls. Reason: {}",
1255                                    e
1256                                );
1257
1258                                #[cfg(feature = "logging")]
1259                                error!(target: "stdout", "{}", &err_msg);
1260
1261                                return Err(LlamaCoreError::Operation(err_msg));
1262                            }
1263                        }
1264                    }
1265
1266                    let mut tool_calls: Vec<ToolCall> = vec![];
1267                    for value in values.iter() {
1268                        let name = match value.get("name") {
1269                            Some(name) => name.to_string().replace("\"", ""),
1270                            None => {
1271                                let err_msg = format!(
1272                                    "Failed to get the name of the function. Tool call: {:?}",
1273                                    value
1274                                );
1275
1276                                #[cfg(feature = "logging")]
1277                                error!(target: "stdout", "{}", &err_msg);
1278
1279                                return Err(LlamaCoreError::Operation(err_msg));
1280                            }
1281                        };
1282
1283                        let arguments = match value.get("arguments") {
1284                            Some(arguments) => arguments.to_string(),
1285                            None => {
1286                                let err_msg = format!(
1287                                    "Failed to get the arguments of the function. Tool call: {:?}",
1288                                    value
1289                                );
1290
1291                                #[cfg(feature = "logging")]
1292                                error!(target: "stdout", "{}", &err_msg);
1293
1294                                return Err(LlamaCoreError::Operation(err_msg));
1295                            }
1296                        };
1297
1298                        let function = Function { name, arguments };
1299
1300                        let tool_call = ToolCall {
1301                            id: "call_abc123".to_string(),
1302                            ty: "function".to_string(),
1303                            function,
1304                        };
1305
1306                        tool_calls.push(tool_call);
1307                    }
1308
1309                    let parsed = ParseResult {
1310                        raw: input.to_owned(),
1311                        content: None,
1312                        tool_calls,
1313                    };
1314
1315                    #[cfg(feature = "logging")]
1316                    info!(target: "stdout", "parsed result: {:?}", parsed);
1317
1318                    Ok(parsed)
1319                }
1320                Err(e) => {
1321                    let err_msg = format!("Failed to create a regex pattern. Reason: {}", e);
1322
1323                    #[cfg(feature = "logging")]
1324                    error!(target: "stdout", "{}", &err_msg);
1325
1326                    Err(LlamaCoreError::Operation(err_msg))
1327                }
1328            }
1329        }
1330        PromptTemplateType::Llama3Tool => {
1331            #[cfg(feature = "logging")]
1332            info!(target: "stdout", "raw input: {}", input);
1333
1334            let re = match regex::Regex::new(r"^\{(.|\r|\n)*\}$") {
1335                Ok(re) => re,
1336                Err(e) => {
1337                    let err_msg = format!("Failed to create a regex pattern. Reason: {}", e);
1338
1339                    #[cfg(feature = "logging")]
1340                    error!(target: "stdout", "{}", &err_msg);
1341
1342                    return Err(LlamaCoreError::Operation(err_msg));
1343                }
1344            };
1345
1346            if re.is_match(input) {
1347                match serde_json::from_str::<serde_json::Value>(input) {
1348                    Ok(value) => {
1349                        let values: Vec<serde_json::Value> = vec![value];
1350
1351                        let mut tool_calls: Vec<ToolCall> = vec![];
1352                        for value in values.iter() {
1353                            let name = match value.get("name") {
1354                                Some(name) => name.to_string().replace("\"", ""),
1355                                None => {
1356                                    let err_msg = format!(
1357                                        "Failed to get the name of the function. Tool call: {:?}",
1358                                        value
1359                                    );
1360
1361                                    #[cfg(feature = "logging")]
1362                                    error!(target: "stdout", "{}", &err_msg);
1363
1364                                    return Err(LlamaCoreError::Operation(err_msg));
1365                                }
1366                            };
1367
1368                            let arguments = match value.get("parameters") {
1369                                Some(arguments) => arguments.to_string(),
1370                                None => {
1371                                    let err_msg = format!(
1372                                        "Failed to get the arguments of the function. Tool call: {:?}",
1373                                        value
1374                                    );
1375
1376                                    #[cfg(feature = "logging")]
1377                                    error!(target: "stdout", "{}", &err_msg);
1378
1379                                    return Err(LlamaCoreError::Operation(err_msg));
1380                                }
1381                            };
1382
1383                            let function = Function { name, arguments };
1384
1385                            let tool_call = ToolCall {
1386                                id: "call_abc123".to_string(),
1387                                ty: "function".to_string(),
1388                                function,
1389                            };
1390
1391                            tool_calls.push(tool_call);
1392                        }
1393
1394                        let parsed = ParseResult {
1395                            raw: input.to_owned(),
1396                            content: None,
1397                            tool_calls,
1398                        };
1399
1400                        #[cfg(feature = "logging")]
1401                        info!(target: "stdout", "parsed result: {:?}", parsed);
1402
1403                        Ok(parsed)
1404                    }
1405                    Err(e) => {
1406                        let err_msg =
1407                            format!("Failed to deserialize generated tool calls. Reason: {}", e);
1408
1409                        #[cfg(feature = "logging")]
1410                        error!(target: "stdout", "{}", &err_msg);
1411
1412                        Err(LlamaCoreError::Operation(err_msg))
1413                    }
1414                }
1415            } else {
1416                let parsed = ParseResult {
1417                    raw: input.to_owned(),
1418                    content: None,
1419                    tool_calls: vec![],
1420                };
1421
1422                #[cfg(feature = "logging")]
1423                info!(target: "stdout", "parsed result: {:?}", parsed);
1424
1425                Ok(parsed)
1426            }
1427        }
1428        PromptTemplateType::InternLM2Tool => {
1429            #[cfg(feature = "logging")]
1430            info!(target: "stdout", "raw input: {}", input);
1431
1432            let blocks: Vec<&str> = input.trim().split("<|action_start|><|plugin|>").collect();
1433
1434            #[cfg(feature = "logging")]
1435            info!(target: "stdout", "blocks: {:?}", blocks);
1436
1437            let mut tool_calls: Vec<ToolCall> = vec![];
1438            let mut content = String::new();
1439            for block in blocks {
1440                let block = block.trim();
1441                if !block.is_empty() {
1442                    if block.ends_with("<|action_end|>") {
1443                        let value = block.trim().trim_end_matches("<|action_end|>");
1444
1445                        #[cfg(feature = "logging")]
1446                        info!(target: "stdout", "tool call: {}", value);
1447
1448                        match serde_json::from_str::<serde_json::Value>(value) {
1449                            Ok(value) => {
1450                                let name = match value.get("name") {
1451                                    Some(name) => name.to_string().replace("\"", ""),
1452                                    None => {
1453                                        let err_msg = format!(
1454                                            "Failed to get the name of the function. Tool call: {:?}",
1455                                            value
1456                                        );
1457
1458                                        #[cfg(feature = "logging")]
1459                                        error!(target: "stdout", "{}", &err_msg);
1460
1461                                        return Err(LlamaCoreError::Operation(err_msg));
1462                                    }
1463                                };
1464
1465                                let arguments = match value.get("parameters") {
1466                                    Some(arguments) => arguments.to_string(),
1467                                    None => {
1468                                        let err_msg = format!(
1469                                            "Failed to get the arguments of the function. Tool call: {:?}",
1470                                            value
1471                                        );
1472
1473                                        #[cfg(feature = "logging")]
1474                                        error!(target: "stdout", "{}", &err_msg);
1475
1476                                        return Err(LlamaCoreError::Operation(err_msg));
1477                                    }
1478                                };
1479
1480                                let function = Function { name, arguments };
1481
1482                                let tool_call = ToolCall {
1483                                    id: "call_abc123".to_string(),
1484                                    ty: "function".to_string(),
1485                                    function,
1486                                };
1487
1488                                tool_calls.push(tool_call);
1489                            }
1490                            Err(e) => {
1491                                let err_msg = format!(
1492                                    "Failed to deserialize generated tool calls. Reason: {}",
1493                                    e
1494                                );
1495
1496                                #[cfg(feature = "logging")]
1497                                error!(target: "stdout", "{}", &err_msg);
1498
1499                                return Err(LlamaCoreError::Operation(err_msg));
1500                            }
1501                        }
1502                    } else {
1503                        content.push_str(block);
1504                        content.push('\n');
1505                    }
1506                }
1507            }
1508
1509            let parsed = match content.is_empty() {
1510                true => ParseResult {
1511                    raw: input.to_owned(),
1512                    content: None,
1513                    tool_calls,
1514                },
1515                false => ParseResult {
1516                    raw: input.to_owned(),
1517                    content: Some(content.trim().to_owned()),
1518                    tool_calls,
1519                },
1520            };
1521
1522            #[cfg(feature = "logging")]
1523            info!(target: "stdout", "parsed result: {:?}", parsed);
1524
1525            Ok(parsed)
1526        }
1527        PromptTemplateType::NemotronTool => {
1528            #[cfg(feature = "logging")]
1529            info!(target: "stdout", "raw input: {}", input);
1530
1531            match regex::Regex::new(r"(?s)<toolcall>\s*(.*?)\s*</toolcall>") {
1532                Ok(re) => {
1533                    let mut values: Vec<serde_json::Value> = vec![];
1534                    for cap in re.captures_iter(input) {
1535                        #[cfg(feature = "logging")]
1536                        info!(target: "stdout", "captured: {}", &cap[0]);
1537
1538                        #[cfg(feature = "logging")]
1539                        info!(target: "stdout", "extracted: {}", &cap[1]);
1540
1541                        let matched = cap[1].trim();
1542
1543                        #[cfg(feature = "logging")]
1544                        info!(target: "stdout", "captured: {}", matched);
1545
1546                        match serde_json::from_str::<serde_json::Value>(matched) {
1547                            Ok(value) => values.push(value),
1548                            Err(e) => {
1549                                let err_msg = format!(
1550                                    "Failed to deserialize generated tool calls. Reason: {}",
1551                                    e
1552                                );
1553
1554                                #[cfg(feature = "logging")]
1555                                error!(target: "stdout", "{}", &err_msg);
1556
1557                                return Err(LlamaCoreError::Operation(err_msg));
1558                            }
1559                        }
1560                    }
1561
1562                    let mut tool_calls: Vec<ToolCall> = vec![];
1563                    for value in values.iter() {
1564                        let name = match value.get("name") {
1565                            Some(name) => name.to_string().replace("\"", ""),
1566                            None => {
1567                                let err_msg = format!(
1568                                    "Failed to get the name of the function. Tool call: {:?}",
1569                                    value
1570                                );
1571
1572                                #[cfg(feature = "logging")]
1573                                error!(target: "stdout", "{}", &err_msg);
1574
1575                                return Err(LlamaCoreError::Operation(err_msg));
1576                            }
1577                        };
1578
1579                        let arguments = match value.get("arguments") {
1580                            Some(arguments) => arguments.to_string(),
1581                            None => {
1582                                let err_msg = format!(
1583                                    "Failed to get the arguments of the function. Tool call: {:?}",
1584                                    value
1585                                );
1586
1587                                #[cfg(feature = "logging")]
1588                                error!(target: "stdout", "{}", &err_msg);
1589
1590                                return Err(LlamaCoreError::Operation(err_msg));
1591                            }
1592                        };
1593
1594                        let function = Function { name, arguments };
1595
1596                        let tool_call = ToolCall {
1597                            id: "call_abc123".to_string(),
1598                            ty: "function".to_string(),
1599                            function,
1600                        };
1601
1602                        tool_calls.push(tool_call);
1603                    }
1604
1605                    let parsed = ParseResult {
1606                        raw: input.to_owned(),
1607                        content: None,
1608                        tool_calls,
1609                    };
1610
1611                    #[cfg(feature = "logging")]
1612                    info!(target: "stdout", "parsed result: {:?}", parsed);
1613
1614                    Ok(parsed)
1615                }
1616                Err(e) => {
1617                    let err_msg = format!("Failed to create a regex pattern. Reason: {}", e);
1618
1619                    #[cfg(feature = "logging")]
1620                    error!(target: "stdout", "{}", &err_msg);
1621
1622                    Err(LlamaCoreError::Operation(err_msg))
1623                }
1624            }
1625        }
1626        PromptTemplateType::FunctionaryV32 => {
1627            #[cfg(feature = "logging")]
1628            info!(target: "stdout", "raw input: {}", input);
1629
1630            match regex::Regex::new(r">>>\s*(\w+)\s*\{(.*)\}<\|eot_id\|>") {
1631                Ok(re) => {
1632                    let mut tool_calls: Vec<ToolCall> = vec![];
1633                    for cap in re.captures_iter(input) {
1634                        #[cfg(feature = "logging")]
1635                        info!(target: "stdout", "func_name: {}", &cap[1]);
1636
1637                        #[cfg(feature = "logging")]
1638                        info!(target: "stdout", "arguments: {}", &cap[2]);
1639
1640                        let tool_call = ToolCall {
1641                            id: "call_abc123".to_string(),
1642                            ty: "function".to_string(),
1643                            function: Function {
1644                                name: cap[1].to_string(),
1645                                arguments: cap[2].to_string(),
1646                            },
1647                        };
1648
1649                        tool_calls.push(tool_call);
1650                    }
1651
1652                    let parsed = ParseResult {
1653                        raw: input.to_owned(),
1654                        content: None,
1655                        tool_calls,
1656                    };
1657
1658                    #[cfg(feature = "logging")]
1659                    info!(target: "stdout", "parsed result: {:?}", parsed);
1660
1661                    Ok(parsed)
1662                }
1663                Err(e) => {
1664                    let warn_msg = format!("Failed to create a regex pattern. Reason: {}", e);
1665
1666                    #[cfg(feature = "logging")]
1667                    warn!(target: "stdout", "{}", &warn_msg);
1668
1669                    Ok(ParseResult {
1670                        raw: input.to_owned(),
1671                        content: None,
1672                        tool_calls: vec![],
1673                    })
1674                }
1675            }
1676        }
1677        PromptTemplateType::FunctionaryV31 => {
1678            #[cfg(feature = "logging")]
1679            info!(target: "stdout", "raw input: {}", input);
1680
1681            match regex::Regex::new(r"<function=(\w+)>\s*(\{.*?\})</function>") {
1682                Ok(re) => {
1683                    let mut tool_calls: Vec<ToolCall> = vec![];
1684                    for cap in re.captures_iter(input) {
1685                        #[cfg(feature = "logging")]
1686                        info!(target: "stdout", "func_name: {}", &cap[1]);
1687
1688                        #[cfg(feature = "logging")]
1689                        info!(target: "stdout", "arguments: {}", &cap[2]);
1690
1691                        let tool_call = ToolCall {
1692                            id: "call_abc123".to_string(),
1693                            ty: "function".to_string(),
1694                            function: Function {
1695                                name: cap[1].to_string(),
1696                                arguments: cap[2].to_string(),
1697                            },
1698                        };
1699
1700                        tool_calls.push(tool_call);
1701                    }
1702
1703                    let parsed = ParseResult {
1704                        raw: input.to_owned(),
1705                        content: None,
1706                        tool_calls,
1707                    };
1708
1709                    #[cfg(feature = "logging")]
1710                    info!(target: "stdout", "parsed result: {:?}", parsed);
1711
1712                    Ok(parsed)
1713                }
1714                Err(e) => {
1715                    let warn_msg = format!("Failed to create a regex pattern. Reason: {}", e);
1716
1717                    #[cfg(feature = "logging")]
1718                    warn!(target: "stdout", "{}", &warn_msg);
1719
1720                    Ok(ParseResult {
1721                        raw: input.to_owned(),
1722                        content: None,
1723                        tool_calls: vec![],
1724                    })
1725                }
1726            }
1727        }
1728        PromptTemplateType::MistralSmallTool => {
1729            #[cfg(feature = "logging")]
1730            info!(target: "stdout", "raw input: {}", input);
1731
1732            match regex::Regex::new(r"\[TOOL_CALLS\]\s*(\[(.*?)\])") {
1733                Ok(re) => {
1734                    let mut values: Vec<serde_json::Value> = vec![];
1735                    if let Some(cap) = re.captures(input) {
1736                        let matched = cap[1].trim();
1737
1738                        #[cfg(feature = "logging")]
1739                        info!(target: "stdout", "captured: {}", matched);
1740
1741                        match serde_json::from_str::<Vec<serde_json::Value>>(matched) {
1742                            Ok(vals) => values = vals,
1743                            Err(e) => {
1744                                let err_msg = format!(
1745                                    "Failed to deserialize generated tool calls. Reason: {}",
1746                                    e
1747                                );
1748
1749                                #[cfg(feature = "logging")]
1750                                error!(target: "stdout", "{}", &err_msg);
1751
1752                                return Err(LlamaCoreError::Operation(err_msg));
1753                            }
1754                        }
1755                    };
1756
1757                    let mut tool_calls: Vec<ToolCall> = vec![];
1758                    for value in values.iter() {
1759                        if let Some(object_map) = value.as_object() {
1760                            if object_map.contains_key("function") {
1761                                let mut function = Function {
1762                                    name: String::new(),
1763                                    arguments: String::new(),
1764                                };
1765
1766                                let value = object_map.get("function").unwrap();
1767                                let func_map = value.as_object().unwrap();
1768                                if func_map.contains_key("name") {
1769                                    let func_name = func_map.get("name").unwrap().as_str().unwrap();
1770                                    println!("Function name: {:?}", func_name);
1771
1772                                    function.name = func_name.to_string();
1773                                }
1774                                if func_map.contains_key("arguments") {
1775                                    let args = func_map.get("arguments").unwrap();
1776                                    let arguments = args.to_string();
1777                                    println!("Arguments: {:?}", arguments);
1778
1779                                    function.arguments = arguments;
1780                                }
1781
1782                                let tool_call = ToolCall {
1783                                    id: "call_abc123".to_string(),
1784                                    ty: "function".to_string(),
1785                                    function,
1786                                };
1787
1788                                tool_calls.push(tool_call);
1789                            } else if object_map.contains_key("name") {
1790                                let mut function = Function {
1791                                    name: String::new(),
1792                                    arguments: String::new(),
1793                                };
1794
1795                                let name = object_map.get("name").unwrap().as_str().unwrap();
1796                                println!("name: {:?}", name);
1797                                function.name = name.to_string();
1798
1799                                if object_map.contains_key("arguments") {
1800                                    let args = object_map.get("arguments").unwrap();
1801                                    let arguments = args.to_string();
1802                                    println!("Arguments: {:?}", arguments);
1803
1804                                    function.arguments = arguments;
1805                                }
1806
1807                                let tool_call = ToolCall {
1808                                    id: "call_abc123".to_string(),
1809                                    ty: "function".to_string(),
1810                                    function,
1811                                };
1812
1813                                tool_calls.push(tool_call);
1814                            }
1815                        }
1816
1817                        // match value.get("function") {
1818                        //     Some(func) => {
1819                        //         let name = match func.get("name") {
1820                        //             Some(name) => name.as_str().unwrap().to_string(),
1821                        //             None => String::new(),
1822                        //         };
1823
1824                        //         let arguments = match func.get("arguments") {
1825                        //             Some(arguments) => arguments.to_string(),
1826                        //             None => String::new(),
1827                        //         };
1828
1829                        //         let function = Function { name, arguments };
1830
1831                        //         let tool_call = ToolCall {
1832                        //             id: "call_abc123".to_string(),
1833                        //             ty: "function".to_string(),
1834                        //             function,
1835                        //         };
1836
1837                        //         tool_calls.push(tool_call);
1838                        //     }
1839                        //     None => {
1840                        //         let name = match value.get("name") {
1841                        //             Some(name) => name.as_str().unwrap().to_string(),
1842                        //             None => String::new(),
1843                        //         };
1844
1845                        //         let arguments = match value.get("arguments") {
1846                        //             Some(arguments) => arguments.to_string(),
1847                        //             None => String::new(),
1848                        //         };
1849
1850                        //         let function = Function { name, arguments };
1851
1852                        //         let tool_call = ToolCall {
1853                        //             id: "call_abc123".to_string(),
1854                        //             ty: "function".to_string(),
1855                        //             function,
1856                        //         };
1857
1858                        //         tool_calls.push(tool_call);
1859                        //     }
1860                        // }
1861                    }
1862
1863                    let parsed = ParseResult {
1864                        raw: input.to_owned(),
1865                        content: None,
1866                        tool_calls,
1867                    };
1868
1869                    #[cfg(feature = "logging")]
1870                    info!(target: "stdout", "parsed result: {:?}", parsed);
1871
1872                    Ok(parsed)
1873                }
1874                Err(e) => {
1875                    let err_msg = format!("Failed to create a regex pattern. Reason: {}", e);
1876
1877                    #[cfg(feature = "logging")]
1878                    error!(target: "stdout", "{}", &err_msg);
1879
1880                    Err(LlamaCoreError::Operation(err_msg))
1881                }
1882            }
1883        }
1884        _ => {
1885            let err_msg = format!(
1886                "The tool use is only supported for prompt templates: {}, {}, {}, {}, {}, {}, {}, and {}.",
1887                PromptTemplateType::MistralTool,
1888                PromptTemplateType::ChatMLTool,
1889                PromptTemplateType::GroqLlama3Tool,
1890                PromptTemplateType::Llama3Tool,
1891                PromptTemplateType::InternLM2Tool,
1892                PromptTemplateType::NemotronTool,
1893                PromptTemplateType::FunctionaryV32,
1894                PromptTemplateType::MistralSmallTool,
1895            );
1896
1897            #[cfg(feature = "logging")]
1898            error!(target: "stdout", "{}", &err_msg);
1899
1900            Err(LlamaCoreError::Operation(err_msg))
1901        }
1902    }
1903}
1904
1905async fn check_model_metadata(
1906    chat_request: &ChatCompletionRequest,
1907) -> Result<GgmlMetadata, LlamaCoreError> {
1908    let mut should_update = false;
1909    let mut metadata = get_model_metadata(chat_request.model.as_ref())?;
1910
1911    // check if necessary to update `image`
1912    if metadata.prompt_template.is_image_supported() {
1913        if let Some(ChatCompletionRequestMessage::User(user_message)) = chat_request.messages.last()
1914        {
1915            if let ChatCompletionUserMessageContent::Parts(parts) = user_message.content() {
1916                for part in parts {
1917                    if let ContentPart::Image(image_part) = part {
1918                        let image = image_part.image();
1919
1920                        if image.is_url() {
1921                            #[cfg(feature = "logging")]
1922                            info!(target: "stdout", "The image is provided in URL format.");
1923
1924                            // download the image
1925                            let img_path_str = download_image(&image.url).await?;
1926
1927                            #[cfg(feature = "logging")]
1928                            info!(target: "stdout", "The image is saved to {}", img_path_str);
1929
1930                            // update metadata image
1931                            metadata.image = Some(img_path_str);
1932
1933                            if !should_update {
1934                                should_update = true;
1935                            }
1936
1937                            // TODO: now only support a single image
1938
1939                            break;
1940                        } else {
1941                            #[cfg(feature = "logging")]
1942                            info!(target: "stdout", "The image is provided in base64 format.");
1943
1944                            // TODO: now only support a single image
1945
1946                            break;
1947                        }
1948                    }
1949                }
1950            }
1951        }
1952    }
1953
1954    // check if necessary to update temperature
1955    if let Some(temp) = chat_request.temperature {
1956        if metadata.temperature != temp {
1957            // update temperature
1958            metadata.temperature = temp;
1959
1960            if !should_update {
1961                should_update = true;
1962            }
1963        }
1964    }
1965
1966    // check if necessary to update top_p
1967    if let Some(top_p) = chat_request.top_p {
1968        if metadata.top_p != top_p {
1969            // update top_p
1970            metadata.top_p = top_p;
1971
1972            if !should_update {
1973                should_update = true;
1974            }
1975        }
1976    }
1977
1978    // check if necessary to update frequency_penalty
1979    if let Some(frequency_penalty) = chat_request.frequency_penalty {
1980        if metadata.frequency_penalty != frequency_penalty {
1981            // update frequency_penalty
1982            metadata.frequency_penalty = frequency_penalty;
1983
1984            if !should_update {
1985                should_update = true;
1986            }
1987        }
1988    }
1989
1990    // check if necessary to update presence_penalty
1991    if let Some(presence_penalty) = chat_request.presence_penalty {
1992        if metadata.presence_penalty != presence_penalty {
1993            // update presence_penalty
1994            metadata.presence_penalty = presence_penalty;
1995
1996            if !should_update {
1997                should_update = true;
1998            }
1999        }
2000    }
2001
2002    // check if the `embedding` option is disabled
2003    if metadata.embeddings {
2004        metadata.embeddings = false;
2005
2006        if !should_update {
2007            should_update = true;
2008        }
2009    }
2010
2011    if should_update {
2012        #[cfg(feature = "logging")]
2013        info!(target: "stdout", "Update the model metadata.");
2014
2015        // update the target graph with the new metadata
2016        update_model_metadata(chat_request.model.as_ref(), &metadata)?;
2017    }
2018
2019    Ok(metadata)
2020}
2021
2022async fn update_n_predict(
2023    chat_request: &ChatCompletionRequest,
2024    metadata: &mut GgmlMetadata,
2025    available_completion_tokens: u64,
2026) -> Result<(), LlamaCoreError> {
2027    let mut should_update = false;
2028
2029    #[cfg(feature = "logging")]
2030    info!(target: "stdout", "n_predict: {}", metadata.n_predict);
2031
2032    // From high to low priority
2033    // 1. chat_request.max_tokens & chat_request.max_completion_tokens
2034    // 2. available_completion_tokens
2035    // 3. n_predict
2036
2037    if let Some(max_completion_tokens) = chat_request.max_completion_tokens {
2038        if metadata.n_predict != max_completion_tokens {
2039            #[cfg(feature = "logging")]
2040            info!(target: "stdout", "Update n_predict with max_completion_tokens from {} to {}", metadata.n_predict, max_completion_tokens);
2041
2042            metadata.n_predict = max_completion_tokens;
2043
2044            if !should_update {
2045                should_update = true;
2046            }
2047        }
2048    }
2049
2050    // TODO: remove this condition after [Issue #3958 on WasmEdge](https://github.com/WasmEdge/WasmEdge/issues/3958) is fixed
2051    if metadata.n_predict == -2 {
2052        #[cfg(feature = "logging")]
2053        info!(target: "stdout", "Update n_predict with available_completion_tokens from {} to {}", metadata.n_predict, available_completion_tokens);
2054
2055        // update n_predict
2056        metadata.n_predict = available_completion_tokens as i32;
2057
2058        if !should_update {
2059            should_update = true;
2060        }
2061    }
2062
2063    if metadata.n_predict == -1
2064        || (metadata.n_predict > 0 && metadata.n_predict < available_completion_tokens as i32)
2065        || (metadata.n_predict < 0 && metadata.n_predict != -2)
2066    // TODO: remove this condition after [Issue #3958 on WasmEdge](https://github.com/WasmEdge/WasmEdge/issues/3958) is fixed
2067    {
2068        #[cfg(feature = "logging")]
2069        info!(target: "stdout", "Update n_predict with available_completion_tokens from {} to {}", metadata.n_predict, available_completion_tokens);
2070
2071        // update n_predict
2072        metadata.n_predict = available_completion_tokens as i32;
2073
2074        if !should_update {
2075            should_update = true;
2076        }
2077    }
2078
2079    if should_update {
2080        #[cfg(feature = "logging")]
2081        info!(target: "stdout", "Update the model metadata.");
2082
2083        // update the target graph with the new metadata
2084        update_model_metadata(chat_request.model.as_ref(), metadata)?;
2085    }
2086
2087    Ok(())
2088}
2089
2090fn post_process(
2091    output: impl AsRef<str>,
2092    template_ty: &PromptTemplateType,
2093) -> Result<String, String> {
2094    let output = if *template_ty == PromptTemplateType::Baichuan2 {
2095        if output.as_ref().contains("用户:") {
2096            output.as_ref().trim_end_matches("用户:").trim().to_owned()
2097        } else {
2098            output.as_ref().trim().to_owned()
2099        }
2100    } else if *template_ty == PromptTemplateType::OpenChat {
2101        if output.as_ref().contains("<|end_of_turn|>") {
2102            output
2103                .as_ref()
2104                .trim_end_matches("<|end_of_turn|>")
2105                .trim()
2106                .to_owned()
2107        } else {
2108            output.as_ref().trim().to_owned()
2109        }
2110    } else if *template_ty == PromptTemplateType::GemmaInstruct {
2111        let s = output.as_ref().trim();
2112        if s.ends_with("<end_of_turn>") {
2113            s.trim_end_matches("<end_of_turn>").trim().to_owned()
2114        } else {
2115            s.to_owned()
2116        }
2117    } else if *template_ty == PromptTemplateType::ChatML
2118        || *template_ty == PromptTemplateType::ChatMLTool
2119        || *template_ty == PromptTemplateType::InternLM2Tool
2120        || *template_ty == PromptTemplateType::MiniCPMV
2121    {
2122        let mut s = output.as_ref().trim();
2123        if s.ends_with("<|endoftext|>") {
2124            s = s.trim_end_matches("<|endoftext|>").trim();
2125        }
2126
2127        if s.contains("<|im_start|>") && s.contains("<|im_end|>") {
2128            let idx_start = s.find("<|im_start|>").unwrap();
2129            let idx_end = s.find("<|im_end|>").unwrap();
2130
2131            match idx_start <= idx_end {
2132                true => s.split("<|im_start|>").collect::<Vec<_>>()[0]
2133                    .trim()
2134                    .to_owned(),
2135                false => s.split("<|im_end|>").collect::<Vec<_>>()[0]
2136                    .trim()
2137                    .to_owned(),
2138            }
2139        } else if s.contains("<|im_start|>") {
2140            s.split("<|im_start|>").collect::<Vec<_>>()[0]
2141                .trim()
2142                .to_owned()
2143        } else if s.contains("<|im_end|>") {
2144            let output = s.trim_end_matches("<|im_end|>").trim();
2145            if output.starts_with(": ") {
2146                output.trim_start_matches(": ").to_owned()
2147            } else {
2148                output.to_owned()
2149            }
2150        } else {
2151            s.to_owned()
2152        }
2153    } else if *template_ty == PromptTemplateType::Zephyr
2154        || *template_ty == PromptTemplateType::MistralLite
2155        || *template_ty == PromptTemplateType::MistralTool
2156        || *template_ty == PromptTemplateType::MistralInstruct
2157        || *template_ty == PromptTemplateType::MistralSmallChat
2158        || *template_ty == PromptTemplateType::MistralSmallTool
2159        || *template_ty == PromptTemplateType::BreezeInstruct
2160    {
2161        if output.as_ref().contains("</s><") {
2162            output.as_ref().trim_end_matches("</s><").trim().to_owned()
2163        } else if output.as_ref().contains("</s>") {
2164            output
2165                .as_ref()
2166                .strip_suffix("</s>")
2167                .unwrap()
2168                .trim()
2169                .to_owned()
2170        } else {
2171            output.as_ref().trim().to_owned()
2172        }
2173    } else if *template_ty == PromptTemplateType::DeepseekChat {
2174        if output.as_ref().contains("<|end_of_sentence|>") {
2175            output
2176                .as_ref()
2177                .trim_end_matches("<|end_of_sentence|>")
2178                .trim()
2179                .replace("<|end_of_sentence|>", " ")
2180                .trim()
2181                .to_owned()
2182        } else {
2183            output.as_ref().trim().to_owned()
2184        }
2185    } else if *template_ty == PromptTemplateType::HumanAssistant {
2186        if output.as_ref().contains("Human:") {
2187            output.as_ref().trim_end_matches("Human:").trim().to_owned()
2188        } else {
2189            output.as_ref().trim().to_owned()
2190        }
2191    } else if *template_ty == PromptTemplateType::SolarInstruct {
2192        let s = output.as_ref().trim();
2193
2194        if s.starts_with("### Answer") {
2195            let s = s.trim_start_matches("###").trim();
2196
2197            if s.starts_with("Answer:\n") {
2198                s.replace("Answer:\n", "Answer: ")
2199            } else {
2200                s.to_owned()
2201            }
2202        } else {
2203            s.to_owned()
2204        }
2205    } else if *template_ty == PromptTemplateType::Llama2Chat
2206        || *template_ty == PromptTemplateType::NemotronTool
2207        || *template_ty == PromptTemplateType::NemotronChat
2208    {
2209        let s = output.as_ref().trim();
2210        if s.ends_with("</s>") {
2211            s.trim_end_matches("</s>").trim().to_owned()
2212        } else {
2213            s.to_owned()
2214        }
2215    } else if *template_ty == PromptTemplateType::Llama3Chat
2216        || *template_ty == PromptTemplateType::GroqLlama3Tool
2217        || *template_ty == PromptTemplateType::Llama3Tool
2218        || *template_ty == PromptTemplateType::FunctionaryV32
2219    {
2220        let s = output.as_ref().trim();
2221        if s.ends_with("<|eot_id|>") {
2222            s.trim_end_matches("<|eot_id|>").trim().to_owned()
2223        } else {
2224            s.to_owned()
2225        }
2226    } else if *template_ty == PromptTemplateType::Phi3Chat {
2227        let s = output.as_ref().trim();
2228        if s.ends_with("<|end|>") {
2229            s.trim_end_matches("<|end|>").trim().to_owned()
2230        } else {
2231            s.to_owned()
2232        }
2233    } else if *template_ty == PromptTemplateType::Phi4Chat {
2234        let s = output.as_ref().trim();
2235        if s.ends_with("<|im_end|>") {
2236            s.trim_end_matches("<|im_end|>").trim().to_owned()
2237        } else {
2238            s.to_owned()
2239        }
2240    } else if *template_ty == PromptTemplateType::FunctionaryV31 {
2241        let mut s = output.as_ref().trim();
2242        if s.ends_with("<|eot_id|>") {
2243            s = s.trim_end_matches("<|eot_id|>").trim();
2244        }
2245        if s.ends_with("<|eom_id|>") {
2246            s = s.trim_end_matches("<|eom_id|>").trim();
2247        }
2248        s.to_owned()
2249    } else if *template_ty == PromptTemplateType::MoxinChat {
2250        let s = output.as_ref().trim();
2251        if s.ends_with("</s>") {
2252            s.trim_end_matches("</s>").trim().to_owned()
2253        } else if s.ends_with("[INST]") {
2254            s.trim_end_matches("[INST]").trim().to_owned()
2255        } else {
2256            s.to_owned()
2257        }
2258    } else if *template_ty == PromptTemplateType::Falcon3 {
2259        let s = output.as_ref().trim();
2260        if s.ends_with("<|endoftext|>") {
2261            s.trim_end_matches("<|endoftext|>").trim().to_owned()
2262        } else {
2263            s.to_owned()
2264        }
2265    } else if *template_ty == PromptTemplateType::Megrez {
2266        let s = output.as_ref().trim();
2267        if s.ends_with("<|turn_end|>") {
2268            s.trim_end_matches("<|turn_end|>").trim().to_owned()
2269        } else {
2270            s.to_owned()
2271        }
2272    } else if *template_ty == PromptTemplateType::Qwen2vl {
2273        let mut s = output.as_ref().trim();
2274
2275        if s.starts_with(":") {
2276            s = s.trim_start_matches(":").trim();
2277        }
2278
2279        if s.ends_with("<|im_end|>") {
2280            s.trim_end_matches("<|im_end|>").trim().to_owned()
2281        } else {
2282            s.to_owned()
2283        }
2284    } else if *template_ty == PromptTemplateType::VicunaLlava {
2285        let s = output.as_ref().trim();
2286        if s.ends_with("</s>") {
2287            s.trim_end_matches("</s>").trim().to_owned()
2288        } else {
2289            s.to_owned()
2290        }
2291    } else {
2292        output.as_ref().trim().to_owned()
2293    };
2294
2295    Ok(output)
2296}
2297
2298/// Build the chat prompt from the chat messages.
2299///
2300/// # Arguments
2301///
2302/// * `model_name`: The name of the model.
2303///
2304/// * `chat_request`: The chat request.
2305///
2306/// # Returns
2307///
2308/// A tuple containing the prompt, the number of available tokens for completions, and a boolean indicating whether tools are used.
2309fn build_prompt(
2310    model_name: Option<&String>,
2311    chat_request: &mut ChatCompletionRequest,
2312) -> Result<(String, u64, bool), LlamaCoreError> {
2313    let metadata = get_model_metadata(model_name)?;
2314    let ctx_size = metadata.ctx_size as u64;
2315    let chat_prompt = ChatPrompt::from(metadata.prompt_template);
2316
2317    // compute max prompt tokens, which is 80% of the context size
2318    let max_prompt_tokens = ctx_size * 4 / 5;
2319
2320    loop {
2321        // ! DO NOT REMOVE
2322        {
2323            // // build prompt
2324            // let prompt = match chat_prompt.build(&mut chat_request.messages) {
2325            //     Ok(prompt) => prompt,
2326            //     Err(e) => {
2327            //         let err_msg = format!("Fail to build chat prompts. Reason: {}", e);
2328
2329            //         #[cfg(feature = "logging")]
2330            //         error!(target: "stdout", "{}", &err_msg);
2331
2332            //         return Err(LlamaCoreError::Operation(err_msg));
2333            //     }
2334            // };
2335        }
2336
2337        if chat_request.messages.is_empty() {
2338            let err_msg = "The messages in the chat request are empty.";
2339
2340            #[cfg(feature = "logging")]
2341            error!(target: "stdout", "{}", err_msg);
2342
2343            return Err(LlamaCoreError::Operation(err_msg.to_owned()));
2344        }
2345
2346        #[cfg(feature = "logging")]
2347        {
2348            let mut role_chain = String::new();
2349            for (idx, message) in chat_request.messages.iter().enumerate() {
2350                if idx == chat_request.messages.len() - 1 {
2351                    role_chain.push_str(&format!("{}", message.role()));
2352                } else {
2353                    role_chain.push_str(&format!("{} -> ", message.role()));
2354                }
2355            }
2356            info!(target: "stdout", "Role chain: {}", role_chain);
2357        }
2358
2359        let (prompt, tool_use) = match chat_request.tool_choice.as_ref() {
2360            Some(tool_choice) => match tool_choice {
2361                ToolChoice::None => {
2362                    match chat_prompt.build_with_tools(&mut chat_request.messages, Some(&[])) {
2363                        Ok(prompt) => (prompt, false),
2364                        Err(e) => {
2365                            let err_msg = format!("Fail to build chat prompts. Reason: {}", e);
2366
2367                            #[cfg(feature = "logging")]
2368                            error!(target: "stdout", "{}", &err_msg);
2369
2370                            return Err(LlamaCoreError::Operation(err_msg));
2371                        }
2372                    }
2373                }
2374                _ => match chat_request.tools.as_ref() {
2375                    Some(tools) => match chat_prompt
2376                        .build_with_tools(&mut chat_request.messages, Some(tools.as_slice()))
2377                    {
2378                        Ok(prompt) => (prompt, true),
2379                        Err(e) => {
2380                            let err_msg = format!("Fail to build chat prompts. Reason: {}", e);
2381
2382                            #[cfg(feature = "logging")]
2383                            error!(target: "stdout", "{}", &err_msg);
2384
2385                            return Err(LlamaCoreError::Operation(err_msg));
2386                        }
2387                    },
2388                    None => {
2389                        #[cfg(feature = "logging")]
2390                        warn!(target: "stdout", "The tool choice without tools is not supported.");
2391
2392                        match chat_prompt.build_with_tools(&mut chat_request.messages, None) {
2393                            Ok(prompt) => (prompt, false),
2394                            Err(e) => {
2395                                let err_msg = format!("Fail to build chat prompts. Reason: {}", e);
2396
2397                                #[cfg(feature = "logging")]
2398                                error!(target: "stdout", "{}", &err_msg);
2399
2400                                return Err(LlamaCoreError::Operation(err_msg));
2401                            }
2402                        }
2403                    }
2404                },
2405            },
2406            None => match chat_prompt.build_with_tools(&mut chat_request.messages, None) {
2407                Ok(prompt) => (prompt, false),
2408                Err(e) => {
2409                    let err_msg = format!("Fail to build chat prompts. Reason: {}", e);
2410
2411                    #[cfg(feature = "logging")]
2412                    error!(target: "stdout", "{}", &err_msg);
2413
2414                    return Err(LlamaCoreError::Operation(err_msg));
2415                }
2416            },
2417        };
2418        #[cfg(feature = "logging")]
2419        info!(target: "stdout", "Try to set prompt: {}", prompt);
2420
2421        // set prompt
2422        set_prompt(model_name, &prompt)?;
2423
2424        // Retrieve the number of prompt tokens.
2425        let token_info = get_token_info_by_graph_name(model_name)?;
2426
2427        match token_info.prompt_tokens > max_prompt_tokens {
2428            true => {
2429                match chat_request.messages[0].role() {
2430                    ChatCompletionRole::System => {
2431                        if chat_request.messages.len() > 2 {
2432                            #[cfg(feature = "logging")]
2433                            info!(target: "stdout", "Prune chat history: current length {}", chat_request.messages.len());
2434
2435                            // remove user_1 if it exists
2436                            // For example, `system -> user_1 -> ... -> user_2 -> ... -> user_latest` will be converted to `system -> ... -> user_2 -> ... -> user_latest`
2437                            if chat_request.messages[1].role() == ChatCompletionRole::User {
2438                                let user_message = chat_request.messages.remove(1);
2439
2440                                #[cfg(feature = "logging")]
2441                                info!(target: "stdout", "Remove a user message from the chat history: {:?}", user_message);
2442                            }
2443
2444                            // remove all messages until the message is of `user`
2445                            // For example, `system -> ... -> user_2 -> ... -> user_latest` will be converted to `system -> user_2 -> ... -> user_latest`
2446                            while chat_request.messages[1].role() != ChatCompletionRole::User {
2447                                let message = chat_request.messages.remove(1);
2448
2449                                #[cfg(feature = "logging")]
2450                                info!(target: "stdout", "Remove a {} message from the chat history: {:?}", message.role(), message);
2451
2452                                if chat_request.messages.len() == 1 {
2453                                    let err_msg = format!("The last message in the chat history should be a user message, but found a {} message.", message.role());
2454
2455                                    #[cfg(feature = "logging")]
2456                                    error!(target: "stdout", "{}", err_msg);
2457
2458                                    return Err(LlamaCoreError::Operation(err_msg));
2459                                }
2460                            }
2461                        } else if token_info.prompt_tokens > ctx_size {
2462                            let err_msg = format!(
2463                                    "The number of prompt tokens ({}) is greater than the context size ({}). Please increase the context size, or simplify the input message.",
2464                                    token_info.prompt_tokens, ctx_size
2465                                );
2466
2467                            #[cfg(feature = "logging")]
2468                            error!(target: "stdout", "{}", &err_msg);
2469
2470                            return Err(LlamaCoreError::Operation(err_msg));
2471                        } else {
2472                            return Ok((prompt, ctx_size - token_info.prompt_tokens, tool_use));
2473                        }
2474                    }
2475                    ChatCompletionRole::User => {
2476                        if chat_request.messages.len() > 1 {
2477                            // user_1 -> ... -> user_2 -> ... -> user_latest
2478
2479                            // remove user_1 if it exists
2480                            // For example, `user_1 -> ... -> user_2 -> ... -> user_latest` will be converted to `... -> user_2 -> ... -> user_latest`
2481                            if chat_request.messages[0].role() == ChatCompletionRole::User {
2482                                let user_message = chat_request.messages.remove(0);
2483
2484                                #[cfg(feature = "logging")]
2485                                info!(target: "stdout", "Remove a user message from the chat history: {:?}", user_message);
2486                            }
2487
2488                            // remove all messages until the message is of `user`
2489                            // For example, `... -> user_2 -> ... -> user_latest` will be converted to `user_2 -> ... -> user_latest`
2490                            while chat_request.messages[0].role() != ChatCompletionRole::User {
2491                                let message = chat_request.messages.remove(0);
2492
2493                                #[cfg(feature = "logging")]
2494                                info!(target: "stdout", "Remove a {} message from the chat history: {:?}", message.role(), message);
2495
2496                                if chat_request.messages.is_empty() {
2497                                    let err_msg = format!("The last message in the chat history should be a user message, but found a {} message.", message.role());
2498
2499                                    #[cfg(feature = "logging")]
2500                                    error!(target: "stdout", "{}", err_msg);
2501
2502                                    return Err(LlamaCoreError::Operation(err_msg));
2503                                }
2504                            }
2505                        } else if token_info.prompt_tokens > ctx_size {
2506                            let err_msg = format!(
2507                                    "The number of prompt tokens ({}) is greater than the context size ({}). Please increase the context size, or simplify the input message.",
2508                                    token_info.prompt_tokens, ctx_size
2509                                );
2510
2511                            #[cfg(feature = "logging")]
2512                            error!(target: "stdout", "{}", &err_msg);
2513
2514                            return Err(LlamaCoreError::Operation(err_msg));
2515                        } else {
2516                            return Ok((prompt, ctx_size - token_info.prompt_tokens, tool_use));
2517                        }
2518                    }
2519                    _ => {
2520                        #[cfg(feature = "logging")]
2521                        info!(target: "stdout", "remove a {} message from the message queue", chat_request.messages[0].role());
2522
2523                        chat_request.messages.remove(0);
2524                    }
2525                }
2526
2527                continue;
2528            }
2529            false => return Ok((prompt, ctx_size - max_prompt_tokens, tool_use)),
2530        }
2531    }
2532}
2533
2534/// Downloads an image from the given URL and returns the file name.
2535async fn download_image(image_url: impl AsRef<str>) -> Result<String, LlamaCoreError> {
2536    let image_url = image_url.as_ref();
2537    let url = reqwest::Url::parse(image_url).map_err(|e| {
2538        let err_msg = format!("Fail to parse the image URL: {}. Reason: {}", image_url, e);
2539
2540        #[cfg(feature = "logging")]
2541        error!(target: "stdout", "{}", &err_msg);
2542
2543        LlamaCoreError::Operation(err_msg)
2544    })?;
2545
2546    let response = reqwest::get(url).await.map_err(|e| {
2547        let err_msg = format!(
2548            "Fail to download the image from the URL: {}. Reason: {}",
2549            image_url, e
2550        );
2551
2552        #[cfg(feature = "logging")]
2553        error!(target: "stdout", "{}", &err_msg);
2554
2555        LlamaCoreError::Operation(err_msg)
2556    })?;
2557
2558    let fname = response
2559        .url()
2560        .path_segments()
2561        .and_then(|mut segments| segments.next_back())
2562        .and_then(|name| if name.is_empty() { None } else { Some(name) })
2563        .ok_or(LlamaCoreError::Operation(format!(
2564            "Fail to get the file name: {}",
2565            image_url
2566        )))?
2567        .to_string();
2568
2569    // create a unique file id
2570    let id = format!("file_{}", uuid::Uuid::new_v4());
2571    // save the file
2572    let path = Path::new("archives");
2573    if !path.exists() {
2574        fs::create_dir(path).unwrap();
2575    }
2576    let file_path = path.join(&id);
2577    if !file_path.exists() {
2578        fs::create_dir(&file_path).unwrap();
2579    }
2580    let img_path = file_path.join(&fname);
2581    let mut dest: File = File::create(&img_path).map_err(|e| {
2582        let err_msg = format!("Failed to create archive document {}. {}", &fname, e);
2583
2584        // log
2585        #[cfg(feature = "logging")]
2586        error!(target: "stdout", "{}", &err_msg);
2587
2588        LlamaCoreError::Operation(err_msg)
2589    })?;
2590
2591    let mut content = response.bytes_stream();
2592    while let Some(Ok(item)) = content.next().await {
2593        std::io::copy(&mut item.as_ref(), &mut dest).map_err(|e| {
2594            let err_msg = format!(
2595                "Fail to write the image content to the file: {}. Reason: {}",
2596                &fname, e
2597            );
2598
2599            #[cfg(feature = "logging")]
2600            error!(target: "stdout", "{}", &err_msg);
2601
2602            LlamaCoreError::Operation(err_msg)
2603        })?;
2604    }
2605
2606    Ok(img_path.as_path().to_string_lossy().to_string())
2607}
2608
2609fn set_prompt(model_name: Option<&String>, prompt: impl AsRef<str>) -> Result<(), LlamaCoreError> {
2610    let chat_graphs = match CHAT_GRAPHS.get() {
2611        Some(chat_graphs) => chat_graphs,
2612        None => {
2613            let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
2614
2615            #[cfg(feature = "logging")]
2616            error!(target: "stdout", "{}", &err_msg);
2617
2618            return Err(LlamaCoreError::Operation(err_msg.into()));
2619        }
2620    };
2621
2622    let mut chat_graphs = chat_graphs.lock().map_err(|e| {
2623        let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {}", e);
2624
2625        #[cfg(feature = "logging")]
2626        error!(target: "stdout", "{}", &err_msg);
2627
2628        LlamaCoreError::Operation(err_msg)
2629    })?;
2630
2631    match model_name {
2632        Some(model_name) => {
2633            #[cfg(feature = "logging")]
2634            info!(target: "stdout", "Set prompt to the chat model named {}", model_name);
2635
2636            match chat_graphs.contains_key(model_name) {
2637                true => {
2638                    let graph = chat_graphs.get_mut(model_name).unwrap();
2639                    let tensor_data = prompt.as_ref().as_bytes().to_vec();
2640                    set_tensor_data_u8(graph, 0, &tensor_data)
2641                }
2642                false => match chat_graphs.iter_mut().next() {
2643                    Some((_, graph)) => {
2644                        let tensor_data = prompt.as_ref().as_bytes().to_vec();
2645                        set_tensor_data_u8(graph, 0, &tensor_data)
2646                    }
2647                    None => {
2648                        let err_msg = "There is no model available in the chat graphs.";
2649
2650                        #[cfg(feature = "logging")]
2651                        error!(target: "stdout", "{}", &err_msg);
2652
2653                        Err(LlamaCoreError::Operation(err_msg.into()))
2654                    }
2655                },
2656            }
2657        }
2658        None => {
2659            #[cfg(feature = "logging")]
2660            info!(target: "stdout", "Set prompt to the default chat model.");
2661
2662            match chat_graphs.iter_mut().next() {
2663                Some((_, graph)) => {
2664                    let tensor_data = prompt.as_ref().as_bytes().to_vec();
2665                    set_tensor_data_u8(graph, 0, &tensor_data)
2666                }
2667                None => {
2668                    let err_msg = "There is no model available in the chat graphs while trying to set prompt to the default model.";
2669
2670                    #[cfg(feature = "logging")]
2671                    error!(target: "stdout", "{}", err_msg);
2672
2673                    Err(LlamaCoreError::Operation(err_msg.into()))
2674                }
2675            }
2676        }
2677    }
2678}
2679
2680/// Get a copy of the metadata of the model.
2681fn get_model_metadata(model_name: Option<&String>) -> Result<GgmlMetadata, LlamaCoreError> {
2682    let chat_graphs = match CHAT_GRAPHS.get() {
2683        Some(chat_graphs) => chat_graphs,
2684        None => {
2685            let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
2686
2687            #[cfg(feature = "logging")]
2688            error!(target: "stdout", "{}", err_msg);
2689
2690            return Err(LlamaCoreError::Operation(err_msg.into()));
2691        }
2692    };
2693
2694    let chat_graphs = chat_graphs.lock().map_err(|e| {
2695        let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {}", e);
2696
2697        #[cfg(feature = "logging")]
2698        error!(target: "stdout", "{}", &err_msg);
2699
2700        LlamaCoreError::Operation(err_msg)
2701    })?;
2702
2703    match model_name {
2704        Some(model_name) => match chat_graphs.contains_key(model_name) {
2705            true => {
2706                let graph = chat_graphs.get(model_name).unwrap();
2707                Ok(graph.metadata.clone())
2708            }
2709            false => match chat_graphs.iter().next() {
2710                Some((_, graph)) => Ok(graph.metadata.clone()),
2711                None => {
2712                    let err_msg = "There is no model available in the chat graphs.";
2713
2714                    #[cfg(feature = "logging")]
2715                    error!(target: "stdout", "{}", &err_msg);
2716
2717                    Err(LlamaCoreError::Operation(err_msg.into()))
2718                }
2719            },
2720        },
2721        None => match chat_graphs.iter().next() {
2722            Some((_, graph)) => Ok(graph.metadata.clone()),
2723            None => {
2724                let err_msg = "There is no model available in the chat graphs.";
2725
2726                #[cfg(feature = "logging")]
2727                error!(target: "stdout", "{}", err_msg);
2728
2729                Err(LlamaCoreError::Operation(err_msg.into()))
2730            }
2731        },
2732    }
2733}
2734
2735fn update_model_metadata(
2736    model_name: Option<&String>,
2737    metadata: &GgmlMetadata,
2738) -> Result<(), LlamaCoreError> {
2739    let config = match serde_json::to_string(metadata) {
2740        Ok(config) => config,
2741        Err(e) => {
2742            let err_msg = format!("Fail to serialize metadata to a JSON string. {}", e);
2743
2744            #[cfg(feature = "logging")]
2745            error!(target: "stdout", "{}", &err_msg);
2746
2747            return Err(LlamaCoreError::Operation(err_msg));
2748        }
2749    };
2750
2751    let chat_graphs = match CHAT_GRAPHS.get() {
2752        Some(chat_graphs) => chat_graphs,
2753        None => {
2754            let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
2755
2756            #[cfg(feature = "logging")]
2757            error!(target: "stdout", "{}", err_msg);
2758
2759            return Err(LlamaCoreError::Operation(err_msg.into()));
2760        }
2761    };
2762
2763    let mut chat_graphs = chat_graphs.lock().map_err(|e| {
2764        let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. Reason: {}", e);
2765
2766        #[cfg(feature = "logging")]
2767        error!(target: "stdout", "{}", &err_msg);
2768
2769        LlamaCoreError::Operation(err_msg)
2770    })?;
2771
2772    match model_name {
2773        Some(model_name) => {
2774            match chat_graphs.contains_key(model_name) {
2775                true => {
2776                    let graph = chat_graphs.get_mut(model_name).unwrap();
2777                    // update metadata
2778                    set_tensor_data_u8(graph, 1, config.as_bytes())
2779                }
2780                false => match chat_graphs.iter_mut().next() {
2781                    Some((_, graph)) => {
2782                        // update metadata
2783                        set_tensor_data_u8(graph, 1, config.as_bytes())
2784                    }
2785                    None => {
2786                        let err_msg = "There is no model available in the chat graphs.";
2787
2788                        #[cfg(feature = "logging")]
2789                        error!(target: "stdout", "{}", &err_msg);
2790
2791                        Err(LlamaCoreError::Operation(err_msg.into()))
2792                    }
2793                },
2794            }
2795        }
2796        None => {
2797            match chat_graphs.iter_mut().next() {
2798                Some((_, graph)) => {
2799                    // update metadata
2800                    set_tensor_data_u8(graph, 1, config.as_bytes())
2801                }
2802                None => {
2803                    let err_msg = "There is no model available in the chat graphs.";
2804
2805                    #[cfg(feature = "logging")]
2806                    error!(target: "stdout", "{}", err_msg);
2807
2808                    Err(LlamaCoreError::Operation(err_msg.into()))
2809                }
2810            }
2811        }
2812    }
2813}
2814
2815fn reset_model_metadata(model_name: Option<&String>) -> Result<(), LlamaCoreError> {
2816    // get metadata
2817    let metadata = get_model_metadata(model_name)?;
2818
2819    // update model with the original metadata
2820    update_model_metadata(model_name, &metadata)
2821}
2822
2823#[derive(Debug, Clone, Copy, PartialEq, Eq)]
2824enum ContextFullState {
2825    Message,
2826    Usage,
2827    Done,
2828    EndOfSequence,
2829}
2830
2831#[derive(Debug, Clone, Copy, PartialEq, Eq)]
2832enum StreamState {
2833    Usage,
2834    NoUsage,
2835    Done,
2836    EndOfSequence,
2837}
2838
2839#[derive(Debug, Clone, Copy, PartialEq, Eq)]
2840enum PromptTooLongState {
2841    Message,
2842    Usage,
2843    Done,
2844    EndOfSequence,
2845}
2846
2847struct ChatStream {
2848    id: String,
2849    model: Option<String>,
2850    include_usage: bool,
2851    context_full_state: ContextFullState,
2852    prompt_too_long_state: PromptTooLongState,
2853    stream_state: StreamState,
2854    cache: Option<VecDeque<String>>,
2855}
2856impl ChatStream {
2857    fn new(
2858        model: Option<String>,
2859        id: String,
2860        include_usage: bool,
2861        cache: Option<Vec<String>>,
2862    ) -> Self {
2863        let stream_state = if include_usage {
2864            StreamState::Usage
2865        } else {
2866            StreamState::NoUsage
2867        };
2868
2869        ChatStream {
2870            id,
2871            model,
2872            include_usage,
2873            context_full_state: ContextFullState::Message,
2874            prompt_too_long_state: PromptTooLongState::Message,
2875            stream_state,
2876            cache: cache.map(VecDeque::from),
2877        }
2878    }
2879}
2880impl Drop for ChatStream {
2881    fn drop(&mut self) {
2882        if self.cache.is_none() {
2883            #[cfg(feature = "logging")]
2884            info!(target: "stdout", "Clean up the context of the stream work environment.");
2885
2886            match &self.model {
2887                Some(model_name) => {
2888                    match CHAT_GRAPHS.get() {
2889                        Some(chat_graphs) => match chat_graphs.lock() {
2890                            Ok(mut chat_graphs) => match chat_graphs.contains_key(model_name) {
2891                                true => {
2892                                    let graph = chat_graphs.get_mut(model_name).unwrap();
2893
2894                                    if let Err(e) = graph.finish_single() {
2895                                        let err_msg = format!(
2896                                            "Failed to clean up the context. Reason: {}",
2897                                            e
2898                                        );
2899
2900                                        #[cfg(feature = "logging")]
2901                                        error!(target: "stdout", "{}", &err_msg);
2902
2903                                        #[cfg(not(feature = "logging"))]
2904                                        println!(
2905                                            "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
2906                                            &err_msg
2907                                        );
2908                                    }
2909                                }
2910                                false => match chat_graphs.iter_mut().next() {
2911                                    Some((_, graph)) => {
2912                                        if let Err(e) = graph.finish_single() {
2913                                            let err_msg = format!(
2914                                                "Failed to clean up the context. Reason: {}",
2915                                                e
2916                                            );
2917
2918                                            #[cfg(feature = "logging")]
2919                                            error!(target: "stdout", "{}", &err_msg);
2920
2921                                            #[cfg(not(feature = "logging"))]
2922                                            println!(
2923                                                "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
2924                                                &err_msg
2925                                            );
2926                                        }
2927                                    }
2928                                    None => {
2929                                        let err_msg =
2930                                            "There is no model available in the chat graphs.";
2931
2932                                        #[cfg(feature = "logging")]
2933                                        error!(target: "stdout", "{}", &err_msg);
2934
2935                                        #[cfg(not(feature = "logging"))]
2936                                        println!(
2937                                            "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
2938                                            &err_msg
2939                                        );
2940                                    }
2941                                },
2942                            },
2943                            Err(e) => {
2944                                let err_msg =
2945                                    format!("Fail to acquire the lock of `CHAT_GRAPHS`. {}", e);
2946
2947                                #[cfg(feature = "logging")]
2948                                error!(target: "stdout", "{}", &err_msg);
2949
2950                                #[cfg(not(feature = "logging"))]
2951                                println!(
2952                                "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
2953                                &err_msg
2954                            );
2955                            }
2956                        },
2957                        None => {
2958                            let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
2959
2960                            #[cfg(feature = "logging")]
2961                            error!(target: "stdout", "{}", &err_msg);
2962
2963                            #[cfg(not(feature = "logging"))]
2964                            println!(
2965                                "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
2966                                &err_msg
2967                            );
2968                        }
2969                    };
2970                }
2971                None => {
2972                    match CHAT_GRAPHS.get() {
2973                        Some(chat_graphs) => match chat_graphs.lock() {
2974                            Ok(mut chat_graphs) => match chat_graphs.iter_mut().next() {
2975                                Some((_, graph)) => {
2976                                    if let Err(e) = graph.finish_single() {
2977                                        let err_msg = format!(
2978                                            "Failed to clean up the context. Reason: {}",
2979                                            e
2980                                        );
2981
2982                                        #[cfg(feature = "logging")]
2983                                        error!(target: "stdout", "{}", &err_msg);
2984
2985                                        #[cfg(not(feature = "logging"))]
2986                                        println!(
2987                                        "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
2988                                        &err_msg
2989                                    );
2990                                    }
2991                                }
2992                                None => {
2993                                    let err_msg = "There is no model available in the chat graphs.";
2994
2995                                    #[cfg(feature = "logging")]
2996                                    error!(target: "stdout", "{}", err_msg);
2997
2998                                    #[cfg(not(feature = "logging"))]
2999                                    println!(
3000                                    "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3001                                    err_msg
3002                                );
3003                                }
3004                            },
3005                            Err(e) => {
3006                                let err_msg =
3007                                    format!("Fail to acquire the lock of `CHAT_GRAPHS`. {}", e);
3008
3009                                #[cfg(feature = "logging")]
3010                                error!(target: "stdout", "{}", &err_msg);
3011
3012                                #[cfg(not(feature = "logging"))]
3013                                println!(
3014                                "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3015                                &err_msg
3016                            );
3017                            }
3018                        },
3019                        None => {
3020                            let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
3021
3022                            #[cfg(feature = "logging")]
3023                            error!(target: "stdout", "{}", &err_msg);
3024
3025                            #[cfg(not(feature = "logging"))]
3026                            println!(
3027                                "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3028                                &err_msg
3029                            );
3030                        }
3031                    };
3032                }
3033            }
3034
3035            #[cfg(feature = "logging")]
3036            info!(target: "stdout", "Cleanup done!");
3037        }
3038    }
3039}
3040impl futures::Stream for ChatStream {
3041    type Item = Result<String, LlamaCoreError>;
3042
3043    fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
3044        if self.cache.is_none() {
3045            let this = self.get_mut();
3046            let res = compute_stream(
3047                this.model.clone(),
3048                this.id.clone(),
3049                this.include_usage,
3050                &mut this.prompt_too_long_state,
3051                &mut this.context_full_state,
3052                &mut this.stream_state,
3053            );
3054
3055            match res {
3056                Ok(x) => {
3057                    #[cfg(feature = "logging")]
3058                    info!(target: "stdout", "next item: {}", &x);
3059
3060                    if x != "[GGML] End of sequence" && !x.is_empty() {
3061                        Poll::Ready(Some(Ok(x)))
3062                    } else {
3063                        // stopped
3064                        Poll::Ready(None)
3065                    }
3066                }
3067                Err(e) => Poll::Ready(Some(Err(e))),
3068            }
3069        } else {
3070            let this = self.get_mut();
3071
3072            let x = this.cache.as_mut().unwrap().pop_front();
3073
3074            #[cfg(feature = "logging")]
3075            info!(target: "stdout", "Get the next item from the cache: {:?}", &x);
3076
3077            match x {
3078                Some(x) => Poll::Ready(Some(Ok(x))),
3079                None => Poll::Ready(None),
3080            }
3081        }
3082    }
3083}
3084
3085fn compute_stream(
3086    model_name: Option<String>,
3087    id: String,
3088    include_usage: bool,
3089    prompt_too_long_state: &mut PromptTooLongState,
3090    context_full_state: &mut ContextFullState,
3091    stream_state: &mut StreamState,
3092) -> Result<String, LlamaCoreError> {
3093    #[cfg(feature = "logging")]
3094    info!(target: "stdout", "Compute the chat stream chunk.");
3095
3096    #[cfg(feature = "logging")]
3097    debug!(target: "stdout", "prompt_too_long_state: {:?}", *prompt_too_long_state);
3098    #[cfg(feature = "logging")]
3099    debug!(target: "stdout", "context_full_state: {:?}", *context_full_state);
3100    #[cfg(feature = "logging")]
3101    debug!(target: "stdout", "stream_state: {:?}", *stream_state);
3102
3103    if *prompt_too_long_state == PromptTooLongState::EndOfSequence
3104        || *context_full_state == ContextFullState::EndOfSequence
3105        || *stream_state == StreamState::EndOfSequence
3106    {
3107        #[cfg(feature = "logging")]
3108        info!(target: "stdout", "Return the chat stream chunk!");
3109
3110        return Ok("[GGML] End of sequence".to_string());
3111    }
3112
3113    let chat_graphs = match CHAT_GRAPHS.get() {
3114        Some(chat_graphs) => chat_graphs,
3115        None => {
3116            let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
3117
3118            #[cfg(feature = "logging")]
3119            error!(target: "stdout", "{}", &err_msg);
3120
3121            return Err(LlamaCoreError::Operation(err_msg.into()));
3122        }
3123    };
3124
3125    let mut chat_graphs = chat_graphs.lock().map_err(|e| {
3126        let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {}", e);
3127
3128        #[cfg(feature = "logging")]
3129        error!(target: "stdout", "{}", &err_msg);
3130
3131        LlamaCoreError::Operation(err_msg)
3132    })?;
3133
3134    // get graph
3135    let res = match &model_name {
3136        Some(model_name) => {
3137            match chat_graphs.contains_key(model_name) {
3138                true => {
3139                    let graph = chat_graphs.get_mut(model_name).unwrap();
3140                    // compute
3141                    match graph.compute_single() {
3142                        Ok(_) => {
3143                            #[cfg(feature = "logging")]
3144                            debug!(target: "stdout", "Compute the chat stream chunk successfully.");
3145
3146                            match stream_state {
3147                                StreamState::Usage | StreamState::NoUsage => {
3148                                    // Retrieve the output
3149                                    let output_buffer =
3150                                        get_output_buffer_single(graph, OUTPUT_TENSOR)?;
3151
3152                                    #[cfg(feature = "logging")]
3153                                    info!(target: "stdout", "retrieved the output buffer");
3154
3155                                    // decode the output buffer to a utf8 string
3156                                    let output = match String::from_utf8(output_buffer.clone()) {
3157                                        Ok(token) => token,
3158                                        Err(_) => {
3159                                            let mutex = CACHED_UTF8_ENCODINGS
3160                                                .get_or_init(|| Mutex::new(Vec::new()));
3161                                            let mut cached_encodings = mutex.lock().map_err(|e| {
3162                                            let err_msg = format!(
3163                                                "Fail to acquire the lock of `UTF8_ENCODINGS`. Reason: {}",
3164                                                e
3165                                            );
3166
3167                                            #[cfg(feature = "logging")]
3168                                            error!(target: "stdout", "{}", &err_msg);
3169
3170
3171                                            LlamaCoreError::Operation(err_msg)
3172                                        })?;
3173
3174                                            // cache the bytes for future decoding
3175                                            cached_encodings.extend_from_slice(&output_buffer[..]);
3176
3177                                            match String::from_utf8(cached_encodings.to_vec()) {
3178                                                Ok(token) => {
3179                                                    // clear CACHED_UTF8_ENCODINGS
3180                                                    cached_encodings.clear();
3181
3182                                                    token
3183                                                }
3184                                                Err(e) => {
3185                                                    // TODO This is a temp check. In case, infinite cached encodings happen.
3186                                                    if cached_encodings.len() > 4 {
3187                                                        let err_msg = format!("Fail to convert a vector of bytes to string. The length of the utf8 bytes exceeds 4. {}", e);
3188
3189                                                        #[cfg(feature = "logging")]
3190                                                        error!(target: "stdout", "{}", &err_msg);
3191
3192                                                        #[cfg(feature = "logging")]
3193                                                        error!(target: "stdout", "The cached buffer: {:?}", &cached_encodings[..]);
3194
3195                                                        // let token = String::from_utf8_lossy(
3196                                                        //     &cached_encodings,
3197                                                        // )
3198                                                        // .to_string();
3199
3200                                                        // clear CACHED_UTF8_ENCODINGS
3201                                                        cached_encodings.clear();
3202
3203                                                        String::from("")
3204                                                    } else {
3205                                                        let warn_msg = format!("Fail to convert a vector of bytes to string. {}", e);
3206
3207                                                        #[cfg(feature = "logging")]
3208                                                        warn!(target: "stdout", "{}", &warn_msg);
3209
3210                                                        String::from("")
3211                                                    }
3212                                                }
3213                                            }
3214                                        }
3215                                    };
3216
3217                                    #[cfg(feature = "logging")]
3218                                    info!(target: "stdout", "decoded the output buffer");
3219
3220                                    let created = SystemTime::now()
3221                                        .duration_since(std::time::UNIX_EPOCH)
3222                                        .map_err(|e| {
3223                                            let err_msg = format!(
3224                                                "Failed to get the current time. Reason: {}",
3225                                                e
3226                                            );
3227
3228                                            #[cfg(feature = "logging")]
3229                                            error!(target: "stdout", "{}", &err_msg);
3230
3231                                            LlamaCoreError::Operation(err_msg)
3232                                        })?;
3233
3234                                    let chat_completion_chunk = ChatCompletionChunk {
3235                                        id,
3236                                        object: "chat.completion.chunk".to_string(),
3237                                        created: created.as_secs(),
3238                                        model: graph.name().to_owned(),
3239                                        system_fingerprint: "fp_44709d6fcb".to_string(),
3240                                        choices: vec![ChatCompletionChunkChoice {
3241                                            index: 0,
3242                                            delta: ChatCompletionChunkChoiceDelta {
3243                                                role: ChatCompletionRole::Assistant,
3244                                                content: Some(output),
3245                                                tool_calls: vec![],
3246                                            },
3247                                            logprobs: None,
3248                                            finish_reason: None,
3249                                        }],
3250                                        usage: None,
3251                                    };
3252
3253                                    #[cfg(feature = "logging")]
3254                                    info!(target: "stdout", "created chat completion chunk");
3255
3256                                    // serialize chat completion chunk
3257                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
3258                                        .map_err(|e| {
3259                                        let err_msg = format!(
3260                                            "Failed to serialize chat completion chunk. Reason: {}",
3261                                            e
3262                                        );
3263
3264                                        #[cfg(feature = "logging")]
3265                                        error!(target: "stdout", "{}", &err_msg);
3266
3267                                        LlamaCoreError::Operation(err_msg)
3268                                    })?;
3269
3270                                    Ok(format!("data: {}\n\n", chunk_str))
3271                                }
3272                                StreamState::Done => {
3273                                    *stream_state = StreamState::EndOfSequence;
3274
3275                                    Ok("data: [DONE]\n\n".to_string())
3276                                }
3277                                StreamState::EndOfSequence => {
3278                                    Ok("[GGML] End of sequence".to_string())
3279                                }
3280                            }
3281                        }
3282                        Err(wasmedge_wasi_nn::Error::BackendError(
3283                            wasmedge_wasi_nn::BackendError::EndOfSequence,
3284                        )) => {
3285                            #[cfg(feature = "logging")]
3286                            debug!(target: "stdout", "End of sequence");
3287
3288                            match stream_state {
3289                                StreamState::Usage => {
3290                                    *stream_state = StreamState::Done;
3291
3292                                    // retrieve the number of prompt and completion tokens
3293                                    let token_info = get_token_info_by_graph(graph)?;
3294
3295                                    let usage = Some(Usage {
3296                                        prompt_tokens: token_info.prompt_tokens,
3297                                        completion_tokens: token_info.completion_tokens,
3298                                        total_tokens: token_info.prompt_tokens
3299                                            + token_info.completion_tokens,
3300                                    });
3301
3302                                    #[cfg(feature = "logging")]
3303                                    info!(target: "stdout", "token_info: {} prompt tokens, {} completion tokens", token_info.prompt_tokens, token_info.completion_tokens);
3304
3305                                    let created = SystemTime::now()
3306                                        .duration_since(std::time::UNIX_EPOCH)
3307                                        .map_err(|e| {
3308                                            let err_msg = format!(
3309                                                "Failed to get the current time. Reason: {}",
3310                                                e
3311                                            );
3312
3313                                            #[cfg(feature = "logging")]
3314                                            error!(target: "stdout", "{}", &err_msg);
3315
3316                                            LlamaCoreError::Operation(err_msg)
3317                                        })?;
3318
3319                                    let chat_completion_chunk = ChatCompletionChunk {
3320                                        id,
3321                                        object: "chat.completion.chunk".to_string(),
3322                                        created: created.as_secs(),
3323                                        model: graph.name().to_owned(),
3324                                        system_fingerprint: "fp_44709d6fcb".to_string(),
3325                                        choices: vec![],
3326                                        usage,
3327                                    };
3328
3329                                    // serialize chat completion chunk
3330                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
3331                                        .map_err(|e| {
3332                                        let err_msg = format!(
3333                                            "Failed to serialize chat completion chunk. Reason: {}",
3334                                            e
3335                                        );
3336
3337                                        #[cfg(feature = "logging")]
3338                                        error!(target: "stdout", "{}", &err_msg);
3339
3340                                        LlamaCoreError::Operation(err_msg)
3341                                    })?;
3342
3343                                    Ok(format!("data: {}\n\n", chunk_str))
3344                                }
3345                                StreamState::Done | StreamState::NoUsage => {
3346                                    *stream_state = StreamState::EndOfSequence;
3347
3348                                    Ok("data: [DONE]\n\n".to_string())
3349                                }
3350                                StreamState::EndOfSequence => {
3351                                    Ok("[GGML] End of sequence".to_string())
3352                                }
3353                            }
3354                        }
3355                        Err(wasmedge_wasi_nn::Error::BackendError(
3356                            wasmedge_wasi_nn::BackendError::ContextFull,
3357                        )) => {
3358                            #[cfg(feature = "logging")]
3359                            debug!(target: "stdout", "Context full");
3360
3361                            match context_full_state {
3362                                ContextFullState::Message => {
3363                                    match include_usage {
3364                                        true => *context_full_state = ContextFullState::Usage,
3365                                        false => *context_full_state = ContextFullState::Done,
3366                                    }
3367
3368                                    let created = SystemTime::now()
3369                                        .duration_since(std::time::UNIX_EPOCH)
3370                                        .map_err(|e| {
3371                                            let err_msg = format!(
3372                                                "Failed to get the current time. Reason: {}",
3373                                                e
3374                                            );
3375
3376                                            #[cfg(feature = "logging")]
3377                                            error!(target: "stdout", "{}", &err_msg);
3378
3379                                            LlamaCoreError::Operation(err_msg)
3380                                        })?;
3381
3382                                    let chat_completion_chunk = ChatCompletionChunk {
3383                                        id,
3384                                        object: "chat.completion.chunk".to_string(),
3385                                        created: created.as_secs(),
3386                                        model: graph.name().to_owned(),
3387                                        system_fingerprint: "fp_44709d6fcb".to_string(),
3388                                        choices: vec![ChatCompletionChunkChoice {
3389                                            index: 0,
3390                                            delta: ChatCompletionChunkChoiceDelta {
3391                                                role: ChatCompletionRole::Assistant,
3392                                                content: Some(
3393                                                    "<|WASMEDGE-GGML-CONTEXT-FULL|>".to_string(),
3394                                                ),
3395                                                tool_calls: vec![],
3396                                            },
3397                                            logprobs: None,
3398                                            finish_reason: Some(FinishReason::length),
3399                                        }],
3400                                        usage: None,
3401                                    };
3402
3403                                    // serialize chat completion chunk
3404                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
3405                                        .map_err(|e| {
3406                                        let err_msg = format!(
3407                                            "Failed to serialize chat completion chunk. Reason: {}",
3408                                            e
3409                                        );
3410
3411                                        #[cfg(feature = "logging")]
3412                                        error!(target: "stdout", "{}", &err_msg);
3413
3414                                        LlamaCoreError::Operation(err_msg)
3415                                    })?;
3416
3417                                    Ok(format!("data: {}\n\n", chunk_str))
3418                                }
3419                                ContextFullState::Usage => {
3420                                    *context_full_state = ContextFullState::Done;
3421
3422                                    // retrieve the number of prompt and completion tokens
3423                                    let token_info = get_token_info_by_graph(graph)?;
3424
3425                                    let usage = Some(Usage {
3426                                        prompt_tokens: token_info.prompt_tokens,
3427                                        completion_tokens: token_info.completion_tokens,
3428                                        total_tokens: token_info.prompt_tokens
3429                                            + token_info.completion_tokens,
3430                                    });
3431
3432                                    let created = SystemTime::now()
3433                                        .duration_since(std::time::UNIX_EPOCH)
3434                                        .map_err(|e| {
3435                                            let err_msg = format!(
3436                                                "Failed to get the current time. Reason: {}",
3437                                                e
3438                                            );
3439
3440                                            #[cfg(feature = "logging")]
3441                                            error!(target: "stdout", "{}", &err_msg);
3442
3443                                            LlamaCoreError::Operation(err_msg)
3444                                        })?;
3445
3446                                    let chat_completion_chunk = ChatCompletionChunk {
3447                                        id,
3448                                        object: "chat.completion.chunk".to_string(),
3449                                        created: created.as_secs(),
3450                                        model: graph.name().to_owned(),
3451                                        system_fingerprint: "fp_44709d6fcb".to_string(),
3452                                        choices: vec![],
3453                                        usage,
3454                                    };
3455
3456                                    // serialize chat completion chunk
3457                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
3458                                        .map_err(|e| {
3459                                        let err_msg = format!(
3460                                            "Failed to serialize chat completion chunk. Reason: {}",
3461                                            e
3462                                        );
3463
3464                                        #[cfg(feature = "logging")]
3465                                        error!(target: "stdout", "{}", &err_msg);
3466
3467                                        LlamaCoreError::Operation(err_msg)
3468                                    })?;
3469
3470                                    Ok(format!("data: {}\n\n", chunk_str))
3471                                }
3472                                ContextFullState::Done => {
3473                                    *context_full_state = ContextFullState::EndOfSequence;
3474
3475                                    Ok("data: [DONE]\n\n".to_string())
3476                                }
3477                                ContextFullState::EndOfSequence => {
3478                                    Ok("[GGML] End of sequence".to_string())
3479                                }
3480                            }
3481                        }
3482                        Err(wasmedge_wasi_nn::Error::BackendError(
3483                            wasmedge_wasi_nn::BackendError::PromptTooLong,
3484                        )) => {
3485                            #[cfg(feature = "logging")]
3486                            debug!(target: "stdout", "Prompt too long");
3487
3488                            match prompt_too_long_state {
3489                                PromptTooLongState::Message => {
3490                                    match include_usage {
3491                                        true => *prompt_too_long_state = PromptTooLongState::Usage,
3492                                        false => *prompt_too_long_state = PromptTooLongState::Done,
3493                                    }
3494
3495                                    let created = SystemTime::now()
3496                                        .duration_since(std::time::UNIX_EPOCH)
3497                                        .map_err(|e| {
3498                                            let err_msg = format!(
3499                                                "Failed to get the current time. Reason: {}",
3500                                                e
3501                                            );
3502
3503                                            #[cfg(feature = "logging")]
3504                                            error!(target: "stdout", "{}", &err_msg);
3505
3506                                            LlamaCoreError::Operation(err_msg)
3507                                        })?;
3508
3509                                    let chat_completion_chunk = ChatCompletionChunk {
3510                                        id,
3511                                        object: "chat.completion.chunk".to_string(),
3512                                        created: created.as_secs(),
3513                                        model: graph.name().to_owned(),
3514                                        system_fingerprint: "fp_44709d6fcb".to_string(),
3515                                        choices: vec![ChatCompletionChunkChoice {
3516                                            index: 0,
3517                                            delta: ChatCompletionChunkChoiceDelta {
3518                                                role: ChatCompletionRole::Assistant,
3519                                                content: None,
3520                                                tool_calls: vec![],
3521                                            },
3522                                            logprobs: None,
3523                                            finish_reason: Some(FinishReason::length),
3524                                        }],
3525                                        usage: None,
3526                                    };
3527
3528                                    // serialize chat completion chunk
3529                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
3530                                        .map_err(|e| {
3531                                        let err_msg = format!(
3532                                            "Failed to serialize chat completion chunk. Reason: {}",
3533                                            e
3534                                        );
3535
3536                                        #[cfg(feature = "logging")]
3537                                        error!(target: "stdout", "{}", &err_msg);
3538
3539                                        LlamaCoreError::Operation(err_msg)
3540                                    })?;
3541
3542                                    Ok(format!("data: {}\n\n", chunk_str))
3543                                }
3544                                PromptTooLongState::Usage => {
3545                                    *prompt_too_long_state = PromptTooLongState::Done;
3546
3547                                    // retrieve the number of prompt and completion tokens
3548                                    let token_info = get_token_info_by_graph(graph)?;
3549
3550                                    let usage = Some(Usage {
3551                                        prompt_tokens: token_info.prompt_tokens,
3552                                        completion_tokens: token_info.completion_tokens,
3553                                        total_tokens: token_info.prompt_tokens
3554                                            + token_info.completion_tokens,
3555                                    });
3556
3557                                    let created = SystemTime::now()
3558                                        .duration_since(std::time::UNIX_EPOCH)
3559                                        .map_err(|e| {
3560                                            let err_msg = format!(
3561                                                "Failed to get the current time. Reason: {}",
3562                                                e
3563                                            );
3564
3565                                            #[cfg(feature = "logging")]
3566                                            error!(target: "stdout", "{}", &err_msg);
3567
3568                                            LlamaCoreError::Operation(err_msg)
3569                                        })?;
3570
3571                                    let chat_completion_chunk = ChatCompletionChunk {
3572                                        id,
3573                                        object: "chat.completion.chunk".to_string(),
3574                                        created: created.as_secs(),
3575                                        model: graph.name().to_owned(),
3576                                        system_fingerprint: "fp_44709d6fcb".to_string(),
3577                                        choices: vec![],
3578                                        usage,
3579                                    };
3580
3581                                    // serialize chat completion chunk
3582                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
3583                                        .map_err(|e| {
3584                                        let err_msg = format!(
3585                                            "Failed to serialize chat completion chunk. Reason: {}",
3586                                            e
3587                                        );
3588
3589                                        #[cfg(feature = "logging")]
3590                                        error!(target: "stdout", "{}", &err_msg);
3591
3592                                        LlamaCoreError::Operation(err_msg)
3593                                    })?;
3594
3595                                    Ok(format!("data: {}\n\n", chunk_str))
3596                                }
3597                                PromptTooLongState::Done => {
3598                                    *prompt_too_long_state = PromptTooLongState::EndOfSequence;
3599
3600                                    Ok("data: [DONE]\n\n".to_string())
3601                                }
3602                                PromptTooLongState::EndOfSequence => {
3603                                    Ok("[GGML] End of sequence".to_string())
3604                                }
3605                            }
3606                        }
3607                        Err(e) => {
3608                            let err_msg =
3609                                format!("Failed to compute the chat completion. Reason: {}", e);
3610
3611                            #[cfg(feature = "logging")]
3612                            error!(target: "stdout", "{}", &err_msg);
3613
3614                            Err(LlamaCoreError::Backend(BackendError::ComputeSingle(
3615                                err_msg,
3616                            )))
3617                        }
3618                    }
3619                }
3620                false => {
3621                    match chat_graphs.iter_mut().next() {
3622                        Some((_, graph)) => {
3623                            // compute
3624                            match graph.compute_single() {
3625                                Ok(_) => {
3626                                    #[cfg(feature = "logging")]
3627                                    debug!(target: "stdout", "Compute the chat stream chunk successfully.");
3628
3629                                    match stream_state {
3630                                        StreamState::Usage | StreamState::NoUsage => {
3631                                            // Retrieve the output
3632                                            let output_buffer =
3633                                                get_output_buffer_single(graph, OUTPUT_TENSOR)?;
3634
3635                                            #[cfg(feature = "logging")]
3636                                            info!(target: "stdout", "retrieved the output buffer");
3637
3638                                            // decode the output buffer to a utf8 string
3639                                            let output = match String::from_utf8(
3640                                                output_buffer.clone(),
3641                                            ) {
3642                                                Ok(token) => token,
3643                                                Err(_) => {
3644                                                    let mutex = CACHED_UTF8_ENCODINGS
3645                                                        .get_or_init(|| Mutex::new(Vec::new()));
3646                                                    let mut cached_encodings = mutex.lock().map_err(|e| {
3647                                            let err_msg = format!(
3648                                                "Fail to acquire the lock of `UTF8_ENCODINGS`. Reason: {}",
3649                                                e
3650                                            );
3651
3652                                            #[cfg(feature = "logging")]
3653                                            error!(target: "stdout", "{}", &err_msg);
3654
3655
3656                                            LlamaCoreError::Operation(err_msg)
3657                                        })?;
3658
3659                                                    // cache the bytes for future decoding
3660                                                    cached_encodings
3661                                                        .extend_from_slice(&output_buffer[..]);
3662
3663                                                    match String::from_utf8(
3664                                                        cached_encodings.to_vec(),
3665                                                    ) {
3666                                                        Ok(token) => {
3667                                                            // clear encodings
3668                                                            cached_encodings.clear();
3669
3670                                                            token
3671                                                        }
3672                                                        Err(e) => {
3673                                                            // TODO This is a temp check. In case, infinite cached encodings happen.
3674                                                            if cached_encodings.len() > 4 {
3675                                                                let err_msg = format!("Fail to convert a vector of bytes to string. The length of the utf8 bytes exceeds 4. {}", e);
3676
3677                                                                #[cfg(feature = "logging")]
3678                                                                error!(target: "stdout", "{}", &err_msg);
3679
3680                                                                #[cfg(feature = "logging")]
3681                                                                error!(target: "stdout", "The cached buffer: {:?}", &cached_encodings[..]);
3682
3683                                                                // let token =
3684                                                                //     String::from_utf8_lossy(
3685                                                                //         &cached_encodings,
3686                                                                //     )
3687                                                                //     .to_string();
3688
3689                                                                // clear CACHED_UTF8_ENCODINGS
3690                                                                cached_encodings.clear();
3691
3692                                                                String::from("")
3693                                                            } else {
3694                                                                let warn_msg = format!("Fail to convert a vector of bytes to string. {}", e);
3695
3696                                                                #[cfg(feature = "logging")]
3697                                                                warn!(target: "stdout", "{}", &warn_msg);
3698
3699                                                                String::from("")
3700                                                            }
3701                                                        }
3702                                                    }
3703                                                }
3704                                            };
3705
3706                                            #[cfg(feature = "logging")]
3707                                            info!(target: "stdout", "decoded the output buffer");
3708
3709                                            let created = SystemTime::now()
3710                                                .duration_since(std::time::UNIX_EPOCH)
3711                                                .map_err(|e| {
3712                                                    let err_msg = format!(
3713                                                "Failed to get the current time. Reason: {}",
3714                                                e
3715                                            );
3716
3717                                                    #[cfg(feature = "logging")]
3718                                                    error!(target: "stdout", "{}", &err_msg);
3719
3720                                                    LlamaCoreError::Operation(err_msg)
3721                                                })?;
3722
3723                                            let chat_completion_chunk = ChatCompletionChunk {
3724                                                id,
3725                                                object: "chat.completion.chunk".to_string(),
3726                                                created: created.as_secs(),
3727                                                model: graph.name().to_owned(),
3728                                                system_fingerprint: "fp_44709d6fcb".to_string(),
3729                                                choices: vec![ChatCompletionChunkChoice {
3730                                                    index: 0,
3731                                                    delta: ChatCompletionChunkChoiceDelta {
3732                                                        role: ChatCompletionRole::Assistant,
3733                                                        content: Some(output),
3734                                                        tool_calls: vec![],
3735                                                    },
3736                                                    logprobs: None,
3737                                                    finish_reason: None,
3738                                                }],
3739                                                usage: None,
3740                                            };
3741
3742                                            #[cfg(feature = "logging")]
3743                                            info!(target: "stdout", "created chat completion chunk");
3744
3745                                            // serialize chat completion chunk
3746                                            let chunk_str =
3747                                                serde_json::to_string(&chat_completion_chunk)
3748                                                    .map_err(|e| {
3749                                                        let err_msg = format!(
3750                                            "Failed to serialize chat completion chunk. Reason: {}",
3751                                            e
3752                                        );
3753
3754                                                        #[cfg(feature = "logging")]
3755                                                        error!(target: "stdout", "{}", &err_msg);
3756
3757                                                        LlamaCoreError::Operation(err_msg)
3758                                                    })?;
3759
3760                                            Ok(format!("data: {}\n\n", chunk_str))
3761                                        }
3762                                        StreamState::Done => {
3763                                            *stream_state = StreamState::EndOfSequence;
3764
3765                                            Ok("data: [DONE]\n\n".to_string())
3766                                        }
3767                                        StreamState::EndOfSequence => {
3768                                            Ok("[GGML] End of sequence".to_string())
3769                                        }
3770                                    }
3771                                }
3772                                Err(wasmedge_wasi_nn::Error::BackendError(
3773                                    wasmedge_wasi_nn::BackendError::EndOfSequence,
3774                                )) => {
3775                                    #[cfg(feature = "logging")]
3776                                    debug!(target: "stdout", "End of sequence");
3777
3778                                    match stream_state {
3779                                        StreamState::Usage => {
3780                                            *stream_state = StreamState::Done;
3781
3782                                            // retrieve the number of prompt and completion tokens
3783                                            let token_info = get_token_info_by_graph(graph)?;
3784
3785                                            let usage = Some(Usage {
3786                                                prompt_tokens: token_info.prompt_tokens,
3787                                                completion_tokens: token_info.completion_tokens,
3788                                                total_tokens: token_info.prompt_tokens
3789                                                    + token_info.completion_tokens,
3790                                            });
3791
3792                                            #[cfg(feature = "logging")]
3793                                            info!(target: "stdout", "token_info: {} prompt tokens, {} completion tokens", token_info.prompt_tokens, token_info.completion_tokens);
3794
3795                                            let created = SystemTime::now()
3796                                                .duration_since(std::time::UNIX_EPOCH)
3797                                                .map_err(|e| {
3798                                                    let err_msg = format!(
3799                                                "Failed to get the current time. Reason: {}",
3800                                                e
3801                                            );
3802
3803                                                    #[cfg(feature = "logging")]
3804                                                    error!(target: "stdout", "{}", &err_msg);
3805
3806                                                    LlamaCoreError::Operation(err_msg)
3807                                                })?;
3808
3809                                            let chat_completion_chunk = ChatCompletionChunk {
3810                                                id,
3811                                                object: "chat.completion.chunk".to_string(),
3812                                                created: created.as_secs(),
3813                                                model: graph.name().to_owned(),
3814                                                system_fingerprint: "fp_44709d6fcb".to_string(),
3815                                                choices: vec![],
3816                                                usage,
3817                                            };
3818
3819                                            // serialize chat completion chunk
3820                                            let chunk_str =
3821                                                serde_json::to_string(&chat_completion_chunk)
3822                                                    .map_err(|e| {
3823                                                        let err_msg = format!(
3824                                            "Failed to serialize chat completion chunk. Reason: {}",
3825                                            e
3826                                        );
3827
3828                                                        #[cfg(feature = "logging")]
3829                                                        error!(target: "stdout", "{}", &err_msg);
3830
3831                                                        LlamaCoreError::Operation(err_msg)
3832                                                    })?;
3833
3834                                            Ok(format!("data: {}\n\n", chunk_str))
3835                                        }
3836                                        StreamState::Done | StreamState::NoUsage => {
3837                                            *stream_state = StreamState::EndOfSequence;
3838
3839                                            Ok("data: [DONE]\n\n".to_string())
3840                                        }
3841                                        StreamState::EndOfSequence => {
3842                                            Ok("[GGML] End of sequence".to_string())
3843                                        }
3844                                    }
3845                                }
3846                                Err(wasmedge_wasi_nn::Error::BackendError(
3847                                    wasmedge_wasi_nn::BackendError::ContextFull,
3848                                )) => {
3849                                    #[cfg(feature = "logging")]
3850                                    debug!(target: "stdout", "Context full");
3851
3852                                    match context_full_state {
3853                                        ContextFullState::Message => {
3854                                            match include_usage {
3855                                                true => {
3856                                                    *context_full_state = ContextFullState::Usage
3857                                                }
3858                                                false => {
3859                                                    *context_full_state = ContextFullState::Done
3860                                                }
3861                                            }
3862
3863                                            let created = SystemTime::now()
3864                                                .duration_since(std::time::UNIX_EPOCH)
3865                                                .map_err(|e| {
3866                                                    let err_msg = format!(
3867                                                "Failed to get the current time. Reason: {}",
3868                                                e
3869                                            );
3870
3871                                                    #[cfg(feature = "logging")]
3872                                                    error!(target: "stdout", "{}", &err_msg);
3873
3874                                                    LlamaCoreError::Operation(err_msg)
3875                                                })?;
3876
3877                                            let chat_completion_chunk = ChatCompletionChunk {
3878                                                id,
3879                                                object: "chat.completion.chunk".to_string(),
3880                                                created: created.as_secs(),
3881                                                model: graph.name().to_owned(),
3882                                                system_fingerprint: "fp_44709d6fcb".to_string(),
3883                                                choices: vec![ChatCompletionChunkChoice {
3884                                                    index: 0,
3885                                                    delta: ChatCompletionChunkChoiceDelta {
3886                                                        role: ChatCompletionRole::Assistant,
3887                                                        content: Some(
3888                                                            "<|WASMEDGE-GGML-CONTEXT-FULL|>"
3889                                                                .to_string(),
3890                                                        ),
3891                                                        tool_calls: vec![],
3892                                                    },
3893                                                    logprobs: None,
3894                                                    finish_reason: Some(FinishReason::length),
3895                                                }],
3896                                                usage: None,
3897                                            };
3898
3899                                            // serialize chat completion chunk
3900                                            let chunk_str =
3901                                                serde_json::to_string(&chat_completion_chunk)
3902                                                    .map_err(|e| {
3903                                                        let err_msg = format!(
3904                                            "Failed to serialize chat completion chunk. Reason: {}",
3905                                            e
3906                                        );
3907
3908                                                        #[cfg(feature = "logging")]
3909                                                        error!(target: "stdout", "{}", &err_msg);
3910
3911                                                        LlamaCoreError::Operation(err_msg)
3912                                                    })?;
3913
3914                                            Ok(format!("data: {}\n\n", chunk_str))
3915                                        }
3916                                        ContextFullState::Usage => {
3917                                            *context_full_state = ContextFullState::Done;
3918
3919                                            // retrieve the number of prompt and completion tokens
3920                                            let token_info = get_token_info_by_graph(graph)?;
3921
3922                                            let usage = Some(Usage {
3923                                                prompt_tokens: token_info.prompt_tokens,
3924                                                completion_tokens: token_info.completion_tokens,
3925                                                total_tokens: token_info.prompt_tokens
3926                                                    + token_info.completion_tokens,
3927                                            });
3928
3929                                            let created = SystemTime::now()
3930                                                .duration_since(std::time::UNIX_EPOCH)
3931                                                .map_err(|e| {
3932                                                    let err_msg = format!(
3933                                                "Failed to get the current time. Reason: {}",
3934                                                e
3935                                            );
3936
3937                                                    #[cfg(feature = "logging")]
3938                                                    error!(target: "stdout", "{}", &err_msg);
3939
3940                                                    LlamaCoreError::Operation(err_msg)
3941                                                })?;
3942
3943                                            let chat_completion_chunk = ChatCompletionChunk {
3944                                                id,
3945                                                object: "chat.completion.chunk".to_string(),
3946                                                created: created.as_secs(),
3947                                                model: graph.name().to_owned(),
3948                                                system_fingerprint: "fp_44709d6fcb".to_string(),
3949                                                choices: vec![],
3950                                                usage,
3951                                            };
3952
3953                                            // serialize chat completion chunk
3954                                            let chunk_str =
3955                                                serde_json::to_string(&chat_completion_chunk)
3956                                                    .map_err(|e| {
3957                                                        let err_msg = format!(
3958                                            "Failed to serialize chat completion chunk. Reason: {}",
3959                                            e
3960                                        );
3961
3962                                                        #[cfg(feature = "logging")]
3963                                                        error!(target: "stdout", "{}", &err_msg);
3964
3965                                                        LlamaCoreError::Operation(err_msg)
3966                                                    })?;
3967
3968                                            Ok(format!("data: {}\n\n", chunk_str))
3969                                        }
3970                                        ContextFullState::Done => {
3971                                            *context_full_state = ContextFullState::EndOfSequence;
3972
3973                                            Ok("data: [DONE]\n\n".to_string())
3974                                        }
3975                                        ContextFullState::EndOfSequence => {
3976                                            Ok("[GGML] End of sequence".to_string())
3977                                        }
3978                                    }
3979                                }
3980                                Err(wasmedge_wasi_nn::Error::BackendError(
3981                                    wasmedge_wasi_nn::BackendError::PromptTooLong,
3982                                )) => {
3983                                    #[cfg(feature = "logging")]
3984                                    debug!(target: "stdout", "Prompt too long");
3985
3986                                    match prompt_too_long_state {
3987                                        PromptTooLongState::Message => {
3988                                            match include_usage {
3989                                                true => {
3990                                                    *prompt_too_long_state =
3991                                                        PromptTooLongState::Usage
3992                                                }
3993                                                false => {
3994                                                    *prompt_too_long_state =
3995                                                        PromptTooLongState::Done
3996                                                }
3997                                            }
3998
3999                                            let created = SystemTime::now()
4000                                                .duration_since(std::time::UNIX_EPOCH)
4001                                                .map_err(|e| {
4002                                                    let err_msg = format!(
4003                                                "Failed to get the current time. Reason: {}",
4004                                                e
4005                                            );
4006
4007                                                    #[cfg(feature = "logging")]
4008                                                    error!(target: "stdout", "{}", &err_msg);
4009
4010                                                    LlamaCoreError::Operation(err_msg)
4011                                                })?;
4012
4013                                            let chat_completion_chunk = ChatCompletionChunk {
4014                                                id,
4015                                                object: "chat.completion.chunk".to_string(),
4016                                                created: created.as_secs(),
4017                                                model: graph.name().to_owned(),
4018                                                system_fingerprint: "fp_44709d6fcb".to_string(),
4019                                                choices: vec![ChatCompletionChunkChoice {
4020                                                    index: 0,
4021                                                    delta: ChatCompletionChunkChoiceDelta {
4022                                                        role: ChatCompletionRole::Assistant,
4023                                                        content: None,
4024                                                        tool_calls: vec![],
4025                                                    },
4026                                                    logprobs: None,
4027                                                    finish_reason: Some(FinishReason::length),
4028                                                }],
4029                                                usage: None,
4030                                            };
4031
4032                                            // serialize chat completion chunk
4033                                            let chunk_str =
4034                                                serde_json::to_string(&chat_completion_chunk)
4035                                                    .map_err(|e| {
4036                                                        let err_msg = format!(
4037                                            "Failed to serialize chat completion chunk. Reason: {}",
4038                                            e
4039                                        );
4040
4041                                                        #[cfg(feature = "logging")]
4042                                                        error!(target: "stdout", "{}", &err_msg);
4043
4044                                                        LlamaCoreError::Operation(err_msg)
4045                                                    })?;
4046
4047                                            Ok(format!("data: {}\n\n", chunk_str))
4048                                        }
4049                                        PromptTooLongState::Usage => {
4050                                            *prompt_too_long_state = PromptTooLongState::Done;
4051
4052                                            // retrieve the number of prompt and completion tokens
4053                                            let token_info = get_token_info_by_graph(graph)?;
4054
4055                                            let usage = Some(Usage {
4056                                                prompt_tokens: token_info.prompt_tokens,
4057                                                completion_tokens: token_info.completion_tokens,
4058                                                total_tokens: token_info.prompt_tokens
4059                                                    + token_info.completion_tokens,
4060                                            });
4061
4062                                            let created = SystemTime::now()
4063                                                .duration_since(std::time::UNIX_EPOCH)
4064                                                .map_err(|e| {
4065                                                    let err_msg = format!(
4066                                                "Failed to get the current time. Reason: {}",
4067                                                e
4068                                            );
4069
4070                                                    #[cfg(feature = "logging")]
4071                                                    error!(target: "stdout", "{}", &err_msg);
4072
4073                                                    LlamaCoreError::Operation(err_msg)
4074                                                })?;
4075
4076                                            let chat_completion_chunk = ChatCompletionChunk {
4077                                                id,
4078                                                object: "chat.completion.chunk".to_string(),
4079                                                created: created.as_secs(),
4080                                                model: graph.name().to_owned(),
4081                                                system_fingerprint: "fp_44709d6fcb".to_string(),
4082                                                choices: vec![],
4083                                                usage,
4084                                            };
4085
4086                                            // serialize chat completion chunk
4087                                            let chunk_str =
4088                                                serde_json::to_string(&chat_completion_chunk)
4089                                                    .map_err(|e| {
4090                                                        let err_msg = format!(
4091                                            "Failed to serialize chat completion chunk. Reason: {}",
4092                                            e
4093                                        );
4094
4095                                                        #[cfg(feature = "logging")]
4096                                                        error!(target: "stdout", "{}", &err_msg);
4097
4098                                                        LlamaCoreError::Operation(err_msg)
4099                                                    })?;
4100
4101                                            Ok(format!("data: {}\n\n", chunk_str))
4102                                        }
4103                                        PromptTooLongState::Done => {
4104                                            *prompt_too_long_state =
4105                                                PromptTooLongState::EndOfSequence;
4106
4107                                            Ok("data: [DONE]\n\n".to_string())
4108                                        }
4109                                        PromptTooLongState::EndOfSequence => {
4110                                            Ok("[GGML] End of sequence".to_string())
4111                                        }
4112                                    }
4113                                }
4114                                Err(e) => {
4115                                    let err_msg = format!(
4116                                        "Failed to compute the chat completion. Reason: {}",
4117                                        e
4118                                    );
4119
4120                                    #[cfg(feature = "logging")]
4121                                    error!(target: "stdout", "{}", &err_msg);
4122
4123                                    Err(LlamaCoreError::Backend(BackendError::ComputeSingle(
4124                                        err_msg,
4125                                    )))
4126                                }
4127                            }
4128                        }
4129                        None => {
4130                            let err_msg = "There is no model available in the chat graphs.";
4131
4132                            #[cfg(feature = "logging")]
4133                            error!(target: "stdout", "{}", &err_msg);
4134
4135                            Err(LlamaCoreError::Operation(err_msg.into()))
4136                        }
4137                    }
4138                }
4139            }
4140        }
4141        None => {
4142            match chat_graphs.iter_mut().next() {
4143                Some((_, graph)) => {
4144                    // compute
4145                    match graph.compute_single() {
4146                        Ok(_) => {
4147                            #[cfg(feature = "logging")]
4148                            debug!(target: "stdout", "Compute the chat stream chunk successfully.");
4149
4150                            match stream_state {
4151                                StreamState::Usage | StreamState::NoUsage => {
4152                                    // Retrieve the output
4153                                    let output_buffer =
4154                                        get_output_buffer_single(graph, OUTPUT_TENSOR)?;
4155
4156                                    #[cfg(feature = "logging")]
4157                                    info!(target: "stdout", "retrieved the output buffer");
4158
4159                                    // decode the output buffer to a utf8 string
4160                                    let output = match String::from_utf8(output_buffer.clone()) {
4161                                        Ok(token) => token,
4162                                        Err(_) => {
4163                                            let mutex = CACHED_UTF8_ENCODINGS
4164                                                .get_or_init(|| Mutex::new(Vec::new()));
4165                                            let mut cached_encodings = mutex.lock().map_err(|e| {
4166                                            let err_msg = format!(
4167                                                "Fail to acquire the lock of `UTF8_ENCODINGS`. Reason: {}",
4168                                                e
4169                                            );
4170
4171                                            #[cfg(feature = "logging")]
4172                                            error!(target: "stdout", "{}", &err_msg);
4173
4174                                            LlamaCoreError::Operation(err_msg)
4175                                        })?;
4176
4177                                            cached_encodings.extend_from_slice(&output_buffer[..]);
4178
4179                                            match String::from_utf8(cached_encodings.to_vec()) {
4180                                                Ok(token) => {
4181                                                    // clear encodings
4182                                                    cached_encodings.clear();
4183
4184                                                    token
4185                                                }
4186                                                Err(e) => {
4187                                                    // TODO This is a temp check. In case, infinite cached encodings happen.
4188                                                    if cached_encodings.len() > 4 {
4189                                                        let err_msg = format!("Fail to convert a vector of bytes to string. The length of the utf8 bytes exceeds 4. {}", e);
4190
4191                                                        #[cfg(feature = "logging")]
4192                                                        error!(target: "stdout", "{}", &err_msg);
4193
4194                                                        #[cfg(feature = "logging")]
4195                                                        error!(target: "stdout", "The cached buffer: {:?}", &cached_encodings[..]);
4196
4197                                                        // let token = String::from_utf8_lossy(
4198                                                        //     &cached_encodings,
4199                                                        // )
4200                                                        // .to_string();
4201
4202                                                        // clear CACHED_UTF8_ENCODINGS
4203                                                        cached_encodings.clear();
4204
4205                                                        String::from("")
4206                                                    } else {
4207                                                        let warn_msg = format!("Fail to convert a vector of bytes to string. {}", e);
4208
4209                                                        #[cfg(feature = "logging")]
4210                                                        warn!(target: "stdout", "{}", &warn_msg);
4211
4212                                                        String::from("")
4213                                                    }
4214                                                }
4215                                            }
4216                                        }
4217                                    };
4218
4219                                    #[cfg(feature = "logging")]
4220                                    info!(target: "stdout", "decoded the output buffer");
4221
4222                                    let created = SystemTime::now()
4223                                        .duration_since(std::time::UNIX_EPOCH)
4224                                        .map_err(|e| {
4225                                            let err_msg = format!(
4226                                                "Failed to get the current time. Reason: {}",
4227                                                e
4228                                            );
4229
4230                                            #[cfg(feature = "logging")]
4231                                            error!(target: "stdout", "{}", &err_msg);
4232
4233                                            LlamaCoreError::Operation(err_msg)
4234                                        })?;
4235
4236                                    let chat_completion_chunk = ChatCompletionChunk {
4237                                        id,
4238                                        object: "chat.completion.chunk".to_string(),
4239                                        created: created.as_secs(),
4240                                        model: graph.name().to_owned(),
4241                                        system_fingerprint: "fp_44709d6fcb".to_string(),
4242                                        choices: vec![ChatCompletionChunkChoice {
4243                                            index: 0,
4244                                            delta: ChatCompletionChunkChoiceDelta {
4245                                                role: ChatCompletionRole::Assistant,
4246                                                content: Some(output),
4247                                                tool_calls: vec![],
4248                                            },
4249                                            logprobs: None,
4250                                            finish_reason: None,
4251                                        }],
4252                                        usage: None,
4253                                    };
4254
4255                                    #[cfg(feature = "logging")]
4256                                    info!(target: "stdout", "created chat completion chunk");
4257
4258                                    // serialize chat completion chunk
4259                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
4260                                        .map_err(|e| {
4261                                        let err_msg = format!(
4262                                            "Failed to serialize chat completion chunk. Reason: {}",
4263                                            e
4264                                        );
4265
4266                                        #[cfg(feature = "logging")]
4267                                        error!(target: "stdout", "{}", &err_msg);
4268
4269                                        LlamaCoreError::Operation(err_msg)
4270                                    })?;
4271
4272                                    Ok(format!("data: {}\n\n", chunk_str))
4273                                }
4274                                StreamState::Done => {
4275                                    *stream_state = StreamState::EndOfSequence;
4276
4277                                    Ok("data: [DONE]\n\n".to_string())
4278                                }
4279                                StreamState::EndOfSequence => {
4280                                    Ok("[GGML] End of sequence".to_string())
4281                                }
4282                            }
4283                        }
4284                        Err(wasmedge_wasi_nn::Error::BackendError(
4285                            wasmedge_wasi_nn::BackendError::EndOfSequence,
4286                        )) => {
4287                            #[cfg(feature = "logging")]
4288                            debug!(target: "stdout", "End of sequence");
4289
4290                            match stream_state {
4291                                StreamState::Usage => {
4292                                    *stream_state = StreamState::Done;
4293
4294                                    // retrieve the number of prompt and completion tokens
4295                                    let token_info = get_token_info_by_graph(graph)?;
4296
4297                                    let usage = Some(Usage {
4298                                        prompt_tokens: token_info.prompt_tokens,
4299                                        completion_tokens: token_info.completion_tokens,
4300                                        total_tokens: token_info.prompt_tokens
4301                                            + token_info.completion_tokens,
4302                                    });
4303
4304                                    #[cfg(feature = "logging")]
4305                                    info!(target: "stdout", "token_info: {} prompt tokens, {} completion tokens", token_info.prompt_tokens, token_info.completion_tokens);
4306
4307                                    let created = SystemTime::now()
4308                                        .duration_since(std::time::UNIX_EPOCH)
4309                                        .map_err(|e| {
4310                                            let err_msg = format!(
4311                                                "Failed to get the current time. Reason: {}",
4312                                                e
4313                                            );
4314
4315                                            #[cfg(feature = "logging")]
4316                                            error!(target: "stdout", "{}", &err_msg);
4317
4318                                            LlamaCoreError::Operation(err_msg)
4319                                        })?;
4320
4321                                    let chat_completion_chunk = ChatCompletionChunk {
4322                                        id,
4323                                        object: "chat.completion.chunk".to_string(),
4324                                        created: created.as_secs(),
4325                                        model: graph.name().to_owned(),
4326                                        system_fingerprint: "fp_44709d6fcb".to_string(),
4327                                        choices: vec![],
4328                                        usage,
4329                                    };
4330
4331                                    // serialize chat completion chunk
4332                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
4333                                        .map_err(|e| {
4334                                        let err_msg = format!(
4335                                            "Failed to serialize chat completion chunk. Reason: {}",
4336                                            e
4337                                        );
4338
4339                                        #[cfg(feature = "logging")]
4340                                        error!(target: "stdout", "{}", &err_msg);
4341
4342                                        LlamaCoreError::Operation(err_msg)
4343                                    })?;
4344
4345                                    Ok(format!("data: {}\n\n", chunk_str))
4346                                }
4347                                StreamState::Done | StreamState::NoUsage => {
4348                                    *stream_state = StreamState::EndOfSequence;
4349
4350                                    Ok("data: [DONE]\n\n".to_string())
4351                                }
4352                                StreamState::EndOfSequence => {
4353                                    Ok("[GGML] End of sequence".to_string())
4354                                }
4355                            }
4356                        }
4357                        Err(wasmedge_wasi_nn::Error::BackendError(
4358                            wasmedge_wasi_nn::BackendError::ContextFull,
4359                        )) => {
4360                            #[cfg(feature = "logging")]
4361                            debug!(target: "stdout", "Context full");
4362
4363                            match context_full_state {
4364                                ContextFullState::Message => {
4365                                    match include_usage {
4366                                        true => *context_full_state = ContextFullState::Usage,
4367                                        false => *context_full_state = ContextFullState::Done,
4368                                    }
4369
4370                                    let created = SystemTime::now()
4371                                        .duration_since(std::time::UNIX_EPOCH)
4372                                        .map_err(|e| {
4373                                            let err_msg = format!(
4374                                                "Failed to get the current time. Reason: {}",
4375                                                e
4376                                            );
4377
4378                                            #[cfg(feature = "logging")]
4379                                            error!(target: "stdout", "{}", &err_msg);
4380
4381                                            LlamaCoreError::Operation(err_msg)
4382                                        })?;
4383
4384                                    let chat_completion_chunk = ChatCompletionChunk {
4385                                        id,
4386                                        object: "chat.completion.chunk".to_string(),
4387                                        created: created.as_secs(),
4388                                        model: graph.name().to_owned(),
4389                                        system_fingerprint: "fp_44709d6fcb".to_string(),
4390                                        choices: vec![ChatCompletionChunkChoice {
4391                                            index: 0,
4392                                            delta: ChatCompletionChunkChoiceDelta {
4393                                                role: ChatCompletionRole::Assistant,
4394                                                content: Some(
4395                                                    "<|WASMEDGE-GGML-CONTEXT-FULL|>".to_string(),
4396                                                ),
4397                                                tool_calls: vec![],
4398                                            },
4399                                            logprobs: None,
4400                                            finish_reason: Some(FinishReason::length),
4401                                        }],
4402                                        usage: None,
4403                                    };
4404
4405                                    // serialize chat completion chunk
4406                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
4407                                        .map_err(|e| {
4408                                        let err_msg = format!(
4409                                            "Failed to serialize chat completion chunk. Reason: {}",
4410                                            e
4411                                        );
4412
4413                                        #[cfg(feature = "logging")]
4414                                        error!(target: "stdout", "{}", &err_msg);
4415
4416                                        LlamaCoreError::Operation(err_msg)
4417                                    })?;
4418
4419                                    Ok(format!("data: {}\n\n", chunk_str))
4420                                }
4421                                ContextFullState::Usage => {
4422                                    *context_full_state = ContextFullState::Done;
4423
4424                                    // retrieve the number of prompt and completion tokens
4425                                    let token_info = get_token_info_by_graph(graph)?;
4426
4427                                    let usage = Some(Usage {
4428                                        prompt_tokens: token_info.prompt_tokens,
4429                                        completion_tokens: token_info.completion_tokens,
4430                                        total_tokens: token_info.prompt_tokens
4431                                            + token_info.completion_tokens,
4432                                    });
4433
4434                                    let created = SystemTime::now()
4435                                        .duration_since(std::time::UNIX_EPOCH)
4436                                        .map_err(|e| {
4437                                            let err_msg = format!(
4438                                                "Failed to get the current time. Reason: {}",
4439                                                e
4440                                            );
4441
4442                                            #[cfg(feature = "logging")]
4443                                            error!(target: "stdout", "{}", &err_msg);
4444
4445                                            LlamaCoreError::Operation(err_msg)
4446                                        })?;
4447
4448                                    let chat_completion_chunk = ChatCompletionChunk {
4449                                        id,
4450                                        object: "chat.completion.chunk".to_string(),
4451                                        created: created.as_secs(),
4452                                        model: graph.name().to_owned(),
4453                                        system_fingerprint: "fp_44709d6fcb".to_string(),
4454                                        choices: vec![],
4455                                        usage,
4456                                    };
4457
4458                                    // serialize chat completion chunk
4459                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
4460                                        .map_err(|e| {
4461                                        let err_msg = format!(
4462                                            "Failed to serialize chat completion chunk. Reason: {}",
4463                                            e
4464                                        );
4465
4466                                        #[cfg(feature = "logging")]
4467                                        error!(target: "stdout", "{}", &err_msg);
4468
4469                                        LlamaCoreError::Operation(err_msg)
4470                                    })?;
4471
4472                                    Ok(format!("data: {}\n\n", chunk_str))
4473                                }
4474                                ContextFullState::Done => {
4475                                    *context_full_state = ContextFullState::EndOfSequence;
4476
4477                                    Ok("data: [DONE]\n\n".to_string())
4478                                }
4479                                ContextFullState::EndOfSequence => {
4480                                    Ok("[GGML] End of sequence".to_string())
4481                                }
4482                            }
4483                        }
4484                        Err(wasmedge_wasi_nn::Error::BackendError(
4485                            wasmedge_wasi_nn::BackendError::PromptTooLong,
4486                        )) => {
4487                            #[cfg(feature = "logging")]
4488                            debug!(target: "stdout", "Prompt too long");
4489
4490                            match prompt_too_long_state {
4491                                PromptTooLongState::Message => {
4492                                    match include_usage {
4493                                        true => *prompt_too_long_state = PromptTooLongState::Usage,
4494                                        false => *prompt_too_long_state = PromptTooLongState::Done,
4495                                    }
4496
4497                                    let created = SystemTime::now()
4498                                        .duration_since(std::time::UNIX_EPOCH)
4499                                        .map_err(|e| {
4500                                            let err_msg = format!(
4501                                                "Failed to get the current time. Reason: {}",
4502                                                e
4503                                            );
4504
4505                                            #[cfg(feature = "logging")]
4506                                            error!(target: "stdout", "{}", &err_msg);
4507
4508                                            LlamaCoreError::Operation(err_msg)
4509                                        })?;
4510
4511                                    let chat_completion_chunk = ChatCompletionChunk {
4512                                        id,
4513                                        object: "chat.completion.chunk".to_string(),
4514                                        created: created.as_secs(),
4515                                        model: graph.name().to_owned(),
4516                                        system_fingerprint: "fp_44709d6fcb".to_string(),
4517                                        choices: vec![ChatCompletionChunkChoice {
4518                                            index: 0,
4519                                            delta: ChatCompletionChunkChoiceDelta {
4520                                                role: ChatCompletionRole::Assistant,
4521                                                content: None,
4522                                                tool_calls: vec![],
4523                                            },
4524                                            logprobs: None,
4525                                            finish_reason: Some(FinishReason::length),
4526                                        }],
4527                                        usage: None,
4528                                    };
4529
4530                                    // serialize chat completion chunk
4531                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
4532                                        .map_err(|e| {
4533                                        let err_msg = format!(
4534                                            "Failed to serialize chat completion chunk. Reason: {}",
4535                                            e
4536                                        );
4537
4538                                        #[cfg(feature = "logging")]
4539                                        error!(target: "stdout", "{}", &err_msg);
4540
4541                                        LlamaCoreError::Operation(err_msg)
4542                                    })?;
4543
4544                                    Ok(format!("data: {}\n\n", chunk_str))
4545                                }
4546                                PromptTooLongState::Usage => {
4547                                    *prompt_too_long_state = PromptTooLongState::Done;
4548
4549                                    // retrieve the number of prompt and completion tokens
4550                                    let token_info = get_token_info_by_graph(graph)?;
4551
4552                                    let usage = Some(Usage {
4553                                        prompt_tokens: token_info.prompt_tokens,
4554                                        completion_tokens: token_info.completion_tokens,
4555                                        total_tokens: token_info.prompt_tokens
4556                                            + token_info.completion_tokens,
4557                                    });
4558
4559                                    let created = SystemTime::now()
4560                                        .duration_since(std::time::UNIX_EPOCH)
4561                                        .map_err(|e| {
4562                                            let err_msg = format!(
4563                                                "Failed to get the current time. Reason: {}",
4564                                                e
4565                                            );
4566
4567                                            #[cfg(feature = "logging")]
4568                                            error!(target: "stdout", "{}", &err_msg);
4569
4570                                            LlamaCoreError::Operation(err_msg)
4571                                        })?;
4572
4573                                    let chat_completion_chunk = ChatCompletionChunk {
4574                                        id,
4575                                        object: "chat.completion.chunk".to_string(),
4576                                        created: created.as_secs(),
4577                                        model: graph.name().to_owned(),
4578                                        system_fingerprint: "fp_44709d6fcb".to_string(),
4579                                        choices: vec![],
4580                                        usage,
4581                                    };
4582
4583                                    // serialize chat completion chunk
4584                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
4585                                        .map_err(|e| {
4586                                        let err_msg = format!(
4587                                            "Failed to serialize chat completion chunk. Reason: {}",
4588                                            e
4589                                        );
4590
4591                                        #[cfg(feature = "logging")]
4592                                        error!(target: "stdout", "{}", &err_msg);
4593
4594                                        LlamaCoreError::Operation(err_msg)
4595                                    })?;
4596
4597                                    Ok(format!("data: {}\n\n", chunk_str))
4598                                }
4599                                PromptTooLongState::Done => {
4600                                    *prompt_too_long_state = PromptTooLongState::EndOfSequence;
4601
4602                                    Ok("data: [DONE]\n\n".to_string())
4603                                }
4604                                PromptTooLongState::EndOfSequence => {
4605                                    Ok("[GGML] End of sequence".to_string())
4606                                }
4607                            }
4608                        }
4609                        Err(e) => {
4610                            let err_msg =
4611                                format!("Failed to compute the chat completion. Reason: {}", e);
4612
4613                            #[cfg(feature = "logging")]
4614                            error!(target: "stdout", "{}", &err_msg);
4615
4616                            Err(LlamaCoreError::Backend(BackendError::ComputeSingle(
4617                                err_msg,
4618                            )))
4619                        }
4620                    }
4621                }
4622                None => {
4623                    let err_msg = "There is no model available in the chat graphs.";
4624
4625                    #[cfg(feature = "logging")]
4626                    error!(target: "stdout", "{}", &err_msg);
4627
4628                    Err(LlamaCoreError::Operation(err_msg.into()))
4629                }
4630            }
4631        }
4632    };
4633
4634    #[cfg(feature = "logging")]
4635    info!(target: "stdout", "Return the chat stream chunk!");
4636
4637    res
4638}
4639
4640#[allow(dead_code)]
4641#[derive(Debug)]
4642struct ParseResult {
4643    raw: String,
4644    content: Option<String>,
4645    tool_calls: Vec<ToolCall>,
4646}