llama_core/
chat.rs

1//! Define APIs for chat completion.
2
3use crate::{
4    error,
5    metadata::ggml::GgmlMetadata,
6    running_mode,
7    utils::{
8        gen_chat_id, get_output_buffer, get_output_buffer_single, get_token_info_by_graph,
9        get_token_info_by_graph_name, set_tensor_data_u8,
10    },
11    Graph, RunningMode, CACHED_UTF8_ENCODINGS, CHAT_GRAPHS, OUTPUT_TENSOR,
12};
13use chat_prompts::{BuildChatPrompt, ChatPrompt, PromptTemplateType};
14use either::{Either, Left, Right};
15use endpoints::{
16    chat::{
17        ChatCompletionChunk, ChatCompletionChunkChoice, ChatCompletionChunkChoiceDelta,
18        ChatCompletionObject, ChatCompletionObjectChoice, ChatCompletionObjectMessage,
19        ChatCompletionRequest, ChatCompletionRequestMessage, ChatCompletionRole,
20        ChatCompletionUserMessageContent, ContentPart, Function, ToolCall, ToolCallForChunk,
21        ToolChoice,
22    },
23    common::{FinishReason, Usage},
24};
25use error::{BackendError, LlamaCoreError};
26use std::{
27    collections::VecDeque,
28    pin::Pin,
29    sync::{
30        atomic::{AtomicBool, Ordering},
31        Mutex, OnceLock,
32    },
33    task::{Context, Poll, Waker},
34    time::SystemTime,
35};
36
37// Define a global waker queue for storing waiting ChatStreams
38static CHAT_STREAM_WAKER_QUEUE: OnceLock<Mutex<VecDeque<Waker>>> = OnceLock::new();
39
40// Define a global atomic boolean indicating whether there is an active ChatStream
41static CHAT_STREAM_ACTIVE: AtomicBool = AtomicBool::new(false);
42
43/// Processes a chat-completion request and returns either a stream of ChatCompletionChunk instances or a ChatCompletionObject instance.
44pub async fn chat(
45    chat_request: &mut ChatCompletionRequest,
46) -> Result<
47    (
48        Either<impl futures::TryStream<Ok = String, Error = LlamaCoreError>, ChatCompletionObject>,
49        bool,
50    ),
51    LlamaCoreError,
52> {
53    #[cfg(feature = "logging")]
54    {
55        debug!(target: "stdout", "tool choice: {:?}", chat_request.tool_choice.as_ref());
56        debug!(target: "stdout", "tools: {:?}", chat_request.tools.as_ref());
57        debug!(target: "stdout", "stream mode: {:?}", chat_request.stream);
58    }
59
60    let result = match chat_request.stream {
61        Some(true) => match chat_stream(chat_request).await {
62            Ok((stream, include_tool_calls)) => Ok((Left(stream), include_tool_calls)),
63            Err(e) => Err(e),
64        },
65        Some(false) | None => match chat_once(chat_request).await {
66            Ok((chat_completion_object, include_tool_calls)) => {
67                Ok((Right(chat_completion_object), include_tool_calls))
68            }
69            Err(e) => Err(e),
70        },
71    };
72
73    #[cfg(feature = "logging")]
74    info!(target: "stdout", "Reset the model metadata");
75
76    result
77}
78
79async fn chat_stream(
80    chat_request: &mut ChatCompletionRequest,
81) -> Result<
82    (
83        impl futures::TryStream<Ok = String, Error = LlamaCoreError>,
84        bool,
85    ),
86    LlamaCoreError,
87> {
88    #[cfg(feature = "logging")]
89    info!(target: "stdout", "Process chat completion request in the stream mode");
90
91    let running_mode = running_mode()?;
92    if !running_mode.contains(RunningMode::CHAT) && !running_mode.contains(RunningMode::RAG) {
93        let err_msg = "The chat completion is only supported in the chat or rag mode.";
94
95        #[cfg(feature = "logging")]
96        error!(target: "stdout", "{err_msg}");
97
98        return Err(LlamaCoreError::Operation(err_msg.to_string()));
99    }
100
101    let model_name = chat_request.model.clone();
102    let id = match &chat_request.user {
103        Some(id) => id.clone(),
104        None => gen_chat_id(),
105    };
106    #[cfg(feature = "logging")]
107    info!(target: "stdout", "user: {}", &id);
108
109    #[cfg(feature = "logging")]
110    info!(target: "stdout", "Check model metadata");
111
112    // update metadata
113    let mut metadata = check_model_metadata(chat_request)?;
114
115    // parse the `include_usage` option
116    let include_usage = match chat_request.stream_options {
117        Some(ref stream_options) => stream_options.include_usage.unwrap_or_default(),
118        None => metadata.include_usage,
119    };
120    #[cfg(feature = "logging")]
121    info!(target: "stdout", "include_usage: {include_usage}");
122
123    #[cfg(feature = "logging")]
124    info!(target: "stdout", "Build the chat prompt");
125
126    // build prompt
127    let (prompt, avaible_completion_tokens, tool_use) =
128        build_prompt(model_name.as_ref(), chat_request)?;
129
130    #[cfg(feature = "logging")]
131    {
132        info!(target: "stdout", "prompt:\n{}", &prompt);
133        info!(target: "stdout", "available_completion_tokens: {avaible_completion_tokens}");
134        info!(target: "stdout", "tool_use: {tool_use}");
135    }
136
137    #[cfg(feature = "logging")]
138    info!(target: "stdout", "Update the n_predict");
139
140    // update metadata n_predict
141    update_n_predict(chat_request, &mut metadata, avaible_completion_tokens)?;
142
143    #[cfg(feature = "logging")]
144    info!(target: "stdout", "Feed the prompt to the model");
145
146    // set prompt
147    set_prompt(chat_request.model.as_ref(), &prompt)?;
148
149    let stream = match tool_use {
150        false => (ChatStream::new(model_name, id, include_usage, None), false),
151        true => {
152            let chat_graphs = match CHAT_GRAPHS.get() {
153                Some(chat_graphs) => chat_graphs,
154                None => {
155                    let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
156
157                    #[cfg(feature = "logging")]
158                    error!(target: "stdout", "{}", &err_msg);
159
160                    return Err(LlamaCoreError::Operation(err_msg.into()));
161                }
162            };
163
164            let mut chat_graphs = chat_graphs.lock().map_err(|e| {
165                let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
166
167                #[cfg(feature = "logging")]
168                error!(target: "stdout", "{}", &err_msg);
169
170                LlamaCoreError::Operation(err_msg)
171            })?;
172
173            match model_name {
174                Some(model_name) => match chat_graphs.contains_key(&model_name) {
175                    true => {
176                        let graph = chat_graphs.get_mut(&model_name).unwrap();
177                        chat_stream_for_tool(graph, id, include_usage)?
178                    }
179                    false => match chat_graphs.iter_mut().next() {
180                        Some((_, graph)) => chat_stream_for_tool(graph, id, include_usage)?,
181                        None => {
182                            let err_msg = "There is no model available in the chat graphs.";
183
184                            #[cfg(feature = "logging")]
185                            error!(target: "stdout", "{}", &err_msg);
186
187                            return Err(LlamaCoreError::Operation(err_msg.into()));
188                        }
189                    },
190                },
191                None => match chat_graphs.iter_mut().next() {
192                    Some((_, graph)) => chat_stream_for_tool(graph, id, include_usage)?,
193                    None => {
194                        let err_msg = "There is no model available in the chat graphs.";
195
196                        #[cfg(feature = "logging")]
197                        error!(target: "stdout", "{}", &err_msg);
198
199                        return Err(LlamaCoreError::Operation(err_msg.into()));
200                    }
201                },
202            }
203        }
204    };
205
206    #[cfg(feature = "logging")]
207    info!(target: "stdout", "End of the chat completion stream.");
208
209    Ok(stream)
210}
211
212fn chat_stream_for_tool(
213    graph: &mut Graph<GgmlMetadata>,
214    id: impl Into<String>,
215    include_usage: bool,
216) -> Result<(ChatStream, bool), LlamaCoreError> {
217    #[cfg(feature = "logging")]
218    info!(target: "stdout", "Handle chat request with available tools by the model named {}.", graph.name());
219
220    let id = id.into();
221
222    match graph.compute() {
223        Ok(_) => {
224            // Retrieve the output.
225            let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
226            let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
227                let err_msg = format!(
228                    "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
229                );
230
231                #[cfg(feature = "logging")]
232                error!(target: "stdout", "{}", &err_msg);
233
234                LlamaCoreError::Operation(err_msg)
235            })?;
236
237            #[cfg(feature = "logging")]
238            info!(target: "stdout", "raw generation:\n{output}");
239
240            // post-process
241            let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
242                LlamaCoreError::Operation(format!("Failed to post-process the output. {e}"))
243            })?;
244
245            #[cfg(feature = "logging")]
246            info!(target: "stdout", "post-processed generation:\n{}", &message);
247
248            // retrieve the number of prompt and completion tokens
249            let token_info = get_token_info_by_graph(graph)?;
250
251            #[cfg(feature = "logging")]
252            info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
253
254            let usage = Some(Usage {
255                prompt_tokens: token_info.prompt_tokens,
256                completion_tokens: token_info.completion_tokens,
257                total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
258            });
259
260            let created = SystemTime::now()
261                .duration_since(std::time::UNIX_EPOCH)
262                .map_err(|e| {
263                    let err_msg = format!("Failed to get the current time. Reason: {e}");
264
265                    #[cfg(feature = "logging")]
266                    error!(target: "stdout", "{}", &err_msg);
267
268                    LlamaCoreError::Operation(err_msg)
269                })?;
270
271            if graph.metadata.prompt_template != PromptTemplateType::MistralTool
272                && graph.metadata.prompt_template != PromptTemplateType::ChatMLTool
273                && graph.metadata.prompt_template != PromptTemplateType::GroqLlama3Tool
274                && graph.metadata.prompt_template != PromptTemplateType::Llama3Tool
275                && graph.metadata.prompt_template != PromptTemplateType::InternLM2Tool
276                && graph.metadata.prompt_template != PromptTemplateType::NemotronTool
277                && graph.metadata.prompt_template != PromptTemplateType::FunctionaryV32
278                && graph.metadata.prompt_template != PromptTemplateType::FunctionaryV31
279                && graph.metadata.prompt_template != PromptTemplateType::MistralSmallTool
280                && graph.metadata.prompt_template != PromptTemplateType::Llama4Chat
281                && graph.metadata.prompt_template != PromptTemplateType::Qwen3NoThink
282                && graph.metadata.prompt_template != PromptTemplateType::Smol3NoThink
283                && graph.metadata.prompt_template != PromptTemplateType::Gemma3
284                && graph.metadata.prompt_template != PromptTemplateType::GptOss
285                && graph.metadata.prompt_template != PromptTemplateType::Qwen3Agent
286            {
287                let err_msg = format!("Unsupported prompt template: {}. The tool use is only supported for 'mistral-tool', 'chatml-tool', 'groq-llama3-tool', 'llama-3-tool', 'internlm-2-tool', 'nemotron-tool', 'functionary-31', 'functionary-32', 'mistral-small-tool', 'llama-4-chat', 'qwen3-no-think', 'smol-3-no-think', 'gemma-3', 'gpt-oss' and 'qwen3-agent' prompt templates.", graph.metadata.prompt_template);
288
289                #[cfg(feature = "logging")]
290                error!(target: "stdout", "{}", &err_msg);
291
292                return Err(LlamaCoreError::Operation(err_msg));
293            }
294
295            let parsed_result = parse_tool_calls(&message, graph.metadata.prompt_template)?;
296
297            let content = if parsed_result.tool_calls.is_empty() {
298                Some(parsed_result.raw.clone())
299            } else {
300                parsed_result.content.clone()
301            };
302
303            let (tool_calls, include_tool_calls) = match parsed_result.tool_calls.is_empty() {
304                false => {
305                    let tool_calls: Vec<ToolCallForChunk> = parsed_result
306                        .tool_calls
307                        .into_iter()
308                        .enumerate()
309                        .map(|(index, tool_call)| ToolCallForChunk {
310                            index,
311                            id: tool_call.id,
312                            ty: tool_call.ty,
313                            function: tool_call.function,
314                        })
315                        .collect();
316                    (tool_calls, true)
317                }
318                true => (vec![], false),
319            };
320
321            // tool_calls chunk
322            let tool_call_chunk = {
323                let chat_completion_chunk = ChatCompletionChunk {
324                    id: id.clone(),
325                    object: "chat.completion.chunk".to_string(),
326                    created: created.as_secs(),
327                    model: graph.name().to_owned(),
328                    system_fingerprint: "fp_44709d6fcb".to_string(),
329                    choices: vec![ChatCompletionChunkChoice {
330                        index: 0,
331                        delta: ChatCompletionChunkChoiceDelta {
332                            role: ChatCompletionRole::Assistant,
333                            content,
334                            tool_calls,
335                        },
336                        logprobs: None,
337                        finish_reason: None,
338                    }],
339                    usage: None,
340                };
341                let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
342                    let err_msg = format!("Failed to serialize chat completion chunk. Reason: {e}");
343
344                    #[cfg(feature = "logging")]
345                    error!(target: "stdout", "{}", &err_msg);
346
347                    LlamaCoreError::Operation(err_msg)
348                })?;
349
350                format!("data: {chunk_str}\n\n")
351            };
352
353            // token uage chunk
354            let usage_chunk = {
355                let chat_completion_chunk = ChatCompletionChunk {
356                    id: id.clone(),
357                    object: "chat.completion.chunk".to_string(),
358                    created: created.as_secs(),
359                    model: graph.name().to_owned(),
360                    system_fingerprint: "fp_44709d6fcb".to_string(),
361                    choices: vec![],
362                    usage,
363                };
364                let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
365                    let err_msg = format!("Failed to serialize chat completion chunk. Reason: {e}");
366
367                    #[cfg(feature = "logging")]
368                    error!(target: "stdout", "{}", &err_msg);
369
370                    LlamaCoreError::Operation(err_msg)
371                })?;
372
373                format!("data: {chunk_str}\n\n")
374            };
375
376            // ending chunk
377            let ending_chunk = "data: [DONE]\n\n".to_string();
378
379            let chunks = vec![tool_call_chunk, usage_chunk, ending_chunk];
380
381            let stream = ChatStream::new(
382                Some(graph.name().to_owned()),
383                id,
384                include_usage,
385                Some(chunks),
386            );
387
388            Ok((stream, include_tool_calls))
389        }
390        Err(wasmedge_wasi_nn::Error::BackendError(wasmedge_wasi_nn::BackendError::ContextFull)) => {
391            // Retrieve the output.
392            let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
393            let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
394                let err_msg = format!(
395                    "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
396                );
397
398                #[cfg(feature = "logging")]
399                error!(target: "stdout", "{}", &err_msg);
400
401                LlamaCoreError::Operation(err_msg)
402            })?;
403
404            // post-process
405            let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
406                let err_msg = format!("Failed to post-process the output. {e}");
407
408                #[cfg(feature = "logging")]
409                error!(target: "stdout", "{}", &err_msg);
410
411                LlamaCoreError::Operation(err_msg)
412            })?;
413
414            // retrieve the number of prompt and completion tokens
415            let token_info = get_token_info_by_graph(graph)?;
416
417            #[cfg(feature = "logging")]
418            info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
419
420            let usage = Some(Usage {
421                prompt_tokens: token_info.prompt_tokens,
422                completion_tokens: token_info.completion_tokens,
423                total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
424            });
425
426            let created = SystemTime::now()
427                .duration_since(std::time::UNIX_EPOCH)
428                .map_err(|e| {
429                    let err_msg = format!("Failed to get the current time. Reason: {e}");
430
431                    #[cfg(feature = "logging")]
432                    error!(target: "stdout", "{}", &err_msg);
433
434                    LlamaCoreError::Operation(err_msg)
435                })?;
436
437            // context full chunk
438            let context_full_chunk = {
439                let chat_completion_chunk = ChatCompletionChunk {
440                    id: id.clone(),
441                    object: "chat.completion.chunk".to_string(),
442                    created: created.as_secs(),
443                    model: graph.name().to_owned(),
444                    system_fingerprint: "fp_44709d6fcb".to_string(),
445                    choices: vec![ChatCompletionChunkChoice {
446                        index: 0,
447                        delta: ChatCompletionChunkChoiceDelta {
448                            role: ChatCompletionRole::Assistant,
449                            content: Some(message),
450                            tool_calls: vec![],
451                        },
452                        logprobs: None,
453                        finish_reason: Some(FinishReason::length),
454                    }],
455                    usage: None,
456                };
457
458                // serialize chat completion chunk
459                let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
460                    let err_msg = format!("Failed to serialize chat completion chunk. Reason: {e}");
461
462                    #[cfg(feature = "logging")]
463                    error!(target: "stdout", "{}", &err_msg);
464
465                    LlamaCoreError::Operation(err_msg)
466                })?;
467
468                format!("data: {chunk_str}\n\n")
469            };
470
471            // usage chunk
472            let usage_chunk = {
473                let chat_completion_chunk = ChatCompletionChunk {
474                    id: id.clone(),
475                    object: "chat.completion.chunk".to_string(),
476                    created: created.as_secs(),
477                    model: graph.name().to_owned(),
478                    system_fingerprint: "fp_44709d6fcb".to_string(),
479                    choices: vec![],
480                    usage,
481                };
482
483                // serialize chat completion chunk
484                let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
485                    let err_msg = format!("Failed to serialize chat completion chunk. Reason: {e}");
486
487                    #[cfg(feature = "logging")]
488                    error!(target: "stdout", "{}", &err_msg);
489
490                    LlamaCoreError::Operation(err_msg)
491                })?;
492
493                format!("data: {chunk_str}\n\n")
494            };
495
496            // ending chunk
497            let ending_chunk = "data: [DONE]\n\n".to_string();
498
499            let chunks = vec![context_full_chunk, usage_chunk, ending_chunk];
500
501            let stream = ChatStream::new(
502                Some(graph.name().to_owned()),
503                id,
504                include_usage,
505                Some(chunks),
506            );
507
508            Ok((stream, false))
509        }
510        Err(wasmedge_wasi_nn::Error::BackendError(
511            wasmedge_wasi_nn::BackendError::PromptTooLong,
512        )) => {
513            #[cfg(feature = "logging")]
514            warn!(target: "stdout", "The prompt is too long. Please reduce the length of your input and try again.");
515
516            // Retrieve the output.
517            let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
518            let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
519                let err_msg = format!(
520                    "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
521                );
522
523                #[cfg(feature = "logging")]
524                error!(target: "stdout", "{}", &err_msg);
525
526                LlamaCoreError::Operation(err_msg)
527            })?;
528
529            // post-process
530            let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
531                let err_msg = format!("Failed to post-process the output. {e}");
532
533                #[cfg(feature = "logging")]
534                error!(target: "stdout", "{}", &err_msg);
535
536                LlamaCoreError::Operation(err_msg)
537            })?;
538
539            // retrieve the number of prompt and completion token
540            let token_info = get_token_info_by_graph(graph)?;
541
542            #[cfg(feature = "logging")]
543            info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
544
545            let usage = Some(Usage {
546                prompt_tokens: token_info.prompt_tokens,
547                completion_tokens: token_info.completion_tokens,
548                total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
549            });
550
551            let created = SystemTime::now()
552                .duration_since(std::time::UNIX_EPOCH)
553                .map_err(|e| {
554                    let err_msg = format!("Failed to get the current time. Reason: {e}");
555
556                    #[cfg(feature = "logging")]
557                    error!(target: "stdout", "{}", &err_msg);
558
559                    LlamaCoreError::Operation(err_msg)
560                })?;
561
562            // prompt too long chunk
563            let prompt_too_long_chunk = {
564                let chat_completion_chunk = ChatCompletionChunk {
565                    id: id.clone(),
566                    object: "chat.completion.chunk".to_string(),
567                    created: created.as_secs(),
568                    model: graph.name().to_owned(),
569                    system_fingerprint: "fp_44709d6fcb".to_string(),
570                    choices: vec![ChatCompletionChunkChoice {
571                        index: 0,
572                        delta: ChatCompletionChunkChoiceDelta {
573                            role: ChatCompletionRole::Assistant,
574                            content: Some(message),
575                            tool_calls: vec![],
576                        },
577                        logprobs: None,
578                        finish_reason: Some(FinishReason::length),
579                    }],
580                    usage: None,
581                };
582
583                // serialize chat completion chunk
584                let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
585                    let err_msg = format!("Failed to serialize chat completion chunk. Reason: {e}");
586
587                    #[cfg(feature = "logging")]
588                    error!(target: "stdout", "{}", &err_msg);
589
590                    LlamaCoreError::Operation(err_msg)
591                })?;
592
593                format!("data: {chunk_str}\n\n")
594            };
595
596            // usage chunk
597            let usage_chunk = {
598                let chat_completion_chunk = ChatCompletionChunk {
599                    id: id.clone(),
600                    object: "chat.completion.chunk".to_string(),
601                    created: created.as_secs(),
602                    model: graph.name().to_owned(),
603                    system_fingerprint: "fp_44709d6fcb".to_string(),
604                    choices: vec![],
605                    usage,
606                };
607
608                // serialize chat completion chunk
609                let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
610                    let err_msg = format!("Failed to serialize chat completion chunk. Reason: {e}");
611
612                    #[cfg(feature = "logging")]
613                    error!(target: "stdout", "{}", &err_msg);
614
615                    LlamaCoreError::Operation(err_msg)
616                })?;
617
618                format!("data: {chunk_str}\n\n")
619            };
620
621            // ending chunk
622            let ending_chunk = "data: [DONE]\n\n".to_string();
623
624            let chunks = vec![prompt_too_long_chunk, usage_chunk, ending_chunk];
625
626            let stream = ChatStream::new(
627                Some(graph.name().to_owned()),
628                id,
629                include_usage,
630                Some(chunks),
631            );
632
633            Ok((stream, false))
634        }
635        Err(e) => {
636            let err_msg = format!("Failed to compute the chat completion. Reason: {e}");
637
638            #[cfg(feature = "logging")]
639            error!(target: "stdout", "{}", &err_msg);
640
641            Err(LlamaCoreError::Backend(BackendError::Compute(err_msg)))
642        }
643    }
644}
645
646async fn chat_once(
647    chat_request: &mut ChatCompletionRequest,
648) -> Result<(ChatCompletionObject, bool), LlamaCoreError> {
649    #[cfg(feature = "logging")]
650    info!(target: "stdout", "Processing chat completion request in non-stream mode");
651
652    let running_mode = running_mode()?;
653    if !running_mode.contains(RunningMode::CHAT) && !running_mode.contains(RunningMode::RAG) {
654        let err_msg = "The chat completion is only supported in the chat or rag mode.";
655
656        #[cfg(feature = "logging")]
657        error!(target: "stdout", "{err_msg}");
658
659        return Err(LlamaCoreError::Operation(err_msg.to_string()));
660    }
661
662    let model_name = chat_request.model.clone();
663    let id = match &chat_request.user {
664        Some(id) => id.clone(),
665        None => gen_chat_id(),
666    };
667
668    #[cfg(feature = "logging")]
669    info!(target: "stdout", "user: {}", &id);
670
671    #[cfg(feature = "logging")]
672    info!(target: "stdout", "Check model metadata");
673
674    // update metadata
675    let mut metadata = check_model_metadata(chat_request)?;
676
677    #[cfg(feature = "logging")]
678    info!(target: "stdout", "Build the chat prompt");
679
680    // build prompt
681    let (prompt, avaible_completion_tokens, tool_use) =
682        build_prompt(model_name.as_ref(), chat_request)?;
683
684    #[cfg(feature = "logging")]
685    {
686        info!(target: "stdout", "prompt:\n{}", &prompt);
687        info!(target: "stdout", "available_completion_tokens: {avaible_completion_tokens}");
688        info!(target: "stdout", "tool_use: {tool_use}");
689    }
690
691    #[cfg(feature = "logging")]
692    info!(target: "stdout", "Update n_predict");
693
694    // update metadata n_predict
695    update_n_predict(chat_request, &mut metadata, avaible_completion_tokens)?;
696
697    #[cfg(feature = "logging")]
698    info!(target: "stdout", "Feed the prompt to the model");
699
700    // feed the prompt to the model
701    set_prompt(model_name.as_ref(), &prompt)?;
702
703    #[cfg(feature = "logging")]
704    info!(target: "stdout", "Compute chat completion.");
705
706    // compute
707    let res = compute(model_name.as_ref(), id, tool_use);
708
709    #[cfg(feature = "logging")]
710    info!(target: "stdout", "End of the chat completion");
711
712    // reset the model metadata
713    reset_model_metadata(model_name.as_ref())?;
714
715    res
716}
717
718fn compute(
719    model_name: Option<&String>,
720    id: impl Into<String>,
721    tool_use: bool,
722) -> Result<(ChatCompletionObject, bool), LlamaCoreError> {
723    let chat_graphs = match CHAT_GRAPHS.get() {
724        Some(chat_graphs) => chat_graphs,
725        None => {
726            let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
727
728            #[cfg(feature = "logging")]
729            error!(target: "stdout", "{}", &err_msg);
730
731            return Err(LlamaCoreError::Operation(err_msg.into()));
732        }
733    };
734
735    let mut chat_graphs = chat_graphs.lock().map_err(|e| {
736        let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
737
738        #[cfg(feature = "logging")]
739        error!(target: "stdout", "{}", &err_msg);
740
741        LlamaCoreError::Operation(err_msg)
742    })?;
743
744    match model_name {
745        Some(model_name) => match chat_graphs.contains_key(model_name) {
746            true => {
747                let graph = chat_graphs.get_mut(model_name).unwrap();
748                compute_by_graph(graph, id, tool_use)
749            }
750            false => match chat_graphs.iter_mut().next() {
751                Some((_, graph)) => compute_by_graph(graph, id, tool_use),
752                None => {
753                    let err_msg = "There is no model available in the chat graphs.";
754
755                    #[cfg(feature = "logging")]
756                    error!(target: "stdout", "{}", &err_msg);
757
758                    Err(LlamaCoreError::Operation(err_msg.into()))
759                }
760            },
761        },
762        None => match chat_graphs.iter_mut().next() {
763            Some((_, graph)) => compute_by_graph(graph, id, tool_use),
764            None => {
765                let err_msg = "There is no model available in the chat graphs.";
766
767                #[cfg(feature = "logging")]
768                error!(target: "stdout", "{}", &err_msg);
769
770                Err(LlamaCoreError::Operation(err_msg.into()))
771            }
772        },
773    }
774}
775
776fn compute_by_graph(
777    graph: &mut Graph<GgmlMetadata>,
778    id: impl Into<String>,
779    tool_use: bool,
780) -> Result<(ChatCompletionObject, bool), LlamaCoreError> {
781    #[cfg(feature = "logging")]
782    info!(target: "stdout", "Compute chat completion by the model named {}.", graph.name());
783
784    match graph.compute() {
785        Ok(_) => {
786            // Retrieve the output.
787            let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
788            let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
789                let err_msg = format!(
790                    "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
791                );
792
793                #[cfg(feature = "logging")]
794                error!(target: "stdout", "{}", &err_msg);
795
796                LlamaCoreError::Operation(err_msg)
797            })?;
798
799            #[cfg(feature = "logging")]
800            info!(target: "stdout", "raw generation: {output}");
801
802            // post-process
803            let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
804                LlamaCoreError::Operation(format!("Failed to post-process the output. {e}"))
805            })?;
806
807            #[cfg(feature = "logging")]
808            info!(target: "stdout", "post-processed generation:\n{}", &message);
809
810            // retrieve the number of prompt and completion tokens
811            let token_info = get_token_info_by_graph(graph)?;
812
813            #[cfg(feature = "logging")]
814            info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
815
816            let created = SystemTime::now()
817                .duration_since(std::time::UNIX_EPOCH)
818                .map_err(|e| {
819                    let err_msg = format!("Failed to get the current time. Reason: {e}");
820
821                    #[cfg(feature = "logging")]
822                    error!(target: "stdout", "{}", &err_msg);
823
824                    LlamaCoreError::Operation(err_msg)
825                })?;
826
827            match tool_use {
828                true => {
829                    if graph.metadata.prompt_template != PromptTemplateType::MistralTool
830                        && graph.metadata.prompt_template != PromptTemplateType::ChatMLTool
831                        && graph.metadata.prompt_template != PromptTemplateType::GroqLlama3Tool
832                        && graph.metadata.prompt_template != PromptTemplateType::Llama3Tool
833                        && graph.metadata.prompt_template != PromptTemplateType::InternLM2Tool
834                        && graph.metadata.prompt_template != PromptTemplateType::NemotronTool
835                        && graph.metadata.prompt_template != PromptTemplateType::FunctionaryV32
836                        && graph.metadata.prompt_template != PromptTemplateType::FunctionaryV31
837                        && graph.metadata.prompt_template != PromptTemplateType::MistralSmallTool
838                        && graph.metadata.prompt_template != PromptTemplateType::Llama4Chat
839                        && graph.metadata.prompt_template != PromptTemplateType::Qwen3NoThink
840                        && graph.metadata.prompt_template != PromptTemplateType::Smol3NoThink
841                        && graph.metadata.prompt_template != PromptTemplateType::Gemma3
842                        && graph.metadata.prompt_template != PromptTemplateType::GptOss
843                        && graph.metadata.prompt_template != PromptTemplateType::Qwen3Agent
844                    {
845                        let err_msg = format!("Unsupported prompt template: {}. The tool use is only supported for 'mistral-tool', 'chatml-tool', 'groq-llama3-tool', 'llama-3-tool', 'internlm-2-tool', 'nemotron-tool', 'functionary-31', 'functionary-32', 'mistral-small-tool', 'llama-4-chat', 'qwen3-no-think', 'smol-3-no-think', 'gemma-3', 'gpt-oss', and 'qwen3-agent' prompt templates.", graph.metadata.prompt_template);
846
847                        #[cfg(feature = "logging")]
848                        error!(target: "stdout", "{}", &err_msg);
849
850                        return Err(LlamaCoreError::Operation(err_msg));
851                    }
852
853                    let parsed_result = parse_tool_calls(&message, graph.metadata.prompt_template)?;
854
855                    let (finish_reason, content, include_tool_calls) =
856                        if parsed_result.tool_calls.is_empty() {
857                            (FinishReason::stop, Some(parsed_result.raw.clone()), false)
858                        } else if graph.metadata.prompt_template != PromptTemplateType::Qwen3Agent {
859                            (
860                                FinishReason::tool_calls,
861                                Some(parsed_result.raw.clone()),
862                                true,
863                            )
864                        } else {
865                            (
866                                FinishReason::tool_calls,
867                                parsed_result.content.clone(),
868                                true,
869                            )
870                        };
871
872                    let res = ChatCompletionObject {
873                        id: id.into(),
874                        object: String::from("chat.completion"),
875                        created: created.as_secs(),
876                        model: graph.name().to_owned(),
877                        choices: vec![ChatCompletionObjectChoice {
878                            index: 0,
879                            message: ChatCompletionObjectMessage {
880                                role: ChatCompletionRole::Assistant,
881                                content,
882                                tool_calls: parsed_result.tool_calls,
883                                function_call: None,
884                            },
885                            finish_reason,
886                            logprobs: None,
887                        }],
888                        usage: Usage {
889                            prompt_tokens: token_info.prompt_tokens,
890                            completion_tokens: token_info.completion_tokens,
891                            total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
892                        },
893                    };
894
895                    // create ChatCompletionResponse
896                    Ok((res, include_tool_calls))
897                }
898                false => {
899                    // create ChatCompletionResponse
900                    let res = ChatCompletionObject {
901                        id: id.into(),
902                        object: String::from("chat.completion"),
903                        created: created.as_secs(),
904                        model: graph.name().to_owned(),
905                        choices: vec![ChatCompletionObjectChoice {
906                            index: 0,
907                            message: ChatCompletionObjectMessage {
908                                role: ChatCompletionRole::Assistant,
909                                content: Some(message),
910                                tool_calls: vec![],
911                                function_call: None,
912                            },
913                            finish_reason: FinishReason::stop,
914                            logprobs: None,
915                        }],
916                        usage: Usage {
917                            prompt_tokens: token_info.prompt_tokens,
918                            completion_tokens: token_info.completion_tokens,
919                            total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
920                        },
921                    };
922
923                    Ok((res, false))
924                }
925            }
926        }
927        Err(wasmedge_wasi_nn::Error::BackendError(wasmedge_wasi_nn::BackendError::ContextFull)) => {
928            // Retrieve the output.
929            let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
930            let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
931                let err_msg = format!(
932                    "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
933                );
934
935                #[cfg(feature = "logging")]
936                error!(target: "stdout", "{}", &err_msg);
937
938                LlamaCoreError::Operation(err_msg)
939            })?;
940
941            // post-process
942            let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
943                let err_msg = format!("Failed to post-process the output. {e}");
944
945                #[cfg(feature = "logging")]
946                error!(target: "stdout", "{}", &err_msg);
947
948                LlamaCoreError::Operation(err_msg)
949            })?;
950
951            // retrieve the number of prompt and completion tokens
952            let token_info = get_token_info_by_graph(graph)?;
953
954            #[cfg(feature = "logging")]
955            info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
956
957            let created = SystemTime::now()
958                .duration_since(std::time::UNIX_EPOCH)
959                .map_err(|e| {
960                    let err_msg = format!("Failed to get the current time. Reason: {e}");
961
962                    #[cfg(feature = "logging")]
963                    error!(target: "stdout", "{}", &err_msg);
964
965                    LlamaCoreError::Operation(err_msg)
966                })?;
967
968            // create ChatCompletionResponse
969            let res = ChatCompletionObject {
970                id: id.into(),
971                object: String::from("chat.completion"),
972                created: created.as_secs(),
973                model: graph.name().to_owned(),
974                choices: vec![ChatCompletionObjectChoice {
975                    index: 0,
976                    message: ChatCompletionObjectMessage {
977                        role: ChatCompletionRole::Assistant,
978                        content: Some(message),
979                        tool_calls: vec![],
980                        function_call: None,
981                    },
982                    finish_reason: FinishReason::length,
983                    logprobs: None,
984                }],
985                usage: Usage {
986                    prompt_tokens: token_info.prompt_tokens,
987                    completion_tokens: token_info.completion_tokens,
988                    total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
989                },
990            };
991
992            Ok((res, false))
993        }
994        Err(wasmedge_wasi_nn::Error::BackendError(
995            wasmedge_wasi_nn::BackendError::PromptTooLong,
996        )) => {
997            #[cfg(feature = "logging")]
998            warn!(target: "stdout", "The prompt is too long. Please reduce the length of your input and try again.");
999
1000            // Retrieve the output.
1001            let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
1002            let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
1003                let err_msg = format!(
1004                    "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
1005                );
1006
1007                #[cfg(feature = "logging")]
1008                error!(target: "stdout", "{}", &err_msg);
1009
1010                LlamaCoreError::Operation(err_msg)
1011            })?;
1012
1013            // post-process
1014            let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
1015                let err_msg = format!("Failed to post-process the output. {e}");
1016
1017                #[cfg(feature = "logging")]
1018                error!(target: "stdout", "{}", &err_msg);
1019
1020                LlamaCoreError::Operation(err_msg)
1021            })?;
1022
1023            // retrieve the number of prompt and completion token
1024            let token_info = get_token_info_by_graph(graph)?;
1025
1026            #[cfg(feature = "logging")]
1027            info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
1028
1029            let usage = Usage {
1030                prompt_tokens: token_info.prompt_tokens,
1031                completion_tokens: token_info.completion_tokens,
1032                total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
1033            };
1034
1035            let created = SystemTime::now()
1036                .duration_since(std::time::UNIX_EPOCH)
1037                .map_err(|e| {
1038                    let err_msg = format!("Failed to get the current time. Reason: {e}");
1039
1040                    #[cfg(feature = "logging")]
1041                    error!(target: "stdout", "{}", &err_msg);
1042
1043                    LlamaCoreError::Operation(err_msg)
1044                })?;
1045
1046            // create ChatCompletionResponse
1047            let res = ChatCompletionObject {
1048                id: id.into(),
1049                object: String::from("chat.completion"),
1050                created: created.as_secs(),
1051                model: graph.name().to_owned(),
1052                choices: vec![ChatCompletionObjectChoice {
1053                    index: 0,
1054                    message: ChatCompletionObjectMessage {
1055                        role: ChatCompletionRole::Assistant,
1056                        content: Some(message),
1057                        tool_calls: vec![],
1058                        function_call: None,
1059                    },
1060                    finish_reason: FinishReason::length,
1061                    logprobs: None,
1062                }],
1063                usage,
1064            };
1065
1066            Ok((res, false))
1067        }
1068        Err(e) => {
1069            let err_msg = format!("Failed to compute the chat completion. Reason: {e}");
1070
1071            #[cfg(feature = "logging")]
1072            error!(target: "stdout", "{}", &err_msg);
1073
1074            Err(LlamaCoreError::Backend(BackendError::Compute(err_msg)))
1075        }
1076    }
1077}
1078
1079fn parse_tool_calls(
1080    input: &str,
1081    prompt_template: PromptTemplateType,
1082) -> Result<ParseResult, LlamaCoreError> {
1083    match prompt_template {
1084        PromptTemplateType::MistralTool => match regex::Regex::new(r"\[\{.*?\}\]") {
1085            Ok(re) => {
1086                let mut values: Vec<serde_json::Value> = vec![];
1087                for cap in re.captures_iter(input) {
1088                    let matched = &cap[0];
1089
1090                    #[cfg(feature = "logging")]
1091                    info!(target: "stdout", "captured: {matched}");
1092
1093                    match serde_json::from_str::<Vec<serde_json::Value>>(matched) {
1094                        Ok(group) => values.extend(group),
1095                        Err(e) => {
1096                            let err_msg =
1097                                format!("Failed to deserialize generated tool calls. Reason: {e}");
1098
1099                            #[cfg(feature = "logging")]
1100                            error!(target: "stdout", "{}", &err_msg);
1101
1102                            return Err(LlamaCoreError::Operation(err_msg));
1103                        }
1104                    }
1105                }
1106
1107                let mut tool_calls: Vec<ToolCall> = vec![];
1108                for value in values.iter() {
1109                    let name = match value.get("name") {
1110                        Some(name) => name.to_string().replace("\"", ""),
1111                        None => {
1112                            let err_msg = format!(
1113                                "Failed to get the name of the function. Tool call: {value:?}"
1114                            );
1115
1116                            #[cfg(feature = "logging")]
1117                            error!(target: "stdout", "{}", &err_msg);
1118
1119                            return Err(LlamaCoreError::Operation(err_msg));
1120                        }
1121                    };
1122
1123                    let arguments = match value.get("arguments") {
1124                        Some(arguments) => arguments.to_string(),
1125                        None => {
1126                            let err_msg = format!(
1127                                "Failed to get the arguments of the function. Tool call: {value:?}"
1128                            );
1129
1130                            #[cfg(feature = "logging")]
1131                            error!(target: "stdout", "{}", &err_msg);
1132
1133                            return Err(LlamaCoreError::Operation(err_msg));
1134                        }
1135                    };
1136
1137                    let function = Function { name, arguments };
1138
1139                    let tool_call = ToolCall {
1140                        id: "call_abc123".to_string(),
1141                        ty: "function".to_string(),
1142                        function,
1143                    };
1144
1145                    tool_calls.push(tool_call);
1146                }
1147
1148                let parsed = ParseResult {
1149                    raw: input.to_owned(),
1150                    content: None,
1151                    tool_calls,
1152                };
1153
1154                #[cfg(feature = "logging")]
1155                info!(target: "stdout", "parsed result: {parsed:?}");
1156
1157                Ok(parsed)
1158            }
1159            Err(e) => {
1160                let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1161
1162                #[cfg(feature = "logging")]
1163                error!(target: "stdout", "{}", &err_msg);
1164
1165                Err(LlamaCoreError::Operation(err_msg))
1166            }
1167        },
1168        PromptTemplateType::ChatMLTool => {
1169            match regex::Regex::new(r"<tool_call>(.*?)</tool_call>") {
1170                Ok(re) => {
1171                    let mut values: Vec<serde_json::Value> = vec![];
1172                    for cap in re.captures_iter(input) {
1173                        let matched = cap[1].replace("\\n", ""); // Remove "\\n" from the captured group
1174
1175                        #[cfg(feature = "logging")]
1176                        info!(target: "stdout", "captured: {}", &matched);
1177
1178                        match serde_json::from_str::<serde_json::Value>(&matched) {
1179                            Ok(value) => values.push(value),
1180                            Err(e) => {
1181                                let err_msg = format!(
1182                                    "Failed to deserialize generated tool calls. Reason: {e}"
1183                                );
1184
1185                                #[cfg(feature = "logging")]
1186                                error!(target: "stdout", "{}", &err_msg);
1187
1188                                return Err(LlamaCoreError::Operation(err_msg));
1189                            }
1190                        }
1191                    }
1192
1193                    let mut tool_calls: Vec<ToolCall> = vec![];
1194                    for value in values.iter() {
1195                        let name = match value.get("name") {
1196                            Some(name) => name.to_string().replace("\"", ""),
1197                            None => {
1198                                let err_msg = format!(
1199                                    "Failed to get the name of the function. Tool call: {value:?}"
1200                                );
1201
1202                                #[cfg(feature = "logging")]
1203                                error!(target: "stdout", "{}", &err_msg);
1204
1205                                return Err(LlamaCoreError::Operation(err_msg));
1206                            }
1207                        };
1208
1209                        let arguments = match value.get("arguments") {
1210                            Some(arguments) => arguments.to_string(),
1211                            None => {
1212                                let err_msg = format!(
1213                                    "Failed to get the arguments of the function. Tool call: {value:?}"
1214                                );
1215
1216                                #[cfg(feature = "logging")]
1217                                error!(target: "stdout", "{}", &err_msg);
1218
1219                                return Err(LlamaCoreError::Operation(err_msg));
1220                            }
1221                        };
1222
1223                        let function = Function { name, arguments };
1224
1225                        let tool_call = ToolCall {
1226                            id: "call_abc123".to_string(),
1227                            ty: "function".to_string(),
1228                            function,
1229                        };
1230
1231                        tool_calls.push(tool_call);
1232                    }
1233
1234                    let parsed = ParseResult {
1235                        raw: input.to_owned(),
1236                        content: None,
1237                        tool_calls,
1238                    };
1239
1240                    #[cfg(feature = "logging")]
1241                    info!(target: "stdout", "parsed result: {parsed:?}");
1242
1243                    Ok(parsed)
1244                }
1245                Err(e) => {
1246                    let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1247
1248                    #[cfg(feature = "logging")]
1249                    error!(target: "stdout", "{}", &err_msg);
1250
1251                    Err(LlamaCoreError::Operation(err_msg))
1252                }
1253            }
1254        }
1255        PromptTemplateType::GroqLlama3Tool => {
1256            #[cfg(feature = "logging")]
1257            info!(target: "stdout", "raw input: {input}");
1258
1259            match regex::Regex::new(r"(?s)<tool_call>((.|\r|\n)*?)</tool_call>") {
1260                Ok(re) => {
1261                    let mut values: Vec<serde_json::Value> = vec![];
1262                    for cap in re.captures_iter(input) {
1263                        let matched = cap[1].trim();
1264
1265                        #[cfg(feature = "logging")]
1266                        info!(target: "stdout", "captured: {matched}");
1267
1268                        match serde_json::from_str::<serde_json::Value>(matched) {
1269                            Ok(value) => values.push(value),
1270                            Err(e) => {
1271                                let err_msg = format!(
1272                                    "Failed to deserialize generated tool calls. Reason: {e}"
1273                                );
1274
1275                                #[cfg(feature = "logging")]
1276                                error!(target: "stdout", "{}", &err_msg);
1277
1278                                return Err(LlamaCoreError::Operation(err_msg));
1279                            }
1280                        }
1281                    }
1282
1283                    let mut tool_calls: Vec<ToolCall> = vec![];
1284                    for value in values.iter() {
1285                        let name = match value.get("name") {
1286                            Some(name) => name.to_string().replace("\"", ""),
1287                            None => {
1288                                let err_msg = format!(
1289                                    "Failed to get the name of the function. Tool call: {value:?}"
1290                                );
1291
1292                                #[cfg(feature = "logging")]
1293                                error!(target: "stdout", "{}", &err_msg);
1294
1295                                return Err(LlamaCoreError::Operation(err_msg));
1296                            }
1297                        };
1298
1299                        let arguments = match value.get("arguments") {
1300                            Some(arguments) => {
1301                                if arguments.is_string() {
1302                                    arguments.as_str().unwrap().to_string()
1303                                } else if arguments.is_object() {
1304                                    let map = arguments.as_object().unwrap();
1305
1306                                    #[cfg(feature = "logging")]
1307                                    info!(target: "stdout", "func arguments: {map:?}");
1308
1309                                    serde_json::to_string(map).unwrap()
1310                                } else {
1311                                    serde_json::to_string(arguments).unwrap()
1312                                }
1313                            }
1314                            None => {
1315                                let err_msg = format!(
1316                                    "Failed to get the arguments of the function. Tool call: {value:?}"
1317                                );
1318
1319                                #[cfg(feature = "logging")]
1320                                error!(target: "stdout", "{}", &err_msg);
1321
1322                                return Err(LlamaCoreError::Operation(err_msg));
1323                            }
1324                        };
1325
1326                        let function = Function { name, arguments };
1327
1328                        let tool_call = ToolCall {
1329                            id: "call_abc123".to_string(),
1330                            ty: "function".to_string(),
1331                            function,
1332                        };
1333
1334                        tool_calls.push(tool_call);
1335                    }
1336
1337                    let parsed = if tool_calls.is_empty() {
1338                        ParseResult {
1339                            raw: input.to_owned(),
1340                            content: Some(input.to_owned()),
1341                            tool_calls: vec![],
1342                        }
1343                    } else {
1344                        ParseResult {
1345                            raw: input.to_owned(),
1346                            content: None,
1347                            tool_calls,
1348                        }
1349                    };
1350
1351                    #[cfg(feature = "logging")]
1352                    info!(target: "stdout", "parsed result: {parsed:?}");
1353
1354                    Ok(parsed)
1355                }
1356                Err(e) => {
1357                    let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1358
1359                    #[cfg(feature = "logging")]
1360                    error!(target: "stdout", "{}", &err_msg);
1361
1362                    Err(LlamaCoreError::Operation(err_msg))
1363                }
1364            }
1365        }
1366        PromptTemplateType::Llama3Tool => {
1367            #[cfg(feature = "logging")]
1368            info!(target: "stdout", "raw input: {input}");
1369
1370            let re = match regex::Regex::new(r"^\{(.|\r|\n)*\}$") {
1371                Ok(re) => re,
1372                Err(e) => {
1373                    let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1374
1375                    #[cfg(feature = "logging")]
1376                    error!(target: "stdout", "{}", &err_msg);
1377
1378                    return Err(LlamaCoreError::Operation(err_msg));
1379                }
1380            };
1381
1382            if re.is_match(input) {
1383                match serde_json::from_str::<serde_json::Value>(input) {
1384                    Ok(value) => {
1385                        let values: Vec<serde_json::Value> = vec![value];
1386
1387                        let mut tool_calls: Vec<ToolCall> = vec![];
1388                        for value in values.iter() {
1389                            let name = match value.get("name") {
1390                                Some(name) => name.to_string().replace("\"", ""),
1391                                None => {
1392                                    let err_msg = format!(
1393                                        "Failed to get the name of the function. Tool call: {value:?}"
1394                                    );
1395
1396                                    #[cfg(feature = "logging")]
1397                                    error!(target: "stdout", "{}", &err_msg);
1398
1399                                    return Err(LlamaCoreError::Operation(err_msg));
1400                                }
1401                            };
1402
1403                            let arguments = match value.get("parameters") {
1404                                Some(arguments) => arguments.to_string(),
1405                                None => {
1406                                    let err_msg = format!(
1407                                        "Failed to get the arguments of the function. Tool call: {value:?}"
1408                                    );
1409
1410                                    #[cfg(feature = "logging")]
1411                                    error!(target: "stdout", "{}", &err_msg);
1412
1413                                    return Err(LlamaCoreError::Operation(err_msg));
1414                                }
1415                            };
1416
1417                            let function = Function { name, arguments };
1418
1419                            let tool_call = ToolCall {
1420                                id: "call_abc123".to_string(),
1421                                ty: "function".to_string(),
1422                                function,
1423                            };
1424
1425                            tool_calls.push(tool_call);
1426                        }
1427
1428                        let parsed = ParseResult {
1429                            raw: input.to_owned(),
1430                            content: None,
1431                            tool_calls,
1432                        };
1433
1434                        #[cfg(feature = "logging")]
1435                        info!(target: "stdout", "parsed result: {parsed:?}");
1436
1437                        Ok(parsed)
1438                    }
1439                    Err(e) => {
1440                        let err_msg =
1441                            format!("Failed to deserialize generated tool calls. Reason: {e}");
1442
1443                        #[cfg(feature = "logging")]
1444                        error!(target: "stdout", "{}", &err_msg);
1445
1446                        Err(LlamaCoreError::Operation(err_msg))
1447                    }
1448                }
1449            } else {
1450                let parsed = ParseResult {
1451                    raw: input.to_owned(),
1452                    content: None,
1453                    tool_calls: vec![],
1454                };
1455
1456                #[cfg(feature = "logging")]
1457                info!(target: "stdout", "parsed result: {parsed:?}");
1458
1459                Ok(parsed)
1460            }
1461        }
1462        PromptTemplateType::InternLM2Tool => {
1463            #[cfg(feature = "logging")]
1464            info!(target: "stdout", "raw input: {input}");
1465
1466            let blocks: Vec<&str> = input.trim().split("<|action_start|><|plugin|>").collect();
1467
1468            #[cfg(feature = "logging")]
1469            info!(target: "stdout", "blocks: {blocks:?}");
1470
1471            let mut tool_calls: Vec<ToolCall> = vec![];
1472            let mut content = String::new();
1473            for block in blocks {
1474                let block = block.trim();
1475                if !block.is_empty() {
1476                    if block.ends_with("<|action_end|>") {
1477                        let value = block.trim().trim_end_matches("<|action_end|>");
1478
1479                        #[cfg(feature = "logging")]
1480                        info!(target: "stdout", "tool call: {value}");
1481
1482                        match serde_json::from_str::<serde_json::Value>(value) {
1483                            Ok(value) => {
1484                                let name = match value.get("name") {
1485                                    Some(name) => name.to_string().replace("\"", ""),
1486                                    None => {
1487                                        let err_msg = format!(
1488                                            "Failed to get the name of the function. Tool call: {value:?}"
1489                                        );
1490
1491                                        #[cfg(feature = "logging")]
1492                                        error!(target: "stdout", "{}", &err_msg);
1493
1494                                        return Err(LlamaCoreError::Operation(err_msg));
1495                                    }
1496                                };
1497
1498                                let arguments = match value.get("parameters") {
1499                                    Some(arguments) => arguments.to_string(),
1500                                    None => {
1501                                        let err_msg = format!(
1502                                            "Failed to get the arguments of the function. Tool call: {value:?}"
1503                                        );
1504
1505                                        #[cfg(feature = "logging")]
1506                                        error!(target: "stdout", "{}", &err_msg);
1507
1508                                        return Err(LlamaCoreError::Operation(err_msg));
1509                                    }
1510                                };
1511
1512                                let function = Function { name, arguments };
1513
1514                                let tool_call = ToolCall {
1515                                    id: "call_abc123".to_string(),
1516                                    ty: "function".to_string(),
1517                                    function,
1518                                };
1519
1520                                tool_calls.push(tool_call);
1521                            }
1522                            Err(e) => {
1523                                let err_msg = format!(
1524                                    "Failed to deserialize generated tool calls. Reason: {e}"
1525                                );
1526
1527                                #[cfg(feature = "logging")]
1528                                error!(target: "stdout", "{}", &err_msg);
1529
1530                                return Err(LlamaCoreError::Operation(err_msg));
1531                            }
1532                        }
1533                    } else {
1534                        content.push_str(block);
1535                        content.push('\n');
1536                    }
1537                }
1538            }
1539
1540            let parsed = match content.is_empty() {
1541                true => ParseResult {
1542                    raw: input.to_owned(),
1543                    content: None,
1544                    tool_calls,
1545                },
1546                false => ParseResult {
1547                    raw: input.to_owned(),
1548                    content: Some(content.trim().to_owned()),
1549                    tool_calls,
1550                },
1551            };
1552
1553            #[cfg(feature = "logging")]
1554            info!(target: "stdout", "parsed result: {parsed:?}");
1555
1556            Ok(parsed)
1557        }
1558        PromptTemplateType::NemotronTool => {
1559            #[cfg(feature = "logging")]
1560            info!(target: "stdout", "raw input: {input}");
1561
1562            match regex::Regex::new(r"(?s)<toolcall>\s*(.*?)\s*</toolcall>") {
1563                Ok(re) => {
1564                    let mut values: Vec<serde_json::Value> = vec![];
1565                    for cap in re.captures_iter(input) {
1566                        #[cfg(feature = "logging")]
1567                        info!(target: "stdout", "captured: {}", &cap[0]);
1568
1569                        #[cfg(feature = "logging")]
1570                        info!(target: "stdout", "extracted: {}", &cap[1]);
1571
1572                        let matched = cap[1].trim();
1573
1574                        #[cfg(feature = "logging")]
1575                        info!(target: "stdout", "captured: {matched}");
1576
1577                        match serde_json::from_str::<serde_json::Value>(matched) {
1578                            Ok(value) => values.push(value),
1579                            Err(e) => {
1580                                let err_msg = format!(
1581                                    "Failed to deserialize generated tool calls. Reason: {e}"
1582                                );
1583
1584                                #[cfg(feature = "logging")]
1585                                error!(target: "stdout", "{}", &err_msg);
1586
1587                                return Err(LlamaCoreError::Operation(err_msg));
1588                            }
1589                        }
1590                    }
1591
1592                    let mut tool_calls: Vec<ToolCall> = vec![];
1593                    for value in values.iter() {
1594                        let name = match value.get("name") {
1595                            Some(name) => name.to_string().replace("\"", ""),
1596                            None => {
1597                                let err_msg = format!(
1598                                    "Failed to get the name of the function. Tool call: {value:?}"
1599                                );
1600
1601                                #[cfg(feature = "logging")]
1602                                error!(target: "stdout", "{}", &err_msg);
1603
1604                                return Err(LlamaCoreError::Operation(err_msg));
1605                            }
1606                        };
1607
1608                        let arguments = match value.get("arguments") {
1609                            Some(arguments) => arguments.to_string(),
1610                            None => {
1611                                let err_msg = format!(
1612                                    "Failed to get the arguments of the function. Tool call: {value:?}"
1613                                );
1614
1615                                #[cfg(feature = "logging")]
1616                                error!(target: "stdout", "{}", &err_msg);
1617
1618                                return Err(LlamaCoreError::Operation(err_msg));
1619                            }
1620                        };
1621
1622                        let function = Function { name, arguments };
1623
1624                        let tool_call = ToolCall {
1625                            id: "call_abc123".to_string(),
1626                            ty: "function".to_string(),
1627                            function,
1628                        };
1629
1630                        tool_calls.push(tool_call);
1631                    }
1632
1633                    let parsed = ParseResult {
1634                        raw: input.to_owned(),
1635                        content: None,
1636                        tool_calls,
1637                    };
1638
1639                    #[cfg(feature = "logging")]
1640                    info!(target: "stdout", "parsed result: {parsed:?}");
1641
1642                    Ok(parsed)
1643                }
1644                Err(e) => {
1645                    let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1646
1647                    #[cfg(feature = "logging")]
1648                    error!(target: "stdout", "{}", &err_msg);
1649
1650                    Err(LlamaCoreError::Operation(err_msg))
1651                }
1652            }
1653        }
1654        PromptTemplateType::FunctionaryV32 => {
1655            #[cfg(feature = "logging")]
1656            info!(target: "stdout", "raw input: {input}");
1657
1658            match regex::Regex::new(r">>>\s*(\w+)\s*\{(.*)\}<\|eot_id\|>") {
1659                Ok(re) => {
1660                    let mut tool_calls: Vec<ToolCall> = vec![];
1661                    for cap in re.captures_iter(input) {
1662                        #[cfg(feature = "logging")]
1663                        info!(target: "stdout", "func_name: {}", &cap[1]);
1664
1665                        #[cfg(feature = "logging")]
1666                        info!(target: "stdout", "arguments: {}", &cap[2]);
1667
1668                        let tool_call = ToolCall {
1669                            id: "call_abc123".to_string(),
1670                            ty: "function".to_string(),
1671                            function: Function {
1672                                name: cap[1].to_string(),
1673                                arguments: cap[2].to_string(),
1674                            },
1675                        };
1676
1677                        tool_calls.push(tool_call);
1678                    }
1679
1680                    let parsed = ParseResult {
1681                        raw: input.to_owned(),
1682                        content: None,
1683                        tool_calls,
1684                    };
1685
1686                    #[cfg(feature = "logging")]
1687                    info!(target: "stdout", "parsed result: {parsed:?}");
1688
1689                    Ok(parsed)
1690                }
1691                Err(e) => {
1692                    let warn_msg = format!("Failed to create a regex pattern. Reason: {e}");
1693
1694                    #[cfg(feature = "logging")]
1695                    warn!(target: "stdout", "{}", &warn_msg);
1696
1697                    Ok(ParseResult {
1698                        raw: input.to_owned(),
1699                        content: None,
1700                        tool_calls: vec![],
1701                    })
1702                }
1703            }
1704        }
1705        PromptTemplateType::FunctionaryV31 => {
1706            #[cfg(feature = "logging")]
1707            info!(target: "stdout", "raw input: {input}");
1708
1709            match regex::Regex::new(r"<function=(\w+)>\s*(\{.*?\})</function>") {
1710                Ok(re) => {
1711                    let mut tool_calls: Vec<ToolCall> = vec![];
1712                    for cap in re.captures_iter(input) {
1713                        #[cfg(feature = "logging")]
1714                        info!(target: "stdout", "func_name: {}", &cap[1]);
1715
1716                        #[cfg(feature = "logging")]
1717                        info!(target: "stdout", "arguments: {}", &cap[2]);
1718
1719                        let tool_call = ToolCall {
1720                            id: "call_abc123".to_string(),
1721                            ty: "function".to_string(),
1722                            function: Function {
1723                                name: cap[1].to_string(),
1724                                arguments: cap[2].to_string(),
1725                            },
1726                        };
1727
1728                        tool_calls.push(tool_call);
1729                    }
1730
1731                    let parsed = ParseResult {
1732                        raw: input.to_owned(),
1733                        content: None,
1734                        tool_calls,
1735                    };
1736
1737                    #[cfg(feature = "logging")]
1738                    info!(target: "stdout", "parsed result: {parsed:?}");
1739
1740                    Ok(parsed)
1741                }
1742                Err(e) => {
1743                    let warn_msg = format!("Failed to create a regex pattern. Reason: {e}");
1744
1745                    #[cfg(feature = "logging")]
1746                    warn!(target: "stdout", "{}", &warn_msg);
1747
1748                    Ok(ParseResult {
1749                        raw: input.to_owned(),
1750                        content: None,
1751                        tool_calls: vec![],
1752                    })
1753                }
1754            }
1755        }
1756        PromptTemplateType::MistralSmallTool => {
1757            #[cfg(feature = "logging")]
1758            info!(target: "stdout", "raw input: {input}");
1759
1760            match regex::Regex::new(r"\[TOOL_CALLS\]\s*(\[(.*?)\])") {
1761                Ok(re) => {
1762                    let mut values: Vec<serde_json::Value> = vec![];
1763                    if let Some(cap) = re.captures(input) {
1764                        let matched = cap[1].trim();
1765
1766                        #[cfg(feature = "logging")]
1767                        info!(target: "stdout", "captured: {matched}");
1768
1769                        match serde_json::from_str::<Vec<serde_json::Value>>(matched) {
1770                            Ok(vals) => values = vals,
1771                            Err(e) => {
1772                                let err_msg = format!(
1773                                    "Failed to deserialize generated tool calls. Reason: {e}"
1774                                );
1775
1776                                #[cfg(feature = "logging")]
1777                                error!(target: "stdout", "{}", &err_msg);
1778
1779                                return Err(LlamaCoreError::Operation(err_msg));
1780                            }
1781                        }
1782                    };
1783
1784                    let mut tool_calls: Vec<ToolCall> = vec![];
1785                    for value in values.iter() {
1786                        if let Some(object_map) = value.as_object() {
1787                            if object_map.contains_key("function") {
1788                                let mut function = Function {
1789                                    name: String::new(),
1790                                    arguments: String::new(),
1791                                };
1792
1793                                let value = object_map.get("function").unwrap();
1794                                let func_map = value.as_object().unwrap();
1795                                if func_map.contains_key("name") {
1796                                    let func_name = func_map.get("name").unwrap().as_str().unwrap();
1797                                    println!("Function name: {func_name:?}");
1798
1799                                    function.name = func_name.to_string();
1800                                }
1801                                if func_map.contains_key("arguments") {
1802                                    let args = func_map.get("arguments").unwrap();
1803                                    let arguments = args.to_string();
1804                                    println!("Arguments: {arguments:?}");
1805
1806                                    function.arguments = arguments;
1807                                }
1808
1809                                let tool_call = ToolCall {
1810                                    id: "call_abc123".to_string(),
1811                                    ty: "function".to_string(),
1812                                    function,
1813                                };
1814
1815                                tool_calls.push(tool_call);
1816                            } else if object_map.contains_key("name") {
1817                                let mut function = Function {
1818                                    name: String::new(),
1819                                    arguments: String::new(),
1820                                };
1821
1822                                let name = object_map.get("name").unwrap().as_str().unwrap();
1823                                println!("name: {name:?}");
1824                                function.name = name.to_string();
1825
1826                                if object_map.contains_key("arguments") {
1827                                    let args = object_map.get("arguments").unwrap();
1828                                    let arguments = args.to_string();
1829                                    println!("Arguments: {arguments:?}");
1830
1831                                    function.arguments = arguments;
1832                                }
1833
1834                                let tool_call = ToolCall {
1835                                    id: "call_abc123".to_string(),
1836                                    ty: "function".to_string(),
1837                                    function,
1838                                };
1839
1840                                tool_calls.push(tool_call);
1841                            }
1842                        }
1843                    }
1844
1845                    let parsed = ParseResult {
1846                        raw: input.to_owned(),
1847                        content: None,
1848                        tool_calls,
1849                    };
1850
1851                    #[cfg(feature = "logging")]
1852                    info!(target: "stdout", "parsed result: {parsed:?}");
1853
1854                    Ok(parsed)
1855                }
1856                Err(e) => {
1857                    let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1858
1859                    #[cfg(feature = "logging")]
1860                    error!(target: "stdout", "{}", &err_msg);
1861
1862                    Err(LlamaCoreError::Operation(err_msg))
1863                }
1864            }
1865        }
1866        PromptTemplateType::Llama4Chat => {
1867            #[cfg(feature = "logging")]
1868            info!(target: "stdout", "raw input: {input:?}");
1869
1870            let mut tool_calls: Vec<ToolCall> = vec![];
1871            if let Ok(value) = serde_json::from_str::<serde_json::Value>(input) {
1872                match value.as_object() {
1873                    Some(object_map) => {
1874                        #[cfg(feature = "logging")]
1875                        debug!(target: "stdout", "object_map: {object_map:?}");
1876
1877                        // parse function name
1878                        if object_map.contains_key("name") {
1879                            let name = object_map.get("name").unwrap().as_str().unwrap();
1880
1881                            #[cfg(feature = "logging")]
1882                            debug!(target: "stdout", "name: {name:?}");
1883
1884                            let mut function = Function {
1885                                name: name.to_string(),
1886                                arguments: String::new(),
1887                            };
1888
1889                            // parse function arguments
1890                            if object_map.contains_key("parameters") {
1891                                let args = object_map.get("parameters").unwrap();
1892                                let arguments = args.to_string();
1893
1894                                #[cfg(feature = "logging")]
1895                                debug!(target: "stdout", "arguments: {:?}", &arguments);
1896
1897                                function.arguments = arguments;
1898                            }
1899
1900                            tool_calls.push(ToolCall {
1901                                id: "call_abc123".to_string(),
1902                                ty: "function".to_string(),
1903                                function,
1904                            });
1905                        } else {
1906                            let err_msg = format!(
1907                                "Failed to get the name of the function. raw input: {input:?}"
1908                            );
1909
1910                            #[cfg(feature = "logging")]
1911                            error!(target: "stdout", "{}", &err_msg);
1912
1913                            return Err(LlamaCoreError::Operation(err_msg));
1914                        }
1915                    }
1916                    None => {
1917                        let err_msg = format!("Failed to parse the JSON string. JSON: {input}");
1918
1919                        #[cfg(feature = "logging")]
1920                        error!(target: "stdout", "{}", &err_msg);
1921
1922                        return Err(LlamaCoreError::Operation(err_msg));
1923                    }
1924                }
1925            }
1926
1927            let parsed = ParseResult {
1928                raw: input.to_owned(),
1929                content: None,
1930                tool_calls,
1931            };
1932
1933            #[cfg(feature = "logging")]
1934            info!(target: "stdout", "parsed result: {parsed:?}");
1935
1936            Ok(parsed)
1937        }
1938        PromptTemplateType::Qwen3NoThink | PromptTemplateType::Smol3NoThink => {
1939            #[cfg(feature = "logging")]
1940            info!(target: "stdout", "raw input: {input:?}");
1941
1942            match regex::Regex::new(r"(?s)<tool_call>((.|\r|\n)*?)</tool_call>") {
1943                Ok(re) => {
1944                    let mut values: Vec<serde_json::Value> = vec![];
1945                    for cap in re.captures_iter(input) {
1946                        let mut matched = cap[1].trim();
1947
1948                        if matched.starts_with("\\n") {
1949                            matched = matched.trim_start_matches("\\n");
1950                        }
1951
1952                        if matched.ends_with("\\n") {
1953                            matched = matched.trim_end_matches("\\n");
1954                        }
1955
1956                        #[cfg(feature = "logging")]
1957                        info!(target: "stdout", "captured: {matched:#?}");
1958
1959                        if !matched.is_empty() {
1960                            match serde_json::from_str::<serde_json::Value>(matched) {
1961                                Ok(value) => values.push(value),
1962                                Err(e) => {
1963                                    let err_msg = format!(
1964                                    "Failed to deserialize generated tool calls: {matched:#?}. Reason: {e}"
1965                                );
1966
1967                                    #[cfg(feature = "logging")]
1968                                    error!(target: "stdout", "{}", &err_msg);
1969
1970                                    return Err(LlamaCoreError::Operation(err_msg));
1971                                }
1972                            }
1973                        }
1974                    }
1975
1976                    let mut tool_calls: Vec<ToolCall> = vec![];
1977                    for value in values.iter() {
1978                        let name = match value.get("name") {
1979                            Some(name) => name.to_string().replace("\"", ""),
1980                            None => {
1981                                let err_msg = format!(
1982                                    "Failed to get the name of the function. Tool call: {value:?}"
1983                                );
1984
1985                                #[cfg(feature = "logging")]
1986                                error!(target: "stdout", "{}", &err_msg);
1987
1988                                return Err(LlamaCoreError::Operation(err_msg));
1989                            }
1990                        };
1991
1992                        let arguments = match value.get("arguments") {
1993                            Some(arguments) => {
1994                                if arguments.is_string() {
1995                                    arguments.as_str().unwrap().to_string()
1996                                } else if arguments.is_object() {
1997                                    let map = arguments.as_object().unwrap();
1998
1999                                    #[cfg(feature = "logging")]
2000                                    info!(target: "stdout", "func arguments: {map:?}");
2001
2002                                    serde_json::to_string(map).unwrap()
2003                                } else {
2004                                    serde_json::to_string(arguments).unwrap()
2005                                }
2006                            }
2007                            None => {
2008                                let err_msg = format!(
2009                                    "Failed to get the arguments of the function. Tool call: {value:?}"
2010                                );
2011
2012                                #[cfg(feature = "logging")]
2013                                error!(target: "stdout", "{}", &err_msg);
2014
2015                                return Err(LlamaCoreError::Operation(err_msg));
2016                            }
2017                        };
2018
2019                        let function = Function { name, arguments };
2020
2021                        let tool_call = ToolCall {
2022                            id: "call_abc123".to_string(),
2023                            ty: "function".to_string(),
2024                            function,
2025                        };
2026
2027                        tool_calls.push(tool_call);
2028                    }
2029
2030                    let parsed = if tool_calls.is_empty() {
2031                        ParseResult {
2032                            raw: input.to_owned(),
2033                            content: Some(input.to_owned()),
2034                            tool_calls: vec![],
2035                        }
2036                    } else {
2037                        ParseResult {
2038                            raw: input.to_owned(),
2039                            content: None,
2040                            tool_calls,
2041                        }
2042                    };
2043
2044                    #[cfg(feature = "logging")]
2045                    info!(target: "stdout", "parsed result: {parsed:?}");
2046
2047                    Ok(parsed)
2048                }
2049                Err(e) => {
2050                    let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
2051
2052                    #[cfg(feature = "logging")]
2053                    error!(target: "stdout", "{}", &err_msg);
2054
2055                    Err(LlamaCoreError::Operation(err_msg))
2056                }
2057            }
2058        }
2059        PromptTemplateType::Gemma3 => {
2060            #[cfg(feature = "logging")]
2061            info!(target: "stdout", "raw input: {input:?}");
2062
2063            match regex::Regex::new(r"(?s)```json\s*(.*?)\s*```") {
2064                Ok(re) => {
2065                    let mut values: Vec<serde_json::Value> = vec![];
2066                    for cap in re.captures_iter(input) {
2067                        let mut matched = cap[1].trim();
2068
2069                        if matched.starts_with("\\n") {
2070                            matched = matched.trim_start_matches("\\n");
2071                        }
2072
2073                        if matched.ends_with("\\n") {
2074                            matched = matched.trim_end_matches("\\n");
2075                        }
2076
2077                        #[cfg(feature = "logging")]
2078                        info!(target: "stdout", "captured: {matched:#?}");
2079
2080                        if !matched.is_empty() {
2081                            match serde_json::from_str::<serde_json::Value>(matched) {
2082                                Ok(value) => values.push(value),
2083                                Err(e) => {
2084                                    let err_msg = format!(
2085                                    "Failed to deserialize generated tool calls: {matched:#?}. Reason: {e}"
2086                                );
2087
2088                                    #[cfg(feature = "logging")]
2089                                    error!(target: "stdout", "{}", &err_msg);
2090
2091                                    return Err(LlamaCoreError::Operation(err_msg));
2092                                }
2093                            }
2094                        }
2095                    }
2096
2097                    let mut tool_calls: Vec<ToolCall> = vec![];
2098                    for value in values.iter() {
2099                        let name = match value.get("name") {
2100                            Some(name) => name.to_string().replace("\"", ""),
2101                            None => {
2102                                let err_msg = format!(
2103                                    "Failed to get the name of the function. Tool call: {value:?}"
2104                                );
2105
2106                                #[cfg(feature = "logging")]
2107                                error!(target: "stdout", "{}", &err_msg);
2108
2109                                return Err(LlamaCoreError::Operation(err_msg));
2110                            }
2111                        };
2112
2113                        let arguments = match value.get("arguments") {
2114                            Some(arguments) => {
2115                                if arguments.is_string() {
2116                                    arguments.as_str().unwrap().to_string()
2117                                } else if arguments.is_object() {
2118                                    let map = arguments.as_object().unwrap();
2119
2120                                    #[cfg(feature = "logging")]
2121                                    info!(target: "stdout", "func arguments: {map:?}");
2122
2123                                    serde_json::to_string(map).unwrap()
2124                                } else {
2125                                    serde_json::to_string(arguments).unwrap()
2126                                }
2127                            }
2128                            None => {
2129                                let err_msg = format!(
2130                                    "Failed to get the arguments of the function. Tool call: {value:?}"
2131                                );
2132
2133                                #[cfg(feature = "logging")]
2134                                error!(target: "stdout", "{}", &err_msg);
2135
2136                                return Err(LlamaCoreError::Operation(err_msg));
2137                            }
2138                        };
2139
2140                        let function = Function { name, arguments };
2141
2142                        let tool_call = ToolCall {
2143                            id: "call_abc123".to_string(),
2144                            ty: "function".to_string(),
2145                            function,
2146                        };
2147
2148                        tool_calls.push(tool_call);
2149                    }
2150
2151                    let parsed = if tool_calls.is_empty() {
2152                        ParseResult {
2153                            raw: input.to_owned(),
2154                            content: Some(input.to_owned()),
2155                            tool_calls: vec![],
2156                        }
2157                    } else {
2158                        ParseResult {
2159                            raw: input.to_owned(),
2160                            content: None,
2161                            tool_calls,
2162                        }
2163                    };
2164
2165                    #[cfg(feature = "logging")]
2166                    info!(target: "stdout", "parsed result: {parsed:?}");
2167
2168                    Ok(parsed)
2169                }
2170                Err(e) => {
2171                    let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
2172
2173                    #[cfg(feature = "logging")]
2174                    error!(target: "stdout", "{}", &err_msg);
2175
2176                    Err(LlamaCoreError::Operation(err_msg))
2177                }
2178            }
2179        }
2180        PromptTemplateType::GptOss => {
2181            #[cfg(feature = "logging")]
2182            info!(target: "stdout", "raw input: {input:?}");
2183
2184            // Match strings ending with: <|channel|>commentary to=functions.xxxxx <|constrain|>json<|message|>yyyyy<|call|>
2185            match regex::Regex::new(
2186                r"<\|channel\|>commentary to=functions\.([^<\s]+)\s*<\|constrain\|>json<\|message\|>([^<]*)<\|call\|>$",
2187            ) {
2188                Ok(re) => {
2189                    if let Some(cap) = re.captures(input) {
2190                        let function_name = cap[1].trim();
2191                        let arguments = cap[2].trim();
2192
2193                        #[cfg(feature = "logging")]
2194                        info!(target: "stdout", "extracted function_name: {function_name}, arguments: {arguments}");
2195
2196                        let function = Function {
2197                            name: function_name.to_string(),
2198                            arguments: arguments.to_string(),
2199                        };
2200
2201                        let tool_call = ToolCall {
2202                            id: "call_abc123".to_string(),
2203                            ty: "function".to_string(),
2204                            function,
2205                        };
2206
2207                        let parsed = ParseResult {
2208                            raw: input.to_owned(),
2209                            content: None,
2210                            tool_calls: vec![tool_call],
2211                        };
2212
2213                        #[cfg(feature = "logging")]
2214                        info!(target: "stdout", "parsed result: {parsed:?}");
2215
2216                        Ok(parsed)
2217                    } else {
2218                        match regex::Regex::new(r"(?s)```json\s*(.*?)\s*```") {
2219                            Ok(re) => {
2220                                let mut values: Vec<serde_json::Value> = vec![];
2221                                for cap in re.captures_iter(input) {
2222                                    let mut matched = cap[1].trim();
2223
2224                                    if matched.starts_with("\\n") {
2225                                        matched = matched.trim_start_matches("\\n");
2226                                    }
2227
2228                                    if matched.ends_with("\\n") {
2229                                        matched = matched.trim_end_matches("\\n");
2230                                    }
2231
2232                                    #[cfg(feature = "logging")]
2233                                    info!(target: "stdout", "captured: {matched:#?}");
2234
2235                                    if !matched.is_empty() {
2236                                        match serde_json::from_str::<serde_json::Value>(matched) {
2237                                            Ok(value) => values.push(value),
2238                                            Err(e) => {
2239                                                let err_msg = format!(
2240                                                "Failed to deserialize generated tool calls: {matched:#?}. Reason: {e}"
2241                                            );
2242
2243                                                #[cfg(feature = "logging")]
2244                                                error!(target: "stdout", "{}", &err_msg);
2245
2246                                                return Err(LlamaCoreError::Operation(err_msg));
2247                                            }
2248                                        }
2249                                    }
2250                                }
2251
2252                                let mut tool_calls: Vec<ToolCall> = vec![];
2253                                for value in values.iter() {
2254                                    let name = match value.get("name") {
2255                                        Some(name) => name.to_string().replace("\"", ""),
2256                                        None => {
2257                                            let err_msg = format!(
2258                                                "Failed to get the name of the function. Tool call: {value:?}"
2259                                            );
2260
2261                                            #[cfg(feature = "logging")]
2262                                            error!(target: "stdout", "{}", &err_msg);
2263
2264                                            return Err(LlamaCoreError::Operation(err_msg));
2265                                        }
2266                                    };
2267
2268                                    let arguments = match value.get("arguments") {
2269                                        Some(arguments) => {
2270                                            if arguments.is_string() {
2271                                                arguments.as_str().unwrap().to_string()
2272                                            } else if arguments.is_object() {
2273                                                let map = arguments.as_object().unwrap();
2274
2275                                                #[cfg(feature = "logging")]
2276                                                info!(target: "stdout", "func arguments: {map:?}");
2277
2278                                                serde_json::to_string(map).unwrap()
2279                                            } else {
2280                                                serde_json::to_string(arguments).unwrap()
2281                                            }
2282                                        }
2283                                        None => {
2284                                            let err_msg = format!(
2285                                                "Failed to get the arguments of the function. Tool call: {value:?}"
2286                                            );
2287
2288                                            #[cfg(feature = "logging")]
2289                                            error!(target: "stdout", "{}", &err_msg);
2290
2291                                            return Err(LlamaCoreError::Operation(err_msg));
2292                                        }
2293                                    };
2294
2295                                    let function = Function { name, arguments };
2296
2297                                    let tool_call = ToolCall {
2298                                        id: "call_abc123".to_string(),
2299                                        ty: "function".to_string(),
2300                                        function,
2301                                    };
2302
2303                                    tool_calls.push(tool_call);
2304                                }
2305
2306                                let parsed = if tool_calls.is_empty() {
2307                                    ParseResult {
2308                                        raw: input.to_owned(),
2309                                        content: Some(input.to_owned()),
2310                                        tool_calls: vec![],
2311                                    }
2312                                } else {
2313                                    ParseResult {
2314                                        raw: input.to_owned(),
2315                                        content: Some(input.to_owned()),
2316                                        tool_calls,
2317                                    }
2318                                };
2319
2320                                #[cfg(feature = "logging")]
2321                                info!(target: "stdout", "parsed result: {parsed:?}");
2322
2323                                Ok(parsed)
2324                            }
2325                            Err(e) => {
2326                                let err_msg =
2327                                    format!("Failed to create a regex pattern. Reason: {e}");
2328
2329                                #[cfg(feature = "logging")]
2330                                error!(target: "stdout", "{}", &err_msg);
2331
2332                                Err(LlamaCoreError::Operation(err_msg))
2333                            }
2334                        }
2335                    }
2336                }
2337                Err(e) => {
2338                    let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
2339
2340                    #[cfg(feature = "logging")]
2341                    error!(target: "stdout", "{}", &err_msg);
2342
2343                    Err(LlamaCoreError::Operation(err_msg))
2344                }
2345            }
2346        }
2347        PromptTemplateType::Qwen3Agent => {
2348            #[cfg(feature = "logging")]
2349            info!(target: "stdout", "Raw input to tool call parser: {input:?}");
2350
2351            // 检测 <action> 标签
2352            match regex::Regex::new(r"<action>(.*?)</action>")
2353                .unwrap()
2354                .captures(input)
2355            {
2356                Some(captures) => {
2357                    let action = captures.get(1).unwrap().as_str();
2358
2359                    #[cfg(feature = "logging")]
2360                    info!(target: "stdout", "Action: {action}");
2361
2362                    match serde_json::from_str::<serde_json::Value>(action) {
2363                        Ok(value) => {
2364                            let name = match value.get("name") {
2365                                Some(name) => name.to_string().replace("\"", ""),
2366                                None => {
2367                                    let err_msg = format!(
2368                                        "Failed to get the name of the function. Tool call: {value:?}"
2369                                    );
2370
2371                                    #[cfg(feature = "logging")]
2372                                    error!(target: "stdout", "{}", &err_msg);
2373
2374                                    return Err(LlamaCoreError::Operation(err_msg));
2375                                }
2376                            };
2377
2378                            let arguments = match value.get("arguments") {
2379                                Some(arguments) => {
2380                                    if arguments.is_string() {
2381                                        arguments.as_str().unwrap().to_string()
2382                                    } else if arguments.is_object() {
2383                                        let map = arguments.as_object().unwrap();
2384
2385                                        #[cfg(feature = "logging")]
2386                                        info!(target: "stdout", "func arguments: {map:?}");
2387
2388                                        serde_json::to_string(map).unwrap()
2389                                    } else {
2390                                        serde_json::to_string(arguments).unwrap()
2391                                    }
2392                                }
2393                                None => {
2394                                    let err_msg = format!(
2395                                        "Failed to get the arguments of the function. Tool call: {value:?}"
2396                                    );
2397
2398                                    #[cfg(feature = "logging")]
2399                                    error!(target: "stdout", "{}", &err_msg);
2400
2401                                    return Err(LlamaCoreError::Operation(err_msg));
2402                                }
2403                            };
2404
2405                            let function = Function { name, arguments };
2406
2407                            let tool_call = ToolCall {
2408                                id: "call_abc123".to_string(),
2409                                ty: "function".to_string(),
2410                                function,
2411                            };
2412
2413                            Ok(ParseResult {
2414                                raw: input.to_owned(),
2415                                content: Some(input.to_owned()),
2416                                tool_calls: vec![tool_call],
2417                            })
2418                        }
2419                        Err(e) => {
2420                            let err_msg = format!(
2421                            "Failed to deserialize generated tool calls: {action:#?}. Reason: {e}"
2422                        );
2423
2424                            #[cfg(feature = "logging")]
2425                            error!(target: "stdout", "{}", &err_msg);
2426
2427                            Err(LlamaCoreError::Operation(err_msg))
2428                        }
2429                    }
2430                }
2431                None => match input.contains("<final_answer>") {
2432                    true => Ok(ParseResult {
2433                        raw: input.to_owned(),
2434                        content: Some(input.to_owned()),
2435                        tool_calls: vec![],
2436                    }),
2437                    false => {
2438                        let content = format!("<final_answer>{}</final_answer>", input.trim());
2439
2440                        Ok(ParseResult {
2441                            raw: input.to_owned(),
2442                            content: Some(content),
2443                            tool_calls: vec![],
2444                        })
2445                    }
2446                },
2447            }
2448        }
2449        _ => {
2450            let err_msg = format!(
2451                "The tool use is only supported for prompt templates: {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, and {}.",
2452                PromptTemplateType::MistralTool,
2453                PromptTemplateType::ChatMLTool,
2454                PromptTemplateType::GroqLlama3Tool,
2455                PromptTemplateType::Llama3Tool,
2456                PromptTemplateType::InternLM2Tool,
2457                PromptTemplateType::NemotronTool,
2458                PromptTemplateType::FunctionaryV32,
2459                PromptTemplateType::MistralSmallTool,
2460                PromptTemplateType::Llama4Chat,
2461                PromptTemplateType::Qwen3NoThink,
2462                PromptTemplateType::Smol3NoThink,
2463                PromptTemplateType::Gemma3,
2464                PromptTemplateType::GptOss,
2465                PromptTemplateType::Qwen3Agent,
2466            );
2467
2468            #[cfg(feature = "logging")]
2469            error!(target: "stdout", "{}", &err_msg);
2470
2471            Err(LlamaCoreError::Operation(err_msg))
2472        }
2473    }
2474}
2475
2476fn check_model_metadata(
2477    chat_request: &ChatCompletionRequest,
2478) -> Result<GgmlMetadata, LlamaCoreError> {
2479    let mut should_update = false;
2480    let mut metadata = get_model_metadata(chat_request.model.as_ref())?;
2481
2482    // check if necessary to update `image`
2483    if metadata.prompt_template.is_image_supported() {
2484        if let Some(ChatCompletionRequestMessage::User(user_message)) = chat_request.messages.last()
2485        {
2486            if let ChatCompletionUserMessageContent::Parts(parts) = user_message.content() {
2487                for part in parts {
2488                    if let ContentPart::Image(image_part) = part {
2489                        let image = image_part.image();
2490
2491                        if image.is_url() {
2492                            let err_msg = "The image is provided in URL format. Only base64 format is supported.".to_string();
2493
2494                            #[cfg(feature = "logging")]
2495                            error!(target: "stdout", "{}", &err_msg);
2496
2497                            return Err(LlamaCoreError::Operation(err_msg));
2498                        } else {
2499                            #[cfg(feature = "logging")]
2500                            info!(target: "stdout", "The image is provided in base64 format.");
2501
2502                            // TODO: now only support a single image
2503
2504                            break;
2505                        }
2506                    }
2507                }
2508            }
2509        }
2510    }
2511
2512    // check if necessary to update temperature
2513    if let Some(temp) = chat_request.temperature {
2514        if metadata.temperature != temp {
2515            // update temperature
2516            metadata.temperature = temp;
2517
2518            if !should_update {
2519                should_update = true;
2520            }
2521        }
2522    }
2523
2524    // check if necessary to update top_p
2525    if let Some(top_p) = chat_request.top_p {
2526        if metadata.top_p != top_p {
2527            // update top_p
2528            metadata.top_p = top_p;
2529
2530            if !should_update {
2531                should_update = true;
2532            }
2533        }
2534    }
2535
2536    // check if necessary to update frequency_penalty
2537    if let Some(frequency_penalty) = chat_request.frequency_penalty {
2538        if metadata.frequency_penalty != frequency_penalty {
2539            // update frequency_penalty
2540            metadata.frequency_penalty = frequency_penalty;
2541
2542            if !should_update {
2543                should_update = true;
2544            }
2545        }
2546    }
2547
2548    // check if necessary to update presence_penalty
2549    if let Some(presence_penalty) = chat_request.presence_penalty {
2550        if metadata.presence_penalty != presence_penalty {
2551            // update presence_penalty
2552            metadata.presence_penalty = presence_penalty;
2553
2554            if !should_update {
2555                should_update = true;
2556            }
2557        }
2558    }
2559
2560    // check if the `embedding` option is disabled
2561    if metadata.embeddings {
2562        metadata.embeddings = false;
2563
2564        if !should_update {
2565            should_update = true;
2566        }
2567    }
2568
2569    if should_update {
2570        #[cfg(feature = "logging")]
2571        info!(target: "stdout", "Update the model metadata.");
2572
2573        // update the target graph with the new metadata
2574        update_model_metadata(chat_request.model.as_ref(), &metadata)?;
2575    }
2576
2577    Ok(metadata)
2578}
2579
2580fn update_n_predict(
2581    chat_request: &ChatCompletionRequest,
2582    metadata: &mut GgmlMetadata,
2583    available_completion_tokens: u64,
2584) -> Result<(), LlamaCoreError> {
2585    let mut should_update = false;
2586
2587    #[cfg(feature = "logging")]
2588    info!(target: "stdout", "n_predict: {}", metadata.n_predict);
2589
2590    // From high to low priority
2591    // 1. chat_request.max_completion_tokens
2592    // 2. available_completion_tokens
2593    // 3. n_predict
2594
2595    if let Some(max_completion_tokens) = chat_request.max_completion_tokens {
2596        if metadata.n_predict != max_completion_tokens {
2597            #[cfg(feature = "logging")]
2598            info!(target: "stdout", "Update n_predict with max_completion_tokens from {} to {}", metadata.n_predict, max_completion_tokens);
2599
2600            metadata.n_predict = max_completion_tokens;
2601
2602            if !should_update {
2603                should_update = true;
2604            }
2605        }
2606    }
2607
2608    // TODO: remove this condition after [Issue #3958 on WasmEdge](https://github.com/WasmEdge/WasmEdge/issues/3958) is fixed
2609    if metadata.n_predict == -2 {
2610        #[cfg(feature = "logging")]
2611        info!(target: "stdout", "Update n_predict with available_completion_tokens from {} to {}", metadata.n_predict, available_completion_tokens);
2612
2613        // update n_predict
2614        metadata.n_predict = available_completion_tokens as i32;
2615
2616        if !should_update {
2617            should_update = true;
2618        }
2619    }
2620
2621    if metadata.n_predict == -1
2622        || (metadata.n_predict > 0 && metadata.n_predict < available_completion_tokens as i32)
2623        || (metadata.n_predict < 0 && metadata.n_predict != -2)
2624    // TODO: remove this condition after [Issue #3958 on WasmEdge](https://github.com/WasmEdge/WasmEdge/issues/3958) is fixed
2625    {
2626        #[cfg(feature = "logging")]
2627        info!(target: "stdout", "Update n_predict with available_completion_tokens from {} to {}", metadata.n_predict, available_completion_tokens);
2628
2629        // update n_predict
2630        metadata.n_predict = available_completion_tokens as i32;
2631
2632        if !should_update {
2633            should_update = true;
2634        }
2635    }
2636
2637    if should_update {
2638        #[cfg(feature = "logging")]
2639        info!(target: "stdout", "Update the model metadata.");
2640
2641        // update the target graph with the new metadata
2642        update_model_metadata(chat_request.model.as_ref(), metadata)?;
2643    }
2644
2645    Ok(())
2646}
2647
2648/// Build post-processing for output based on template type
2649fn post_process(
2650    output: impl AsRef<str>,
2651    template_ty: &PromptTemplateType,
2652) -> Result<String, String> {
2653    let output = if *template_ty == PromptTemplateType::Baichuan2 {
2654        if output.as_ref().contains("用户:") {
2655            output.as_ref().trim_end_matches("用户:").trim().to_owned()
2656        } else {
2657            output.as_ref().trim().to_owned()
2658        }
2659    } else if *template_ty == PromptTemplateType::OpenChat {
2660        if output.as_ref().contains("<|end_of_turn|>") {
2661            output
2662                .as_ref()
2663                .trim_end_matches("<|end_of_turn|>")
2664                .trim()
2665                .to_owned()
2666        } else {
2667            output.as_ref().trim().to_owned()
2668        }
2669    } else if *template_ty == PromptTemplateType::GemmaInstruct
2670        || *template_ty == PromptTemplateType::Gemma3
2671    {
2672        let s = output.as_ref().trim();
2673        if s.ends_with("<end_of_turn>") {
2674            s.trim_end_matches("<end_of_turn>").trim().to_owned()
2675        } else {
2676            s.to_owned()
2677        }
2678    } else if *template_ty == PromptTemplateType::ChatML
2679        || *template_ty == PromptTemplateType::ChatMLTool
2680        || *template_ty == PromptTemplateType::InternLM2Tool
2681        || *template_ty == PromptTemplateType::MiniCPMV
2682    {
2683        let mut s = output.as_ref().trim();
2684        if s.ends_with("<|endoftext|>") {
2685            s = s.trim_end_matches("<|endoftext|>").trim();
2686        }
2687
2688        if s.starts_with(":") {
2689            s = s.trim_start_matches(":").trim();
2690        }
2691
2692        // handle Qwen3 empty think tags
2693        let x = {
2694            let pat = r#"<think>
2695
2696</think>
2697"#;
2698            if s.contains(pat) {
2699                let x = s.replace(pat, "");
2700                if x.starts_with("()") {
2701                    x.trim_start_matches("()").to_owned()
2702                } else {
2703                    x.to_owned()
2704                }
2705            } else {
2706                s.to_owned()
2707            }
2708        };
2709        s = x.trim();
2710
2711        if s.contains("<|im_start|>") && s.contains("<|im_end|>") {
2712            let idx_start = s.find("<|im_start|>").unwrap();
2713            let idx_end = s.find("<|im_end|>").unwrap();
2714
2715            match idx_start <= idx_end {
2716                true => s.split("<|im_start|>").collect::<Vec<_>>()[0]
2717                    .trim()
2718                    .to_owned(),
2719                false => s.split("<|im_end|>").collect::<Vec<_>>()[0]
2720                    .trim()
2721                    .to_owned(),
2722            }
2723        } else if s.contains("<|im_start|>") {
2724            s.split("<|im_start|>").collect::<Vec<_>>()[0]
2725                .trim()
2726                .to_owned()
2727        } else if s.contains("<|im_end|>") {
2728            let output = s.trim_end_matches("<|im_end|>").trim();
2729            if output.starts_with(": ") {
2730                output.trim_start_matches(": ").to_owned()
2731            } else {
2732                output.to_owned()
2733            }
2734        } else {
2735            s.to_owned()
2736        }
2737    } else if *template_ty == PromptTemplateType::Zephyr
2738        || *template_ty == PromptTemplateType::MistralLite
2739        || *template_ty == PromptTemplateType::MistralTool
2740        || *template_ty == PromptTemplateType::MistralInstruct
2741        || *template_ty == PromptTemplateType::MistralSmallChat
2742        || *template_ty == PromptTemplateType::MistralSmallTool
2743        || *template_ty == PromptTemplateType::BreezeInstruct
2744    {
2745        if output.as_ref().contains("</s><") {
2746            output.as_ref().trim_end_matches("</s><").trim().to_owned()
2747        } else if output.as_ref().contains("</s>") {
2748            output
2749                .as_ref()
2750                .strip_suffix("</s>")
2751                .unwrap()
2752                .trim()
2753                .to_owned()
2754        } else {
2755            output.as_ref().trim().to_owned()
2756        }
2757    } else if *template_ty == PromptTemplateType::DeepseekChat {
2758        if output.as_ref().contains("<|end_of_sentence|>") {
2759            output
2760                .as_ref()
2761                .trim_end_matches("<|end_of_sentence|>")
2762                .trim()
2763                .replace("<|end_of_sentence|>", " ")
2764                .trim()
2765                .to_owned()
2766        } else {
2767            output.as_ref().trim().to_owned()
2768        }
2769    } else if *template_ty == PromptTemplateType::HumanAssistant {
2770        if output.as_ref().contains("Human:") {
2771            output.as_ref().trim_end_matches("Human:").trim().to_owned()
2772        } else {
2773            output.as_ref().trim().to_owned()
2774        }
2775    } else if *template_ty == PromptTemplateType::SolarInstruct {
2776        let s = output.as_ref().trim();
2777
2778        if s.starts_with("### Answer") {
2779            let s = s.trim_start_matches("###").trim();
2780
2781            if s.starts_with("Answer:\n") {
2782                s.replace("Answer:\n", "Answer: ")
2783            } else {
2784                s.to_owned()
2785            }
2786        } else {
2787            s.to_owned()
2788        }
2789    } else if *template_ty == PromptTemplateType::Llama2Chat
2790        || *template_ty == PromptTemplateType::NemotronTool
2791        || *template_ty == PromptTemplateType::NemotronChat
2792    {
2793        let s = output.as_ref().trim();
2794        if s.ends_with("</s>") {
2795            s.trim_end_matches("</s>").trim().to_owned()
2796        } else {
2797            s.to_owned()
2798        }
2799    } else if *template_ty == PromptTemplateType::Llama3Chat
2800        || *template_ty == PromptTemplateType::GroqLlama3Tool
2801        || *template_ty == PromptTemplateType::Llama3Tool
2802        || *template_ty == PromptTemplateType::FunctionaryV32
2803    {
2804        let s = output.as_ref().trim();
2805        if s.ends_with("<|eot_id|>") {
2806            s.trim_end_matches("<|eot_id|>").trim().to_owned()
2807        } else {
2808            s.to_owned()
2809        }
2810    } else if *template_ty == PromptTemplateType::Phi3Chat {
2811        let s = output.as_ref().trim();
2812        if s.ends_with("<|end|>") {
2813            s.trim_end_matches("<|end|>").trim().to_owned()
2814        } else {
2815            s.to_owned()
2816        }
2817    } else if *template_ty == PromptTemplateType::Phi4Chat {
2818        let mut s = output.as_ref().trim();
2819
2820        if s.starts_with("think>") {
2821            s = s.trim_start_matches("think>").trim();
2822        }
2823
2824        if s.ends_with("<|im_end|>") {
2825            s.trim_end_matches("<|im_end|>").trim().to_owned()
2826        } else if s.ends_with("<|end|>") {
2827            s.trim_end_matches("<|end|>").trim().to_owned()
2828        } else {
2829            s.to_owned()
2830        }
2831    } else if *template_ty == PromptTemplateType::FunctionaryV31 {
2832        let mut s = output.as_ref().trim();
2833        if s.ends_with("<|eot_id|>") {
2834            s = s.trim_end_matches("<|eot_id|>").trim();
2835        }
2836        if s.ends_with("<|eom_id|>") {
2837            s = s.trim_end_matches("<|eom_id|>").trim();
2838        }
2839        s.to_owned()
2840    } else if *template_ty == PromptTemplateType::MoxinChat
2841        || *template_ty == PromptTemplateType::MoxinInstruct
2842    {
2843        let s = output.as_ref().trim();
2844        if s.ends_with("</s>") {
2845            s.trim_end_matches("</s>").trim().to_owned()
2846        } else if s.ends_with("[INST]") {
2847            s.trim_end_matches("[INST]").trim().to_owned()
2848        } else {
2849            s.to_owned()
2850        }
2851    } else if *template_ty == PromptTemplateType::Falcon3 {
2852        let s = output.as_ref().trim();
2853        if s.ends_with("<|endoftext|>") {
2854            s.trim_end_matches("<|endoftext|>").trim().to_owned()
2855        } else {
2856            s.to_owned()
2857        }
2858    } else if *template_ty == PromptTemplateType::Megrez {
2859        let s = output.as_ref().trim();
2860        if s.ends_with("<|turn_end|>") {
2861            s.trim_end_matches("<|turn_end|>").trim().to_owned()
2862        } else {
2863            s.to_owned()
2864        }
2865    } else if *template_ty == PromptTemplateType::Qwen2vl
2866        || *template_ty == PromptTemplateType::Qwen3NoThink
2867        || *template_ty == PromptTemplateType::ChatMLThink
2868    {
2869        let mut s = output.as_ref().trim();
2870
2871        if s.starts_with(":") {
2872            s = s.trim_start_matches(":").trim();
2873        }
2874
2875        if s.starts_with("</think>") {
2876            s = s.trim_start_matches("</think>").trim();
2877        }
2878
2879        if s.ends_with("<|im_end|>") {
2880            s.trim_end_matches("<|im_end|>").trim().to_owned()
2881        } else {
2882            s.to_owned()
2883        }
2884    } else if *template_ty == PromptTemplateType::VicunaLlava {
2885        let s = output.as_ref().trim();
2886        if s.ends_with("</s>") {
2887            s.trim_end_matches("</s>").trim().to_owned()
2888        } else {
2889            s.to_owned()
2890        }
2891    } else if *template_ty == PromptTemplateType::ExaoneDeepChat
2892        || *template_ty == PromptTemplateType::ExaoneChat
2893    {
2894        let mut s = output.as_ref().trim();
2895
2896        if s.ends_with("[|endofturn|]") {
2897            s = s.trim_end_matches("[|endofturn|]").trim();
2898        }
2899
2900        s.to_owned()
2901    } else if *template_ty == PromptTemplateType::Llama4Chat {
2902        let mut s = output.as_ref().trim();
2903
2904        if s.ends_with("<|eot|>") {
2905            s = s.trim_end_matches("<|eot|>").trim();
2906        }
2907
2908        s.to_owned()
2909    } else if *template_ty == PromptTemplateType::Smolvl {
2910        let mut s = output.as_ref().trim();
2911
2912        if s.starts_with(":") {
2913            s = s.trim_start_matches(":").trim();
2914        }
2915
2916        if s.ends_with("<end_of_utterance>") {
2917            s = s.trim_end_matches("<end_of_utterance>").trim();
2918        }
2919
2920        if s.contains("<end_of_utterance>:") {
2921            let parts = s.split("<end_of_utterance>:").collect::<Vec<_>>();
2922            parts.last().unwrap().trim().to_owned()
2923        } else {
2924            s.to_owned()
2925        }
2926    } else if *template_ty == PromptTemplateType::Smol3NoThink {
2927        let mut s = output.as_ref().trim();
2928
2929        if s.ends_with("<|im_end|>") {
2930            s = s.trim_end_matches("<|im_end|>").trim();
2931        }
2932
2933        let re = regex::Regex::new(r"(?s)^<think>.*?</think>\s*").unwrap();
2934        re.replace(s, "").to_string()
2935    } else if *template_ty == PromptTemplateType::GptOss {
2936        let s = output.as_ref().trim();
2937
2938        let re =
2939            regex::Regex::new(r"(?s).*<\|channel\|>final<\|message\|>(.*?)<\|return\|>$").unwrap();
2940
2941        if let Some(caps) = re.captures(s) {
2942            let extracted = &caps[1];
2943            extracted.to_owned()
2944        } else {
2945            s.to_owned()
2946        }
2947    } else if *template_ty == PromptTemplateType::Qwen3Agent {
2948        let mut s = output.as_ref().trim();
2949
2950        if s.starts_with(":") {
2951            s = s.trim_start_matches(":").trim();
2952        }
2953
2954        if s.starts_with("</think>") {
2955            s = s.trim_start_matches("</think>").trim();
2956        }
2957
2958        if s.ends_with("<|im_end|>") {
2959            s = s.trim_end_matches("<|im_end|>").trim();
2960        }
2961
2962        if s.contains("<final_answer>") && !s.contains("</final_answer>") {
2963            format!("{s}</final_answer>")
2964        } else {
2965            s.to_owned()
2966        }
2967    } else {
2968        output.as_ref().trim().to_owned()
2969    };
2970
2971    Ok(output)
2972}
2973
2974/// Build the chat prompt from the chat messages.
2975///
2976/// # Arguments
2977///
2978/// * `model_name`: The name of the model.
2979///
2980/// * `chat_request`: The chat request.
2981///
2982/// # Returns
2983///
2984/// A tuple containing the prompt, the number of available tokens for completions, and a boolean indicating whether tools are used.
2985fn build_prompt(
2986    model_name: Option<&String>,
2987    chat_request: &mut ChatCompletionRequest,
2988) -> Result<(String, u64, bool), LlamaCoreError> {
2989    let metadata = get_model_metadata(model_name)?;
2990    let ctx_size = metadata.ctx_size as u64;
2991    let chat_prompt = ChatPrompt::from(metadata.prompt_template);
2992
2993    // compute max prompt tokens, which is 80% of the context size
2994    let max_prompt_tokens = ctx_size * 4 / 5;
2995
2996    loop {
2997        // ! DO NOT REMOVE
2998        {
2999            // // build prompt
3000            // let prompt = match chat_prompt.build(&mut chat_request.messages) {
3001            //     Ok(prompt) => prompt,
3002            //     Err(e) => {
3003            //         let err_msg = format!("Fail to build chat prompts. Reason: {}", e);
3004
3005            //         #[cfg(feature = "logging")]
3006            //         error!(target: "stdout", "{}", &err_msg);
3007
3008            //         return Err(LlamaCoreError::Operation(err_msg));
3009            //     }
3010            // };
3011        }
3012
3013        if chat_request.messages.is_empty() {
3014            let err_msg = "The messages in the chat request are empty.";
3015
3016            #[cfg(feature = "logging")]
3017            error!(target: "stdout", "{err_msg}");
3018
3019            return Err(LlamaCoreError::Operation(err_msg.to_owned()));
3020        }
3021
3022        #[cfg(feature = "logging")]
3023        {
3024            let mut role_chain = String::new();
3025            for (idx, message) in chat_request.messages.iter().enumerate() {
3026                if idx == chat_request.messages.len() - 1 {
3027                    role_chain.push_str(&format!("{}", message.role()));
3028                } else {
3029                    role_chain.push_str(&format!("{} -> ", message.role()));
3030                }
3031            }
3032            info!(target: "stdout", "Role chain: {role_chain}");
3033        }
3034
3035        let (prompt, tool_use) = match chat_request.tool_choice.as_ref() {
3036            Some(tool_choice) => match tool_choice {
3037                ToolChoice::None => {
3038                    match chat_prompt.build_with_tools(&mut chat_request.messages, Some(&[])) {
3039                        Ok(prompt) => (prompt, false),
3040                        Err(e) => {
3041                            let err_msg = format!("Fail to build chat prompts. Reason: {e}");
3042
3043                            #[cfg(feature = "logging")]
3044                            error!(target: "stdout", "{}", &err_msg);
3045
3046                            return Err(LlamaCoreError::Operation(err_msg));
3047                        }
3048                    }
3049                }
3050                _ => match chat_request.tools.as_ref() {
3051                    Some(tools) => match chat_prompt
3052                        .build_with_tools(&mut chat_request.messages, Some(tools.as_slice()))
3053                    {
3054                        Ok(prompt) => (prompt, true),
3055                        Err(e) => {
3056                            let err_msg = format!("Fail to build chat prompts. Reason: {e}");
3057
3058                            #[cfg(feature = "logging")]
3059                            error!(target: "stdout", "{}", &err_msg);
3060
3061                            return Err(LlamaCoreError::Operation(err_msg));
3062                        }
3063                    },
3064                    None => {
3065                        #[cfg(feature = "logging")]
3066                        warn!(target: "stdout", "The tool choice without tools is not supported.");
3067
3068                        match chat_prompt.build_with_tools(&mut chat_request.messages, None) {
3069                            Ok(prompt) => (prompt, false),
3070                            Err(e) => {
3071                                let err_msg = format!("Fail to build chat prompts. Reason: {e}");
3072
3073                                #[cfg(feature = "logging")]
3074                                error!(target: "stdout", "{}", &err_msg);
3075
3076                                return Err(LlamaCoreError::Operation(err_msg));
3077                            }
3078                        }
3079                    }
3080                },
3081            },
3082            None => match chat_prompt.build_with_tools(&mut chat_request.messages, None) {
3083                Ok(prompt) => (prompt, false),
3084                Err(e) => {
3085                    let err_msg = format!("Fail to build chat prompts. Reason: {e}");
3086
3087                    #[cfg(feature = "logging")]
3088                    error!(target: "stdout", "{}", &err_msg);
3089
3090                    return Err(LlamaCoreError::Operation(err_msg));
3091                }
3092            },
3093        };
3094        #[cfg(feature = "logging")]
3095        info!(target: "stdout", "Try to set prompt: {prompt}");
3096
3097        // set prompt
3098        set_prompt(model_name, &prompt)?;
3099
3100        // Retrieve the number of prompt tokens.
3101        let token_info = get_token_info_by_graph_name(model_name)?;
3102
3103        match token_info.prompt_tokens > max_prompt_tokens {
3104            true => {
3105                match chat_request.messages[0].role() {
3106                    ChatCompletionRole::System => {
3107                        // corner case: context size is too small, `system -> user -> assistant -> tool` cannot be trimmed.
3108                        if chat_request.messages.len() == 4
3109                            && chat_request.messages[1].role() == ChatCompletionRole::User
3110                            && chat_request.messages[2].role() == ChatCompletionRole::Assistant
3111                            && chat_request.messages[3].role() == ChatCompletionRole::Tool
3112                        {
3113                            let err_msg = format!(
3114                                "The number of prompt tokens ({}) is greater than the max prompt tokens ({}). Please increase the context size.",
3115                                token_info.prompt_tokens, max_prompt_tokens
3116                            );
3117
3118                            #[cfg(feature = "logging")]
3119                            error!(target: "stdout", "{}", &err_msg);
3120
3121                            return Err(LlamaCoreError::Operation(err_msg));
3122                        }
3123
3124                        if chat_request.messages.len() > 2 {
3125                            #[cfg(feature = "logging")]
3126                            info!(target: "stdout", "Prune chat history: current length {}", chat_request.messages.len());
3127
3128                            // remove user_1 if it exists
3129                            // For example, `system -> user_1 -> ... -> user_2 -> ... -> user_latest` will be converted to `system -> ... -> user_2 -> ... -> user_latest`
3130                            if chat_request.messages[1].role() == ChatCompletionRole::User {
3131                                let user_message = chat_request.messages.remove(1);
3132
3133                                #[cfg(feature = "logging")]
3134                                info!(target: "stdout", "Remove a user message from the chat history: {user_message:?}");
3135                            }
3136
3137                            // remove all messages until the message is of `user`
3138                            // For example, `system -> ... -> user_2 -> ... -> user_latest` will be converted to `system -> user_2 -> ... -> user_latest`
3139                            while chat_request.messages[1].role() != ChatCompletionRole::User {
3140                                let message = chat_request.messages.remove(1);
3141
3142                                #[cfg(feature = "logging")]
3143                                info!(target: "stdout", "Remove a {} message from the chat history: {:?}", message.role(), message);
3144
3145                                if chat_request.messages.len() == 1 {
3146                                    let err_msg = format!("The last message in the chat history should be a user message, but found a {} message.", message.role());
3147
3148                                    #[cfg(feature = "logging")]
3149                                    error!(target: "stdout", "{err_msg}");
3150
3151                                    return Err(LlamaCoreError::Operation(err_msg));
3152                                }
3153                            }
3154                        } else if token_info.prompt_tokens > ctx_size {
3155                            let err_msg = format!(
3156                                    "The number of prompt tokens ({}) is greater than the context size ({}). Please increase the context size, or simplify the input message.",
3157                                    token_info.prompt_tokens, ctx_size
3158                                );
3159
3160                            #[cfg(feature = "logging")]
3161                            error!(target: "stdout", "{}", &err_msg);
3162
3163                            return Err(LlamaCoreError::Operation(err_msg));
3164                        } else {
3165                            return Ok((prompt, ctx_size - token_info.prompt_tokens, tool_use));
3166                        }
3167                    }
3168                    ChatCompletionRole::User => {
3169                        // corner case: context size is too small, `user -> assistant -> tool` cannot be trimmed.
3170                        if chat_request.messages.len() == 3
3171                            && chat_request.messages[1].role() == ChatCompletionRole::User
3172                            && chat_request.messages[2].role() == ChatCompletionRole::Assistant
3173                            && chat_request.messages[3].role() == ChatCompletionRole::Tool
3174                        {
3175                            let err_msg = format!(
3176                            "The number of prompt tokens ({}) is greater than the max prompt tokens ({}). Please increase the context size.",
3177                            token_info.prompt_tokens, max_prompt_tokens
3178                        );
3179
3180                            #[cfg(feature = "logging")]
3181                            error!(target: "stdout", "{}", &err_msg);
3182
3183                            return Err(LlamaCoreError::Operation(err_msg));
3184                        }
3185
3186                        if chat_request.messages.len() > 1 {
3187                            // user_1 -> ... -> user_2 -> ... -> user_latest
3188
3189                            // remove user_1 if it exists
3190                            // For example, `user_1 -> ... -> user_2 -> ... -> user_latest` will be converted to `... -> user_2 -> ... -> user_latest`
3191                            if chat_request.messages[0].role() == ChatCompletionRole::User {
3192                                let user_message = chat_request.messages.remove(0);
3193
3194                                #[cfg(feature = "logging")]
3195                                info!(target: "stdout", "Remove a user message from the chat history: {user_message:?}");
3196                            }
3197
3198                            // remove all messages until the message is of `user`
3199                            // For example, `... -> user_2 -> ... -> user_latest` will be converted to `user_2 -> ... -> user_latest`
3200                            while chat_request.messages[0].role() != ChatCompletionRole::User {
3201                                let message = chat_request.messages.remove(0);
3202
3203                                #[cfg(feature = "logging")]
3204                                info!(target: "stdout", "Remove a {} message from the chat history: {:?}", message.role(), message);
3205
3206                                if chat_request.messages.is_empty() {
3207                                    let err_msg = format!("The last message in the chat history should be a user message, but found a {} message.", message.role());
3208
3209                                    #[cfg(feature = "logging")]
3210                                    error!(target: "stdout", "{err_msg}");
3211
3212                                    return Err(LlamaCoreError::Operation(err_msg));
3213                                }
3214                            }
3215                        } else if token_info.prompt_tokens > ctx_size {
3216                            let err_msg = format!(
3217                                    "The number of prompt tokens ({}) is greater than the context size ({}). Please increase the context size, or simplify the input message.",
3218                                    token_info.prompt_tokens, ctx_size
3219                                );
3220
3221                            #[cfg(feature = "logging")]
3222                            error!(target: "stdout", "{}", &err_msg);
3223
3224                            return Err(LlamaCoreError::Operation(err_msg));
3225                        } else {
3226                            return Ok((prompt, ctx_size - token_info.prompt_tokens, tool_use));
3227                        }
3228                    }
3229                    _ => {
3230                        #[cfg(feature = "logging")]
3231                        info!(target: "stdout", "remove a {} message from the message queue", chat_request.messages[0].role());
3232
3233                        chat_request.messages.remove(0);
3234                    }
3235                }
3236
3237                continue;
3238            }
3239            false => return Ok((prompt, ctx_size - max_prompt_tokens, tool_use)),
3240        }
3241    }
3242}
3243
3244fn set_prompt(model_name: Option<&String>, prompt: impl AsRef<str>) -> Result<(), LlamaCoreError> {
3245    let chat_graphs = match CHAT_GRAPHS.get() {
3246        Some(chat_graphs) => chat_graphs,
3247        None => {
3248            let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
3249
3250            #[cfg(feature = "logging")]
3251            error!(target: "stdout", "{}", &err_msg);
3252
3253            return Err(LlamaCoreError::Operation(err_msg.into()));
3254        }
3255    };
3256
3257    let mut chat_graphs = chat_graphs.lock().map_err(|e| {
3258        let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
3259
3260        #[cfg(feature = "logging")]
3261        error!(target: "stdout", "{}", &err_msg);
3262
3263        LlamaCoreError::Operation(err_msg)
3264    })?;
3265
3266    match model_name {
3267        Some(model_name) => {
3268            #[cfg(feature = "logging")]
3269            info!(target: "stdout", "Set prompt to the chat model named {model_name}");
3270
3271            match chat_graphs.contains_key(model_name) {
3272                true => {
3273                    let graph = chat_graphs.get_mut(model_name).unwrap();
3274                    let tensor_data = prompt.as_ref().as_bytes().to_vec();
3275                    set_tensor_data_u8(graph, 0, &tensor_data)
3276                }
3277                false => match chat_graphs.iter_mut().next() {
3278                    Some((_, graph)) => {
3279                        let tensor_data = prompt.as_ref().as_bytes().to_vec();
3280                        set_tensor_data_u8(graph, 0, &tensor_data)
3281                    }
3282                    None => {
3283                        let err_msg = "There is no model available in the chat graphs.";
3284
3285                        #[cfg(feature = "logging")]
3286                        error!(target: "stdout", "{}", &err_msg);
3287
3288                        Err(LlamaCoreError::Operation(err_msg.into()))
3289                    }
3290                },
3291            }
3292        }
3293        None => {
3294            #[cfg(feature = "logging")]
3295            info!(target: "stdout", "Set prompt to the default chat model.");
3296
3297            match chat_graphs.iter_mut().next() {
3298                Some((_, graph)) => {
3299                    let tensor_data = prompt.as_ref().as_bytes().to_vec();
3300                    set_tensor_data_u8(graph, 0, &tensor_data)
3301                }
3302                None => {
3303                    let err_msg = "There is no model available in the chat graphs while trying to set prompt to the default model.";
3304
3305                    #[cfg(feature = "logging")]
3306                    error!(target: "stdout", "{err_msg}");
3307
3308                    Err(LlamaCoreError::Operation(err_msg.into()))
3309                }
3310            }
3311        }
3312    }
3313}
3314
3315/// Get a copy of the metadata of the model.
3316fn get_model_metadata(model_name: Option<&String>) -> Result<GgmlMetadata, LlamaCoreError> {
3317    let chat_graphs = match CHAT_GRAPHS.get() {
3318        Some(chat_graphs) => chat_graphs,
3319        None => {
3320            let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
3321
3322            #[cfg(feature = "logging")]
3323            error!(target: "stdout", "{err_msg}");
3324
3325            return Err(LlamaCoreError::Operation(err_msg.into()));
3326        }
3327    };
3328
3329    let chat_graphs = chat_graphs.lock().map_err(|e| {
3330        let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
3331
3332        #[cfg(feature = "logging")]
3333        error!(target: "stdout", "{}", &err_msg);
3334
3335        LlamaCoreError::Operation(err_msg)
3336    })?;
3337
3338    match model_name {
3339        Some(model_name) => match chat_graphs.contains_key(model_name) {
3340            true => {
3341                let graph = chat_graphs.get(model_name).unwrap();
3342                Ok(graph.metadata.clone())
3343            }
3344            false => match chat_graphs.iter().next() {
3345                Some((_, graph)) => Ok(graph.metadata.clone()),
3346                None => {
3347                    let err_msg = "There is no model available in the chat graphs.";
3348
3349                    #[cfg(feature = "logging")]
3350                    error!(target: "stdout", "{}", &err_msg);
3351
3352                    Err(LlamaCoreError::Operation(err_msg.into()))
3353                }
3354            },
3355        },
3356        None => match chat_graphs.iter().next() {
3357            Some((_, graph)) => Ok(graph.metadata.clone()),
3358            None => {
3359                let err_msg = "There is no model available in the chat graphs.";
3360
3361                #[cfg(feature = "logging")]
3362                error!(target: "stdout", "{err_msg}");
3363
3364                Err(LlamaCoreError::Operation(err_msg.into()))
3365            }
3366        },
3367    }
3368}
3369
3370fn update_model_metadata(
3371    model_name: Option<&String>,
3372    metadata: &GgmlMetadata,
3373) -> Result<(), LlamaCoreError> {
3374    let config = match serde_json::to_string(metadata) {
3375        Ok(config) => config,
3376        Err(e) => {
3377            let err_msg = format!("Fail to serialize metadata to a JSON string. {e}");
3378
3379            #[cfg(feature = "logging")]
3380            error!(target: "stdout", "{}", &err_msg);
3381
3382            return Err(LlamaCoreError::Operation(err_msg));
3383        }
3384    };
3385
3386    let chat_graphs = match CHAT_GRAPHS.get() {
3387        Some(chat_graphs) => chat_graphs,
3388        None => {
3389            let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
3390
3391            #[cfg(feature = "logging")]
3392            error!(target: "stdout", "{err_msg}");
3393
3394            return Err(LlamaCoreError::Operation(err_msg.into()));
3395        }
3396    };
3397
3398    let mut chat_graphs = chat_graphs.lock().map_err(|e| {
3399        let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. Reason: {e}");
3400
3401        #[cfg(feature = "logging")]
3402        error!(target: "stdout", "{}", &err_msg);
3403
3404        LlamaCoreError::Operation(err_msg)
3405    })?;
3406
3407    match model_name {
3408        Some(model_name) => {
3409            match chat_graphs.contains_key(model_name) {
3410                true => {
3411                    let graph = chat_graphs.get_mut(model_name).unwrap();
3412                    // update metadata
3413                    set_tensor_data_u8(graph, 1, config.as_bytes())
3414                }
3415                false => match chat_graphs.iter_mut().next() {
3416                    Some((_, graph)) => {
3417                        // update metadata
3418                        set_tensor_data_u8(graph, 1, config.as_bytes())
3419                    }
3420                    None => {
3421                        let err_msg = "There is no model available in the chat graphs.";
3422
3423                        #[cfg(feature = "logging")]
3424                        error!(target: "stdout", "{}", &err_msg);
3425
3426                        Err(LlamaCoreError::Operation(err_msg.into()))
3427                    }
3428                },
3429            }
3430        }
3431        None => {
3432            match chat_graphs.iter_mut().next() {
3433                Some((_, graph)) => {
3434                    // update metadata
3435                    set_tensor_data_u8(graph, 1, config.as_bytes())
3436                }
3437                None => {
3438                    let err_msg = "There is no model available in the chat graphs.";
3439
3440                    #[cfg(feature = "logging")]
3441                    error!(target: "stdout", "{err_msg}");
3442
3443                    Err(LlamaCoreError::Operation(err_msg.into()))
3444                }
3445            }
3446        }
3447    }
3448}
3449
3450fn reset_model_metadata(model_name: Option<&String>) -> Result<(), LlamaCoreError> {
3451    // get metadata
3452    let metadata = get_model_metadata(model_name)?;
3453
3454    // update model with the original metadata
3455    update_model_metadata(model_name, &metadata)
3456}
3457
3458#[derive(Debug, Clone, Copy, PartialEq, Eq)]
3459enum ContextFullState {
3460    Message,
3461    Usage,
3462    Done,
3463    EndOfSequence,
3464}
3465
3466#[derive(Debug, Clone, Copy, PartialEq, Eq)]
3467enum StreamState {
3468    Usage,
3469    NoUsage,
3470    Done,
3471    EndOfSequence,
3472}
3473
3474#[derive(Debug, Clone, Copy, PartialEq, Eq)]
3475enum PromptTooLongState {
3476    Message,
3477    Usage,
3478    Done,
3479    EndOfSequence,
3480}
3481
3482struct ChatStream {
3483    id: String,
3484    model: Option<String>,
3485    include_usage: bool,
3486    context_full_state: ContextFullState,
3487    prompt_too_long_state: PromptTooLongState,
3488    stream_state: StreamState,
3489    cache: Option<VecDeque<String>>,
3490    is_waiting: bool,
3491    has_lock: bool,
3492}
3493impl ChatStream {
3494    fn new(
3495        model: Option<String>,
3496        id: String,
3497        include_usage: bool,
3498        cache: Option<Vec<String>>,
3499    ) -> Self {
3500        // Try to acquire lock
3501        let has_lock = CHAT_STREAM_ACTIVE
3502            .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
3503            .is_ok();
3504
3505        #[cfg(feature = "logging")]
3506        if !has_lock {
3507            info!(target: "stdout", "Lock acquisition failed in ChatStream::new, creating with waiting status");
3508        }
3509
3510        ChatStream {
3511            id,
3512            model,
3513            include_usage,
3514            context_full_state: ContextFullState::Message,
3515            prompt_too_long_state: PromptTooLongState::Message,
3516            stream_state: if include_usage {
3517                StreamState::Usage
3518            } else {
3519                StreamState::NoUsage
3520            },
3521            cache: cache.map(VecDeque::from),
3522            is_waiting: !has_lock,
3523            has_lock,
3524        }
3525    }
3526
3527    // Try to acquire lock, returns whether successful
3528    fn try_acquire_lock(&mut self) -> bool {
3529        if self.has_lock {
3530            return true;
3531        }
3532
3533        let acquired = CHAT_STREAM_ACTIVE
3534            .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
3535            .is_ok();
3536
3537        if acquired {
3538            self.has_lock = true;
3539            self.is_waiting = false;
3540        }
3541
3542        acquired
3543    }
3544}
3545impl Drop for ChatStream {
3546    fn drop(&mut self) {
3547        // Clean up is only needed if we have the lock or if stream was actually used
3548        if self.has_lock || (self.cache.is_none() && !self.is_waiting) {
3549            #[cfg(feature = "logging")]
3550            info!(target: "stdout", "Cleaning up context for ChatStream {}", &self.id);
3551
3552            match &self.model {
3553                Some(model_name) => {
3554                    match CHAT_GRAPHS.get() {
3555                        Some(chat_graphs) => {
3556                            match chat_graphs.lock() {
3557                                Ok(mut chat_graphs) => match chat_graphs.contains_key(model_name) {
3558                                    true => {
3559                                        let graph = chat_graphs.get_mut(model_name).unwrap();
3560
3561                                        // clean up the context
3562                                        if let Err(e) = graph.finish_single() {
3563                                            let err_msg = format!(
3564                                                "Failed to clean up the context. Reason: {e}"
3565                                            );
3566
3567                                            #[cfg(feature = "logging")]
3568                                            error!(target: "stdout", "{}", &err_msg);
3569
3570                                            #[cfg(not(feature = "logging"))]
3571                                            println!(
3572                                                "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3573                                                &err_msg
3574                                            );
3575                                        }
3576                                    }
3577                                    false => match chat_graphs.iter_mut().next() {
3578                                        Some((_, graph)) => {
3579                                            // clean up the context
3580                                            if let Err(e) = graph.finish_single() {
3581                                                let err_msg = format!(
3582                                                    "Failed to clean up the context. Reason: {e}"
3583                                                );
3584
3585                                                #[cfg(feature = "logging")]
3586                                                error!(target: "stdout", "{}", &err_msg);
3587
3588                                                #[cfg(not(feature = "logging"))]
3589                                                println!(
3590                                                    "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3591                                                    &err_msg
3592                                                );
3593                                            }
3594                                        }
3595                                        None => {
3596                                            let err_msg =
3597                                                "There is no model available in the chat graphs.";
3598
3599                                            #[cfg(feature = "logging")]
3600                                            error!(target: "stdout", "{}", &err_msg);
3601
3602                                            #[cfg(not(feature = "logging"))]
3603                                            println!(
3604                                                "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3605                                                &err_msg
3606                                            );
3607                                        }
3608                                    },
3609                                },
3610                                Err(e) => {
3611                                    let err_msg =
3612                                        format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
3613
3614                                    #[cfg(feature = "logging")]
3615                                    error!(target: "stdout", "{}", &err_msg);
3616
3617                                    #[cfg(not(feature = "logging"))]
3618                                    println!(
3619                                        "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3620                                        &err_msg
3621                                    );
3622                                }
3623                            }
3624                        }
3625                        None => {
3626                            let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
3627
3628                            #[cfg(feature = "logging")]
3629                            error!(target: "stdout", "{}", &err_msg);
3630
3631                            #[cfg(not(feature = "logging"))]
3632                            println!(
3633                                "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3634                                &err_msg
3635                            );
3636                        }
3637                    };
3638                }
3639                None => {
3640                    match CHAT_GRAPHS.get() {
3641                        Some(chat_graphs) => {
3642                            match chat_graphs.lock() {
3643                                Ok(mut chat_graphs) => match chat_graphs.iter_mut().next() {
3644                                    Some((_, graph)) => {
3645                                        // clean up the context
3646                                        if let Err(e) = graph.finish_single() {
3647                                            let err_msg = format!(
3648                                                "Failed to clean up the context. Reason: {e}"
3649                                            );
3650
3651                                            #[cfg(feature = "logging")]
3652                                            error!(target: "stdout", "{}", &err_msg);
3653
3654                                            #[cfg(not(feature = "logging"))]
3655                                            println!(
3656                                                "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3657                                                &err_msg
3658                                            );
3659                                        }
3660                                    }
3661                                    None => {
3662                                        let err_msg =
3663                                            "There is no model available in the chat graphs.";
3664
3665                                        #[cfg(feature = "logging")]
3666                                        error!(target: "stdout", "{err_msg}");
3667
3668                                        #[cfg(not(feature = "logging"))]
3669                                        println!(
3670                                            "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3671                                            err_msg
3672                                        );
3673                                    }
3674                                },
3675                                Err(e) => {
3676                                    let err_msg =
3677                                        format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
3678
3679                                    #[cfg(feature = "logging")]
3680                                    error!(target: "stdout", "{}", &err_msg);
3681
3682                                    #[cfg(not(feature = "logging"))]
3683                                    println!(
3684                                        "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3685                                        &err_msg
3686                                    );
3687                                }
3688                            }
3689                        }
3690                        None => {
3691                            let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
3692
3693                            #[cfg(feature = "logging")]
3694                            error!(target: "stdout", "{}", &err_msg);
3695
3696                            #[cfg(not(feature = "logging"))]
3697                            println!(
3698                                "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3699                                &err_msg
3700                            );
3701                        }
3702                    };
3703                }
3704            }
3705
3706            #[cfg(feature = "logging")]
3707            info!(target: "stdout", "Model context cleanup done!");
3708        }
3709
3710        // reset the model metadata
3711        if let Err(e) = reset_model_metadata(self.model.as_ref()) {
3712            let err_msg = format!("Fail to reset model metadata. Reason: {e}");
3713
3714            #[cfg(feature = "logging")]
3715            error!(target: "stdout", "{}", &err_msg);
3716
3717            #[cfg(not(feature = "logging"))]
3718            println!("[ERROR][llama_core] {}", &err_msg);
3719        }
3720        #[cfg(feature = "logging")]
3721        info!(target: "stdout", "Model metadata reset done!");
3722
3723        // When dropping a ChatStream that held the lock, check if there are waiting streams
3724        if self.has_lock {
3725            // Reset the atomic flag
3726            CHAT_STREAM_ACTIVE.store(false, Ordering::SeqCst);
3727
3728            #[cfg(feature = "logging")]
3729            info!(target: "stdout", "Lock from ChatStream {} released", &self.id);
3730
3731            // Wake up waiting streams
3732            if let Ok(mut queue) = get_chat_stream_waker_queue().lock() {
3733                if let Some(waker) = queue.pop_front() {
3734                    #[cfg(feature = "logging")]
3735                    info!(target: "stdout", "Waking up a waiting ChatStream");
3736
3737                    waker.wake();
3738                }
3739            }
3740        }
3741    }
3742}
3743impl futures::Stream for ChatStream {
3744    type Item = Result<String, LlamaCoreError>;
3745
3746    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
3747        let this = self.get_mut();
3748
3749        // If this is a waiting stream, try to acquire the lock
3750        if this.is_waiting {
3751            if !this.try_acquire_lock() {
3752                // Store the waker to be notified when the lock becomes available
3753                if let Ok(mut queue) = get_chat_stream_waker_queue().lock() {
3754                    // Remove any previous instance of this waker
3755                    queue.retain(|w| !w.will_wake(cx.waker()));
3756                    // Add this waker to the queue
3757                    queue.push_back(cx.waker().clone());
3758
3759                    #[cfg(feature = "logging")]
3760                    debug!(target: "stdout", "ChatStream {} is waiting for lock, added waker to queue", &this.id);
3761                }
3762
3763                return Poll::Pending;
3764            }
3765
3766            #[cfg(feature = "logging")]
3767            info!(target: "stdout", "ChatStream {} acquired lock and is now active", &this.id);
3768            // If we got here, we successfully acquired the lock and can proceed
3769        }
3770
3771        // Ensure we still have the lock
3772        if !this.has_lock && !this.try_acquire_lock() {
3773            // Lost the lock, need to wait
3774            this.is_waiting = true;
3775
3776            // Register waker to be notified when lock is available
3777            if let Ok(mut queue) = get_chat_stream_waker_queue().lock() {
3778                queue.retain(|w| !w.will_wake(cx.waker()));
3779                queue.push_back(cx.waker().clone());
3780            }
3781
3782            return Poll::Pending;
3783        }
3784
3785        if this.cache.is_none() {
3786            let res = compute_stream(
3787                this.model.clone(),
3788                this.id.clone(),
3789                this.include_usage,
3790                &mut this.prompt_too_long_state,
3791                &mut this.context_full_state,
3792                &mut this.stream_state,
3793            );
3794
3795            match res {
3796                Ok(x) => {
3797                    #[cfg(feature = "logging")]
3798                    info!(target: "stdout", "next item for ChatStream {}: {}", &this.id, &x);
3799
3800                    if x != "[GGML] End of sequence" && !x.is_empty() {
3801                        Poll::Ready(Some(Ok(x)))
3802                    } else {
3803                        // stopped
3804                        Poll::Ready(None)
3805                    }
3806                }
3807                Err(e) => Poll::Ready(Some(Err(e))),
3808            }
3809        } else {
3810            let x = this.cache.as_mut().unwrap().pop_front();
3811
3812            #[cfg(feature = "logging")]
3813            info!(target: "stdout", "Get the next item from the cache for ChatStream {}: {:?}", &this.id, &x);
3814
3815            match x {
3816                Some(x) => Poll::Ready(Some(Ok(x))),
3817                None => Poll::Ready(None),
3818            }
3819        }
3820    }
3821}
3822
3823/// Helper function to get or initialize the waker queue for waiting ChatStreams
3824fn get_chat_stream_waker_queue() -> &'static Mutex<VecDeque<Waker>> {
3825    CHAT_STREAM_WAKER_QUEUE.get_or_init(|| {
3826        #[cfg(feature = "logging")]
3827        info!(target: "stdout", "Initializing ChatStream waker queue");
3828        Mutex::new(VecDeque::new())
3829    })
3830}
3831
3832fn compute_stream(
3833    model_name: Option<String>,
3834    id: String,
3835    include_usage: bool,
3836    prompt_too_long_state: &mut PromptTooLongState,
3837    context_full_state: &mut ContextFullState,
3838    stream_state: &mut StreamState,
3839) -> Result<String, LlamaCoreError> {
3840    #[cfg(feature = "logging")]
3841    info!(target: "stdout", "Computing stream chunk for ChatStream {}", &id);
3842
3843    #[cfg(feature = "logging")]
3844    debug!(target: "stdout", "prompt_too_long_state: {:?}", *prompt_too_long_state);
3845    #[cfg(feature = "logging")]
3846    debug!(target: "stdout", "context_full_state: {:?}", *context_full_state);
3847    #[cfg(feature = "logging")]
3848    debug!(target: "stdout", "stream_state: {:?}", *stream_state);
3849
3850    if *prompt_too_long_state == PromptTooLongState::EndOfSequence
3851        || *context_full_state == ContextFullState::EndOfSequence
3852        || *stream_state == StreamState::EndOfSequence
3853    {
3854        #[cfg(feature = "logging")]
3855        info!(target: "stdout", "Return the chat stream chunk!");
3856
3857        return Ok("[GGML] End of sequence".to_string());
3858    }
3859
3860    let chat_graphs = match CHAT_GRAPHS.get() {
3861        Some(chat_graphs) => chat_graphs,
3862        None => {
3863            let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
3864
3865            #[cfg(feature = "logging")]
3866            error!(target: "stdout", "{}", &err_msg);
3867
3868            return Err(LlamaCoreError::Operation(err_msg.into()));
3869        }
3870    };
3871
3872    // We're already holding the ChatStream lock, so we know we have exclusive access to the graph
3873    let mut chat_graphs = chat_graphs.lock().map_err(|e| {
3874        let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
3875
3876        #[cfg(feature = "logging")]
3877        error!(target: "stdout", "{}", &err_msg);
3878
3879        LlamaCoreError::Operation(err_msg)
3880    })?;
3881
3882    // Get the graph based on model name
3883    let res = match &model_name {
3884        Some(model_name) => {
3885            match chat_graphs.contains_key(model_name) {
3886                true => {
3887                    let graph = chat_graphs.get_mut(model_name).unwrap();
3888                    // compute
3889                    match graph.compute_single() {
3890                        Ok(_) => {
3891                            #[cfg(feature = "logging")]
3892                            debug!(target: "stdout", "Compute the chat stream chunk successfully.");
3893
3894                            // Process according to state
3895                            match stream_state {
3896                                StreamState::Usage | StreamState::NoUsage => {
3897                                    // Retrieve the output
3898                                    let output_buffer =
3899                                        get_output_buffer_single(graph, OUTPUT_TENSOR)?;
3900
3901                                    #[cfg(feature = "logging")]
3902                                    info!(target: "stdout", "retrieved the output buffer");
3903
3904                                    // decode the output buffer to a utf8 string
3905                                    let output = match String::from_utf8(output_buffer.clone()) {
3906                                        Ok(token) => token,
3907                                        Err(_) => {
3908                                            let mutex = CACHED_UTF8_ENCODINGS
3909                                                .get_or_init(|| Mutex::new(Vec::new()));
3910                                            let mut cached_encodings = mutex.lock().map_err(|e| {
3911                                            let err_msg = format!(
3912                                                "Fail to acquire the lock of `UTF8_ENCODINGS`. Reason: {e}"
3913                                            );
3914
3915                                            #[cfg(feature = "logging")]
3916                                            error!(target: "stdout", "{}", &err_msg);
3917
3918
3919                                            LlamaCoreError::Operation(err_msg)
3920                                        })?;
3921
3922                                            // cache the bytes for future decoding
3923                                            cached_encodings.extend_from_slice(&output_buffer[..]);
3924
3925                                            match String::from_utf8(cached_encodings.to_vec()) {
3926                                                Ok(token) => {
3927                                                    // clear CACHED_UTF8_ENCODINGS
3928                                                    cached_encodings.clear();
3929
3930                                                    token
3931                                                }
3932                                                Err(e) => {
3933                                                    // TODO This is a temp check. In case, infinite cached encodings happen.
3934                                                    if cached_encodings.len() > 4 {
3935                                                        let err_msg = format!("Fail to convert a vector of bytes to string. The length of the utf8 bytes exceeds 4. {e}");
3936
3937                                                        #[cfg(feature = "logging")]
3938                                                        error!(target: "stdout", "{}", &err_msg);
3939
3940                                                        #[cfg(feature = "logging")]
3941                                                        error!(target: "stdout", "The cached buffer: {:?}", &cached_encodings[..]);
3942
3943                                                        // let token = String::from_utf8_lossy(
3944                                                        //     &cached_encodings,
3945                                                        // )
3946                                                        // .to_string();
3947
3948                                                        // clear CACHED_UTF8_ENCODINGS
3949                                                        cached_encodings.clear();
3950
3951                                                        String::from("")
3952                                                    } else {
3953                                                        let warn_msg = format!("Fail to convert a vector of bytes to string. {e}");
3954
3955                                                        #[cfg(feature = "logging")]
3956                                                        warn!(target: "stdout", "{}", &warn_msg);
3957
3958                                                        String::from("")
3959                                                    }
3960                                                }
3961                                            }
3962                                        }
3963                                    };
3964
3965                                    #[cfg(feature = "logging")]
3966                                    info!(target: "stdout", "decoded the output buffer");
3967
3968                                    let created = SystemTime::now()
3969                                        .duration_since(std::time::UNIX_EPOCH)
3970                                        .map_err(|e| {
3971                                            let err_msg = format!(
3972                                                "Failed to get the current time. Reason: {e}"
3973                                            );
3974
3975                                            #[cfg(feature = "logging")]
3976                                            error!(target: "stdout", "{}", &err_msg);
3977
3978                                            LlamaCoreError::Operation(err_msg)
3979                                        })?;
3980
3981                                    let chat_completion_chunk = ChatCompletionChunk {
3982                                        id,
3983                                        object: "chat.completion.chunk".to_string(),
3984                                        created: created.as_secs(),
3985                                        model: graph.name().to_owned(),
3986                                        system_fingerprint: "fp_44709d6fcb".to_string(),
3987                                        choices: vec![ChatCompletionChunkChoice {
3988                                            index: 0,
3989                                            delta: ChatCompletionChunkChoiceDelta {
3990                                                role: ChatCompletionRole::Assistant,
3991                                                content: Some(output),
3992                                                tool_calls: vec![],
3993                                            },
3994                                            logprobs: None,
3995                                            finish_reason: None,
3996                                        }],
3997                                        usage: None,
3998                                    };
3999
4000                                    #[cfg(feature = "logging")]
4001                                    info!(target: "stdout", "created chat completion chunk");
4002
4003                                    // serialize chat completion chunk
4004                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
4005                                        .map_err(|e| {
4006                                        let err_msg = format!(
4007                                            "Failed to serialize chat completion chunk. Reason: {e}"
4008                                        );
4009
4010                                        #[cfg(feature = "logging")]
4011                                        error!(target: "stdout", "{}", &err_msg);
4012
4013                                        LlamaCoreError::Operation(err_msg)
4014                                    })?;
4015
4016                                    Ok(format!("data: {chunk_str}\n\n"))
4017                                }
4018                                StreamState::Done => {
4019                                    *stream_state = StreamState::EndOfSequence;
4020
4021                                    Ok("data: [DONE]\n\n".to_string())
4022                                }
4023                                StreamState::EndOfSequence => {
4024                                    Ok("[GGML] End of sequence".to_string())
4025                                }
4026                            }
4027                        }
4028                        Err(wasmedge_wasi_nn::Error::BackendError(
4029                            wasmedge_wasi_nn::BackendError::EndOfSequence,
4030                        )) => {
4031                            #[cfg(feature = "logging")]
4032                            debug!(target: "stdout", "End of sequence");
4033
4034                            match stream_state {
4035                                StreamState::Usage => {
4036                                    *stream_state = StreamState::Done;
4037
4038                                    // retrieve the number of prompt and completion tokens
4039                                    let token_info = get_token_info_by_graph(graph)?;
4040
4041                                    let usage = Some(Usage {
4042                                        prompt_tokens: token_info.prompt_tokens,
4043                                        completion_tokens: token_info.completion_tokens,
4044                                        total_tokens: token_info.prompt_tokens
4045                                            + token_info.completion_tokens,
4046                                    });
4047
4048                                    #[cfg(feature = "logging")]
4049                                    info!(target: "stdout", "token_info: {} prompt tokens, {} completion tokens", token_info.prompt_tokens, token_info.completion_tokens);
4050
4051                                    let created = SystemTime::now()
4052                                        .duration_since(std::time::UNIX_EPOCH)
4053                                        .map_err(|e| {
4054                                            let err_msg = format!(
4055                                                "Failed to get the current time. Reason: {e}"
4056                                            );
4057
4058                                            #[cfg(feature = "logging")]
4059                                            error!(target: "stdout", "{}", &err_msg);
4060
4061                                            LlamaCoreError::Operation(err_msg)
4062                                        })?;
4063
4064                                    let chat_completion_chunk = ChatCompletionChunk {
4065                                        id,
4066                                        object: "chat.completion.chunk".to_string(),
4067                                        created: created.as_secs(),
4068                                        model: graph.name().to_owned(),
4069                                        system_fingerprint: "fp_44709d6fcb".to_string(),
4070                                        choices: vec![],
4071                                        usage,
4072                                    };
4073
4074                                    // serialize chat completion chunk
4075                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
4076                                        .map_err(|e| {
4077                                        let err_msg = format!(
4078                                            "Failed to serialize chat completion chunk. Reason: {e}"
4079                                        );
4080
4081                                        #[cfg(feature = "logging")]
4082                                        error!(target: "stdout", "{}", &err_msg);
4083
4084                                        LlamaCoreError::Operation(err_msg)
4085                                    })?;
4086
4087                                    Ok(format!("data: {chunk_str}\n\n"))
4088                                }
4089                                StreamState::Done | StreamState::NoUsage => {
4090                                    *stream_state = StreamState::EndOfSequence;
4091
4092                                    Ok("data: [DONE]\n\n".to_string())
4093                                }
4094                                StreamState::EndOfSequence => {
4095                                    Ok("[GGML] End of sequence".to_string())
4096                                }
4097                            }
4098                        }
4099                        Err(wasmedge_wasi_nn::Error::BackendError(
4100                            wasmedge_wasi_nn::BackendError::ContextFull,
4101                        )) => {
4102                            #[cfg(feature = "logging")]
4103                            debug!(target: "stdout", "Context full");
4104
4105                            match context_full_state {
4106                                ContextFullState::Message => {
4107                                    match include_usage {
4108                                        true => *context_full_state = ContextFullState::Usage,
4109                                        false => *context_full_state = ContextFullState::Done,
4110                                    }
4111
4112                                    let created = SystemTime::now()
4113                                        .duration_since(std::time::UNIX_EPOCH)
4114                                        .map_err(|e| {
4115                                            let err_msg = format!(
4116                                                "Failed to get the current time. Reason: {e}"
4117                                            );
4118
4119                                            #[cfg(feature = "logging")]
4120                                            error!(target: "stdout", "{}", &err_msg);
4121
4122                                            LlamaCoreError::Operation(err_msg)
4123                                        })?;
4124
4125                                    let chat_completion_chunk = ChatCompletionChunk {
4126                                        id,
4127                                        object: "chat.completion.chunk".to_string(),
4128                                        created: created.as_secs(),
4129                                        model: graph.name().to_owned(),
4130                                        system_fingerprint: "fp_44709d6fcb".to_string(),
4131                                        choices: vec![ChatCompletionChunkChoice {
4132                                            index: 0,
4133                                            delta: ChatCompletionChunkChoiceDelta {
4134                                                role: ChatCompletionRole::Assistant,
4135                                                content: Some(
4136                                                    "<|WASMEDGE-GGML-CONTEXT-FULL|>".to_string(),
4137                                                ),
4138                                                tool_calls: vec![],
4139                                            },
4140                                            logprobs: None,
4141                                            finish_reason: Some(FinishReason::length),
4142                                        }],
4143                                        usage: None,
4144                                    };
4145
4146                                    // serialize chat completion chunk
4147                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
4148                                        .map_err(|e| {
4149                                        let err_msg = format!(
4150                                            "Failed to serialize chat completion chunk. Reason: {e}"
4151                                        );
4152
4153                                        #[cfg(feature = "logging")]
4154                                        error!(target: "stdout", "{}", &err_msg);
4155
4156                                        LlamaCoreError::Operation(err_msg)
4157                                    })?;
4158
4159                                    Ok(format!("data: {chunk_str}\n\n"))
4160                                }
4161                                ContextFullState::Usage => {
4162                                    *context_full_state = ContextFullState::Done;
4163
4164                                    // retrieve the number of prompt and completion tokens
4165                                    let token_info = get_token_info_by_graph(graph)?;
4166
4167                                    let usage = Some(Usage {
4168                                        prompt_tokens: token_info.prompt_tokens,
4169                                        completion_tokens: token_info.completion_tokens,
4170                                        total_tokens: token_info.prompt_tokens
4171                                            + token_info.completion_tokens,
4172                                    });
4173
4174                                    let created = SystemTime::now()
4175                                        .duration_since(std::time::UNIX_EPOCH)
4176                                        .map_err(|e| {
4177                                            let err_msg = format!(
4178                                                "Failed to get the current time. Reason: {e}"
4179                                            );
4180
4181                                            #[cfg(feature = "logging")]
4182                                            error!(target: "stdout", "{}", &err_msg);
4183
4184                                            LlamaCoreError::Operation(err_msg)
4185                                        })?;
4186
4187                                    let chat_completion_chunk = ChatCompletionChunk {
4188                                        id,
4189                                        object: "chat.completion.chunk".to_string(),
4190                                        created: created.as_secs(),
4191                                        model: graph.name().to_owned(),
4192                                        system_fingerprint: "fp_44709d6fcb".to_string(),
4193                                        choices: vec![],
4194                                        usage,
4195                                    };
4196
4197                                    // serialize chat completion chunk
4198                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
4199                                        .map_err(|e| {
4200                                        let err_msg = format!(
4201                                            "Failed to serialize chat completion chunk. Reason: {e}"
4202                                        );
4203
4204                                        #[cfg(feature = "logging")]
4205                                        error!(target: "stdout", "{}", &err_msg);
4206
4207                                        LlamaCoreError::Operation(err_msg)
4208                                    })?;
4209
4210                                    Ok(format!("data: {chunk_str}\n\n"))
4211                                }
4212                                ContextFullState::Done => {
4213                                    *context_full_state = ContextFullState::EndOfSequence;
4214
4215                                    Ok("data: [DONE]\n\n".to_string())
4216                                }
4217                                ContextFullState::EndOfSequence => {
4218                                    Ok("[GGML] End of sequence".to_string())
4219                                }
4220                            }
4221                        }
4222                        Err(wasmedge_wasi_nn::Error::BackendError(
4223                            wasmedge_wasi_nn::BackendError::PromptTooLong,
4224                        )) => {
4225                            #[cfg(feature = "logging")]
4226                            debug!(target: "stdout", "Prompt too long");
4227
4228                            match prompt_too_long_state {
4229                                PromptTooLongState::Message => {
4230                                    match include_usage {
4231                                        true => *prompt_too_long_state = PromptTooLongState::Usage,
4232                                        false => *prompt_too_long_state = PromptTooLongState::Done,
4233                                    }
4234
4235                                    let created = SystemTime::now()
4236                                        .duration_since(std::time::UNIX_EPOCH)
4237                                        .map_err(|e| {
4238                                            let err_msg = format!(
4239                                                "Failed to get the current time. Reason: {e}"
4240                                            );
4241
4242                                            #[cfg(feature = "logging")]
4243                                            error!(target: "stdout", "{}", &err_msg);
4244
4245                                            LlamaCoreError::Operation(err_msg)
4246                                        })?;
4247
4248                                    let chat_completion_chunk = ChatCompletionChunk {
4249                                        id,
4250                                        object: "chat.completion.chunk".to_string(),
4251                                        created: created.as_secs(),
4252                                        model: graph.name().to_owned(),
4253                                        system_fingerprint: "fp_44709d6fcb".to_string(),
4254                                        choices: vec![ChatCompletionChunkChoice {
4255                                            index: 0,
4256                                            delta: ChatCompletionChunkChoiceDelta {
4257                                                role: ChatCompletionRole::Assistant,
4258                                                content: None,
4259                                                tool_calls: vec![],
4260                                            },
4261                                            logprobs: None,
4262                                            finish_reason: Some(FinishReason::length),
4263                                        }],
4264                                        usage: None,
4265                                    };
4266
4267                                    // serialize chat completion chunk
4268                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
4269                                        .map_err(|e| {
4270                                        let err_msg = format!(
4271                                            "Failed to serialize chat completion chunk. Reason: {e}"
4272                                        );
4273
4274                                        #[cfg(feature = "logging")]
4275                                        error!(target: "stdout", "{}", &err_msg);
4276
4277                                        LlamaCoreError::Operation(err_msg)
4278                                    })?;
4279
4280                                    Ok(format!("data: {chunk_str}\n\n"))
4281                                }
4282                                PromptTooLongState::Usage => {
4283                                    *prompt_too_long_state = PromptTooLongState::Done;
4284
4285                                    // retrieve the number of prompt and completion tokens
4286                                    let token_info = get_token_info_by_graph(graph)?;
4287
4288                                    let usage = Some(Usage {
4289                                        prompt_tokens: token_info.prompt_tokens,
4290                                        completion_tokens: token_info.completion_tokens,
4291                                        total_tokens: token_info.prompt_tokens
4292                                            + token_info.completion_tokens,
4293                                    });
4294
4295                                    let created = SystemTime::now()
4296                                        .duration_since(std::time::UNIX_EPOCH)
4297                                        .map_err(|e| {
4298                                            let err_msg = format!(
4299                                                "Failed to get the current time. Reason: {e}"
4300                                            );
4301
4302                                            #[cfg(feature = "logging")]
4303                                            error!(target: "stdout", "{}", &err_msg);
4304
4305                                            LlamaCoreError::Operation(err_msg)
4306                                        })?;
4307
4308                                    let chat_completion_chunk = ChatCompletionChunk {
4309                                        id,
4310                                        object: "chat.completion.chunk".to_string(),
4311                                        created: created.as_secs(),
4312                                        model: graph.name().to_owned(),
4313                                        system_fingerprint: "fp_44709d6fcb".to_string(),
4314                                        choices: vec![],
4315                                        usage,
4316                                    };
4317
4318                                    // serialize chat completion chunk
4319                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
4320                                        .map_err(|e| {
4321                                        let err_msg = format!(
4322                                            "Failed to serialize chat completion chunk. Reason: {e}"
4323                                        );
4324
4325                                        #[cfg(feature = "logging")]
4326                                        error!(target: "stdout", "{}", &err_msg);
4327
4328                                        LlamaCoreError::Operation(err_msg)
4329                                    })?;
4330
4331                                    Ok(format!("data: {chunk_str}\n\n"))
4332                                }
4333                                PromptTooLongState::Done => {
4334                                    *prompt_too_long_state = PromptTooLongState::EndOfSequence;
4335
4336                                    Ok("data: [DONE]\n\n".to_string())
4337                                }
4338                                PromptTooLongState::EndOfSequence => {
4339                                    Ok("[GGML] End of sequence".to_string())
4340                                }
4341                            }
4342                        }
4343                        Err(e) => {
4344                            let err_msg =
4345                                format!("Failed to compute the chat completion. Reason: {e}");
4346
4347                            #[cfg(feature = "logging")]
4348                            error!(target: "stdout", "{}", &err_msg);
4349
4350                            Err(LlamaCoreError::Backend(BackendError::ComputeSingle(
4351                                err_msg,
4352                            )))
4353                        }
4354                    }
4355                }
4356                false => {
4357                    match chat_graphs.iter_mut().next() {
4358                        Some((_, graph)) => {
4359                            // compute
4360                            match graph.compute_single() {
4361                                Ok(_) => {
4362                                    #[cfg(feature = "logging")]
4363                                    debug!(target: "stdout", "Compute the chat stream chunk successfully.");
4364
4365                                    match stream_state {
4366                                        StreamState::Usage | StreamState::NoUsage => {
4367                                            // Retrieve the output
4368                                            let output_buffer =
4369                                                get_output_buffer_single(graph, OUTPUT_TENSOR)?;
4370
4371                                            #[cfg(feature = "logging")]
4372                                            info!(target: "stdout", "retrieved the output buffer");
4373
4374                                            // decode the output buffer to a utf8 string
4375                                            let output = match String::from_utf8(
4376                                                output_buffer.clone(),
4377                                            ) {
4378                                                Ok(token) => token,
4379                                                Err(_) => {
4380                                                    let mutex = CACHED_UTF8_ENCODINGS
4381                                                        .get_or_init(|| Mutex::new(Vec::new()));
4382                                                    let mut cached_encodings = mutex.lock().map_err(|e| {
4383                                            let err_msg = format!(
4384                                                "Fail to acquire the lock of `UTF8_ENCODINGS`. Reason: {e}"
4385                                            );
4386
4387                                            #[cfg(feature = "logging")]
4388                                            error!(target: "stdout", "{}", &err_msg);
4389
4390
4391                                            LlamaCoreError::Operation(err_msg)
4392                                        })?;
4393
4394                                                    // cache the bytes for future decoding
4395                                                    cached_encodings
4396                                                        .extend_from_slice(&output_buffer[..]);
4397
4398                                                    match String::from_utf8(
4399                                                        cached_encodings.to_vec(),
4400                                                    ) {
4401                                                        Ok(token) => {
4402                                                            // clear encodings
4403                                                            cached_encodings.clear();
4404
4405                                                            token
4406                                                        }
4407                                                        Err(e) => {
4408                                                            // TODO This is a temp check. In case, infinite cached encodings happen.
4409                                                            if cached_encodings.len() > 4 {
4410                                                                let err_msg = format!("Fail to convert a vector of bytes to string. The length of the utf8 bytes exceeds 4. {e}");
4411
4412                                                                #[cfg(feature = "logging")]
4413                                                                error!(target: "stdout", "{}", &err_msg);
4414
4415                                                                #[cfg(feature = "logging")]
4416                                                                error!(target: "stdout", "The cached buffer: {:?}", &cached_encodings[..]);
4417
4418                                                                // let token =
4419                                                                //     String::from_utf8_lossy(
4420                                                                //         &cached_encodings,
4421                                                                //     )
4422                                                                //     .to_string();
4423
4424                                                                // clear CACHED_UTF8_ENCODINGS
4425                                                                cached_encodings.clear();
4426
4427                                                                String::from("")
4428                                                            } else {
4429                                                                let warn_msg = format!("Fail to convert a vector of bytes to string. {e}");
4430
4431                                                                #[cfg(feature = "logging")]
4432                                                                warn!(target: "stdout", "{}", &warn_msg);
4433
4434                                                                String::from("")
4435                                                            }
4436                                                        }
4437                                                    }
4438                                                }
4439                                            };
4440
4441                                            #[cfg(feature = "logging")]
4442                                            info!(target: "stdout", "decoded the output buffer");
4443
4444                                            let created = SystemTime::now()
4445                                                .duration_since(std::time::UNIX_EPOCH)
4446                                                .map_err(|e| {
4447                                                    let err_msg = format!(
4448                                                "Failed to get the current time. Reason: {e}"
4449                                            );
4450
4451                                                    #[cfg(feature = "logging")]
4452                                                    error!(target: "stdout", "{}", &err_msg);
4453
4454                                                    LlamaCoreError::Operation(err_msg)
4455                                                })?;
4456
4457                                            let chat_completion_chunk = ChatCompletionChunk {
4458                                                id,
4459                                                object: "chat.completion.chunk".to_string(),
4460                                                created: created.as_secs(),
4461                                                model: graph.name().to_owned(),
4462                                                system_fingerprint: "fp_44709d6fcb".to_string(),
4463                                                choices: vec![ChatCompletionChunkChoice {
4464                                                    index: 0,
4465                                                    delta: ChatCompletionChunkChoiceDelta {
4466                                                        role: ChatCompletionRole::Assistant,
4467                                                        content: Some(output),
4468                                                        tool_calls: vec![],
4469                                                    },
4470                                                    logprobs: None,
4471                                                    finish_reason: None,
4472                                                }],
4473                                                usage: None,
4474                                            };
4475
4476                                            #[cfg(feature = "logging")]
4477                                            info!(target: "stdout", "created chat completion chunk");
4478
4479                                            // serialize chat completion chunk
4480                                            let chunk_str =
4481                                                serde_json::to_string(&chat_completion_chunk)
4482                                                    .map_err(|e| {
4483                                                        let err_msg = format!(
4484                                            "Failed to serialize chat completion chunk. Reason: {e}"
4485                                        );
4486
4487                                                        #[cfg(feature = "logging")]
4488                                                        error!(target: "stdout", "{}", &err_msg);
4489
4490                                                        LlamaCoreError::Operation(err_msg)
4491                                                    })?;
4492
4493                                            Ok(format!("data: {chunk_str}\n\n"))
4494                                        }
4495                                        StreamState::Done => {
4496                                            *stream_state = StreamState::EndOfSequence;
4497
4498                                            Ok("data: [DONE]\n\n".to_string())
4499                                        }
4500                                        StreamState::EndOfSequence => {
4501                                            Ok("[GGML] End of sequence".to_string())
4502                                        }
4503                                    }
4504                                }
4505                                Err(wasmedge_wasi_nn::Error::BackendError(
4506                                    wasmedge_wasi_nn::BackendError::EndOfSequence,
4507                                )) => {
4508                                    #[cfg(feature = "logging")]
4509                                    debug!(target: "stdout", "End of sequence");
4510
4511                                    match stream_state {
4512                                        StreamState::Usage => {
4513                                            *stream_state = StreamState::Done;
4514
4515                                            // retrieve the number of prompt and completion tokens
4516                                            let token_info = get_token_info_by_graph(graph)?;
4517
4518                                            let usage = Some(Usage {
4519                                                prompt_tokens: token_info.prompt_tokens,
4520                                                completion_tokens: token_info.completion_tokens,
4521                                                total_tokens: token_info.prompt_tokens
4522                                                    + token_info.completion_tokens,
4523                                            });
4524
4525                                            #[cfg(feature = "logging")]
4526                                            info!(target: "stdout", "token_info: {} prompt tokens, {} completion tokens", token_info.prompt_tokens, token_info.completion_tokens);
4527
4528                                            let created = SystemTime::now()
4529                                                .duration_since(std::time::UNIX_EPOCH)
4530                                                .map_err(|e| {
4531                                                    let err_msg = format!(
4532                                                "Failed to get the current time. Reason: {e}"
4533                                            );
4534
4535                                                    #[cfg(feature = "logging")]
4536                                                    error!(target: "stdout", "{}", &err_msg);
4537
4538                                                    LlamaCoreError::Operation(err_msg)
4539                                                })?;
4540
4541                                            let chat_completion_chunk = ChatCompletionChunk {
4542                                                id,
4543                                                object: "chat.completion.chunk".to_string(),
4544                                                created: created.as_secs(),
4545                                                model: graph.name().to_owned(),
4546                                                system_fingerprint: "fp_44709d6fcb".to_string(),
4547                                                choices: vec![],
4548                                                usage,
4549                                            };
4550
4551                                            // serialize chat completion chunk
4552                                            let chunk_str =
4553                                                serde_json::to_string(&chat_completion_chunk)
4554                                                    .map_err(|e| {
4555                                                        let err_msg = format!(
4556                                            "Failed to serialize chat completion chunk. Reason: {e}"
4557                                        );
4558
4559                                                        #[cfg(feature = "logging")]
4560                                                        error!(target: "stdout", "{}", &err_msg);
4561
4562                                                        LlamaCoreError::Operation(err_msg)
4563                                                    })?;
4564
4565                                            Ok(format!("data: {chunk_str}\n\n"))
4566                                        }
4567                                        StreamState::Done | StreamState::NoUsage => {
4568                                            *stream_state = StreamState::EndOfSequence;
4569
4570                                            Ok("data: [DONE]\n\n".to_string())
4571                                        }
4572                                        StreamState::EndOfSequence => {
4573                                            Ok("[GGML] End of sequence".to_string())
4574                                        }
4575                                    }
4576                                }
4577                                Err(wasmedge_wasi_nn::Error::BackendError(
4578                                    wasmedge_wasi_nn::BackendError::ContextFull,
4579                                )) => {
4580                                    #[cfg(feature = "logging")]
4581                                    debug!(target: "stdout", "Context full");
4582
4583                                    match context_full_state {
4584                                        ContextFullState::Message => {
4585                                            match include_usage {
4586                                                true => {
4587                                                    *context_full_state = ContextFullState::Usage
4588                                                }
4589                                                false => {
4590                                                    *context_full_state = ContextFullState::Done
4591                                                }
4592                                            }
4593
4594                                            let created = SystemTime::now()
4595                                                .duration_since(std::time::UNIX_EPOCH)
4596                                                .map_err(|e| {
4597                                                    let err_msg = format!(
4598                                                "Failed to get the current time. Reason: {e}"
4599                                            );
4600
4601                                                    #[cfg(feature = "logging")]
4602                                                    error!(target: "stdout", "{}", &err_msg);
4603
4604                                                    LlamaCoreError::Operation(err_msg)
4605                                                })?;
4606
4607                                            let chat_completion_chunk = ChatCompletionChunk {
4608                                                id,
4609                                                object: "chat.completion.chunk".to_string(),
4610                                                created: created.as_secs(),
4611                                                model: graph.name().to_owned(),
4612                                                system_fingerprint: "fp_44709d6fcb".to_string(),
4613                                                choices: vec![ChatCompletionChunkChoice {
4614                                                    index: 0,
4615                                                    delta: ChatCompletionChunkChoiceDelta {
4616                                                        role: ChatCompletionRole::Assistant,
4617                                                        content: Some(
4618                                                            "<|WASMEDGE-GGML-CONTEXT-FULL|>"
4619                                                                .to_string(),
4620                                                        ),
4621                                                        tool_calls: vec![],
4622                                                    },
4623                                                    logprobs: None,
4624                                                    finish_reason: Some(FinishReason::length),
4625                                                }],
4626                                                usage: None,
4627                                            };
4628
4629                                            // serialize chat completion chunk
4630                                            let chunk_str =
4631                                                serde_json::to_string(&chat_completion_chunk)
4632                                                    .map_err(|e| {
4633                                                        let err_msg = format!(
4634                                            "Failed to serialize chat completion chunk. Reason: {e}"
4635                                        );
4636
4637                                                        #[cfg(feature = "logging")]
4638                                                        error!(target: "stdout", "{}", &err_msg);
4639
4640                                                        LlamaCoreError::Operation(err_msg)
4641                                                    })?;
4642
4643                                            Ok(format!("data: {chunk_str}\n\n"))
4644                                        }
4645                                        ContextFullState::Usage => {
4646                                            *context_full_state = ContextFullState::Done;
4647
4648                                            // retrieve the number of prompt and completion tokens
4649                                            let token_info = get_token_info_by_graph(graph)?;
4650
4651                                            let usage = Some(Usage {
4652                                                prompt_tokens: token_info.prompt_tokens,
4653                                                completion_tokens: token_info.completion_tokens,
4654                                                total_tokens: token_info.prompt_tokens
4655                                                    + token_info.completion_tokens,
4656                                            });
4657
4658                                            let created = SystemTime::now()
4659                                                .duration_since(std::time::UNIX_EPOCH)
4660                                                .map_err(|e| {
4661                                                    let err_msg = format!(
4662                                                "Failed to get the current time. Reason: {e}"
4663                                            );
4664
4665                                                    #[cfg(feature = "logging")]
4666                                                    error!(target: "stdout", "{}", &err_msg);
4667
4668                                                    LlamaCoreError::Operation(err_msg)
4669                                                })?;
4670
4671                                            let chat_completion_chunk = ChatCompletionChunk {
4672                                                id,
4673                                                object: "chat.completion.chunk".to_string(),
4674                                                created: created.as_secs(),
4675                                                model: graph.name().to_owned(),
4676                                                system_fingerprint: "fp_44709d6fcb".to_string(),
4677                                                choices: vec![],
4678                                                usage,
4679                                            };
4680
4681                                            // serialize chat completion chunk
4682                                            let chunk_str =
4683                                                serde_json::to_string(&chat_completion_chunk)
4684                                                    .map_err(|e| {
4685                                                        let err_msg = format!(
4686                                            "Failed to serialize chat completion chunk. Reason: {e}"
4687                                        );
4688
4689                                                        #[cfg(feature = "logging")]
4690                                                        error!(target: "stdout", "{}", &err_msg);
4691
4692                                                        LlamaCoreError::Operation(err_msg)
4693                                                    })?;
4694
4695                                            Ok(format!("data: {chunk_str}\n\n"))
4696                                        }
4697                                        ContextFullState::Done => {
4698                                            *context_full_state = ContextFullState::EndOfSequence;
4699
4700                                            Ok("data: [DONE]\n\n".to_string())
4701                                        }
4702                                        ContextFullState::EndOfSequence => {
4703                                            Ok("[GGML] End of sequence".to_string())
4704                                        }
4705                                    }
4706                                }
4707                                Err(wasmedge_wasi_nn::Error::BackendError(
4708                                    wasmedge_wasi_nn::BackendError::PromptTooLong,
4709                                )) => {
4710                                    #[cfg(feature = "logging")]
4711                                    debug!(target: "stdout", "Prompt too long");
4712
4713                                    match prompt_too_long_state {
4714                                        PromptTooLongState::Message => {
4715                                            match include_usage {
4716                                                true => {
4717                                                    *prompt_too_long_state =
4718                                                        PromptTooLongState::Usage
4719                                                }
4720                                                false => {
4721                                                    *prompt_too_long_state =
4722                                                        PromptTooLongState::Done
4723                                                }
4724                                            }
4725
4726                                            let created = SystemTime::now()
4727                                                .duration_since(std::time::UNIX_EPOCH)
4728                                                .map_err(|e| {
4729                                                    let err_msg = format!(
4730                                                "Failed to get the current time. Reason: {e}"
4731                                            );
4732
4733                                                    #[cfg(feature = "logging")]
4734                                                    error!(target: "stdout", "{}", &err_msg);
4735
4736                                                    LlamaCoreError::Operation(err_msg)
4737                                                })?;
4738
4739                                            let chat_completion_chunk = ChatCompletionChunk {
4740                                                id,
4741                                                object: "chat.completion.chunk".to_string(),
4742                                                created: created.as_secs(),
4743                                                model: graph.name().to_owned(),
4744                                                system_fingerprint: "fp_44709d6fcb".to_string(),
4745                                                choices: vec![ChatCompletionChunkChoice {
4746                                                    index: 0,
4747                                                    delta: ChatCompletionChunkChoiceDelta {
4748                                                        role: ChatCompletionRole::Assistant,
4749                                                        content: None,
4750                                                        tool_calls: vec![],
4751                                                    },
4752                                                    logprobs: None,
4753                                                    finish_reason: Some(FinishReason::length),
4754                                                }],
4755                                                usage: None,
4756                                            };
4757
4758                                            // serialize chat completion chunk
4759                                            let chunk_str =
4760                                                serde_json::to_string(&chat_completion_chunk)
4761                                                    .map_err(|e| {
4762                                                        let err_msg = format!(
4763                                            "Failed to serialize chat completion chunk. Reason: {e}"
4764                                        );
4765
4766                                                        #[cfg(feature = "logging")]
4767                                                        error!(target: "stdout", "{}", &err_msg);
4768
4769                                                        LlamaCoreError::Operation(err_msg)
4770                                                    })?;
4771
4772                                            Ok(format!("data: {chunk_str}\n\n"))
4773                                        }
4774                                        PromptTooLongState::Usage => {
4775                                            *prompt_too_long_state = PromptTooLongState::Done;
4776
4777                                            // retrieve the number of prompt and completion tokens
4778                                            let token_info = get_token_info_by_graph(graph)?;
4779
4780                                            let usage = Some(Usage {
4781                                                prompt_tokens: token_info.prompt_tokens,
4782                                                completion_tokens: token_info.completion_tokens,
4783                                                total_tokens: token_info.prompt_tokens
4784                                                    + token_info.completion_tokens,
4785                                            });
4786
4787                                            let created = SystemTime::now()
4788                                                .duration_since(std::time::UNIX_EPOCH)
4789                                                .map_err(|e| {
4790                                                    let err_msg = format!(
4791                                                "Failed to get the current time. Reason: {e}"
4792                                            );
4793
4794                                                    #[cfg(feature = "logging")]
4795                                                    error!(target: "stdout", "{}", &err_msg);
4796
4797                                                    LlamaCoreError::Operation(err_msg)
4798                                                })?;
4799
4800                                            let chat_completion_chunk = ChatCompletionChunk {
4801                                                id,
4802                                                object: "chat.completion.chunk".to_string(),
4803                                                created: created.as_secs(),
4804                                                model: graph.name().to_owned(),
4805                                                system_fingerprint: "fp_44709d6fcb".to_string(),
4806                                                choices: vec![],
4807                                                usage,
4808                                            };
4809
4810                                            // serialize chat completion chunk
4811                                            let chunk_str =
4812                                                serde_json::to_string(&chat_completion_chunk)
4813                                                    .map_err(|e| {
4814                                                        let err_msg = format!(
4815                                            "Failed to serialize chat completion chunk. Reason: {e}"
4816                                        );
4817
4818                                                        #[cfg(feature = "logging")]
4819                                                        error!(target: "stdout", "{}", &err_msg);
4820
4821                                                        LlamaCoreError::Operation(err_msg)
4822                                                    })?;
4823
4824                                            Ok(format!("data: {chunk_str}\n\n"))
4825                                        }
4826                                        PromptTooLongState::Done => {
4827                                            *prompt_too_long_state =
4828                                                PromptTooLongState::EndOfSequence;
4829
4830                                            Ok("data: [DONE]\n\n".to_string())
4831                                        }
4832                                        PromptTooLongState::EndOfSequence => {
4833                                            Ok("[GGML] End of sequence".to_string())
4834                                        }
4835                                    }
4836                                }
4837                                Err(e) => {
4838                                    let err_msg = format!(
4839                                        "Failed to compute the chat completion. Reason: {e}"
4840                                    );
4841
4842                                    #[cfg(feature = "logging")]
4843                                    error!(target: "stdout", "{}", &err_msg);
4844
4845                                    Err(LlamaCoreError::Backend(BackendError::ComputeSingle(
4846                                        err_msg,
4847                                    )))
4848                                }
4849                            }
4850                        }
4851                        None => {
4852                            let err_msg = "There is no model available in the chat graphs.";
4853
4854                            #[cfg(feature = "logging")]
4855                            error!(target: "stdout", "{}", &err_msg);
4856
4857                            Err(LlamaCoreError::Operation(err_msg.into()))
4858                        }
4859                    }
4860                }
4861            }
4862        }
4863        None => {
4864            match chat_graphs.iter_mut().next() {
4865                Some((_, graph)) => {
4866                    // compute
4867                    match graph.compute_single() {
4868                        Ok(_) => {
4869                            #[cfg(feature = "logging")]
4870                            debug!(target: "stdout", "Compute the chat stream chunk successfully.");
4871
4872                            match stream_state {
4873                                StreamState::Usage | StreamState::NoUsage => {
4874                                    // Retrieve the output
4875                                    let output_buffer =
4876                                        get_output_buffer_single(graph, OUTPUT_TENSOR)?;
4877
4878                                    #[cfg(feature = "logging")]
4879                                    info!(target: "stdout", "retrieved the output buffer");
4880
4881                                    // decode the output buffer to a utf8 string
4882                                    let output = match String::from_utf8(output_buffer.clone()) {
4883                                        Ok(token) => token,
4884                                        Err(_) => {
4885                                            let mutex = CACHED_UTF8_ENCODINGS
4886                                                .get_or_init(|| Mutex::new(Vec::new()));
4887                                            let mut cached_encodings = mutex.lock().map_err(|e| {
4888                                            let err_msg = format!(
4889                                                "Fail to acquire the lock of `UTF8_ENCODINGS`. Reason: {e}"
4890                                            );
4891
4892                                            #[cfg(feature = "logging")]
4893                                            error!(target: "stdout", "{}", &err_msg);
4894
4895                                            LlamaCoreError::Operation(err_msg)
4896                                        })?;
4897
4898                                            cached_encodings.extend_from_slice(&output_buffer[..]);
4899
4900                                            match String::from_utf8(cached_encodings.to_vec()) {
4901                                                Ok(token) => {
4902                                                    // clear encodings
4903                                                    cached_encodings.clear();
4904
4905                                                    token
4906                                                }
4907                                                Err(e) => {
4908                                                    // TODO This is a temp check. In case, infinite cached encodings happen.
4909                                                    if cached_encodings.len() > 4 {
4910                                                        let err_msg = format!("Fail to convert a vector of bytes to string. The length of the utf8 bytes exceeds 4. {e}");
4911
4912                                                        #[cfg(feature = "logging")]
4913                                                        error!(target: "stdout", "{}", &err_msg);
4914
4915                                                        #[cfg(feature = "logging")]
4916                                                        error!(target: "stdout", "The cached buffer: {:?}", &cached_encodings[..]);
4917
4918                                                        // let token = String::from_utf8_lossy(
4919                                                        //     &cached_encodings,
4920                                                        // )
4921                                                        // .to_string();
4922
4923                                                        // clear CACHED_UTF8_ENCODINGS
4924                                                        cached_encodings.clear();
4925
4926                                                        String::from("")
4927                                                    } else {
4928                                                        let warn_msg = format!("Fail to convert a vector of bytes to string. {e}");
4929
4930                                                        #[cfg(feature = "logging")]
4931                                                        warn!(target: "stdout", "{}", &warn_msg);
4932
4933                                                        String::from("")
4934                                                    }
4935                                                }
4936                                            }
4937                                        }
4938                                    };
4939
4940                                    #[cfg(feature = "logging")]
4941                                    info!(target: "stdout", "decoded the output buffer");
4942
4943                                    let created = SystemTime::now()
4944                                        .duration_since(std::time::UNIX_EPOCH)
4945                                        .map_err(|e| {
4946                                            let err_msg = format!(
4947                                                "Failed to get the current time. Reason: {e}"
4948                                            );
4949
4950                                            #[cfg(feature = "logging")]
4951                                            error!(target: "stdout", "{}", &err_msg);
4952
4953                                            LlamaCoreError::Operation(err_msg)
4954                                        })?;
4955
4956                                    let chat_completion_chunk = ChatCompletionChunk {
4957                                        id,
4958                                        object: "chat.completion.chunk".to_string(),
4959                                        created: created.as_secs(),
4960                                        model: graph.name().to_owned(),
4961                                        system_fingerprint: "fp_44709d6fcb".to_string(),
4962                                        choices: vec![ChatCompletionChunkChoice {
4963                                            index: 0,
4964                                            delta: ChatCompletionChunkChoiceDelta {
4965                                                role: ChatCompletionRole::Assistant,
4966                                                content: Some(output),
4967                                                tool_calls: vec![],
4968                                            },
4969                                            logprobs: None,
4970                                            finish_reason: None,
4971                                        }],
4972                                        usage: None,
4973                                    };
4974
4975                                    #[cfg(feature = "logging")]
4976                                    info!(target: "stdout", "created chat completion chunk");
4977
4978                                    // serialize chat completion chunk
4979                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
4980                                        .map_err(|e| {
4981                                        let err_msg = format!(
4982                                            "Failed to serialize chat completion chunk. Reason: {e}"
4983                                        );
4984
4985                                        #[cfg(feature = "logging")]
4986                                        error!(target: "stdout", "{}", &err_msg);
4987
4988                                        LlamaCoreError::Operation(err_msg)
4989                                    })?;
4990
4991                                    Ok(format!("data: {chunk_str}\n\n"))
4992                                }
4993                                StreamState::Done => {
4994                                    *stream_state = StreamState::EndOfSequence;
4995
4996                                    Ok("data: [DONE]\n\n".to_string())
4997                                }
4998                                StreamState::EndOfSequence => {
4999                                    Ok("[GGML] End of sequence".to_string())
5000                                }
5001                            }
5002                        }
5003                        Err(wasmedge_wasi_nn::Error::BackendError(
5004                            wasmedge_wasi_nn::BackendError::EndOfSequence,
5005                        )) => {
5006                            #[cfg(feature = "logging")]
5007                            debug!(target: "stdout", "End of sequence");
5008
5009                            match stream_state {
5010                                StreamState::Usage => {
5011                                    *stream_state = StreamState::Done;
5012
5013                                    // retrieve the number of prompt and completion tokens
5014                                    let token_info = get_token_info_by_graph(graph)?;
5015
5016                                    let usage = Some(Usage {
5017                                        prompt_tokens: token_info.prompt_tokens,
5018                                        completion_tokens: token_info.completion_tokens,
5019                                        total_tokens: token_info.prompt_tokens
5020                                            + token_info.completion_tokens,
5021                                    });
5022
5023                                    #[cfg(feature = "logging")]
5024                                    info!(target: "stdout", "token_info: {} prompt tokens, {} completion tokens", token_info.prompt_tokens, token_info.completion_tokens);
5025
5026                                    let created = SystemTime::now()
5027                                        .duration_since(std::time::UNIX_EPOCH)
5028                                        .map_err(|e| {
5029                                            let err_msg = format!(
5030                                                "Failed to get the current time. Reason: {e}"
5031                                            );
5032
5033                                            #[cfg(feature = "logging")]
5034                                            error!(target: "stdout", "{}", &err_msg);
5035
5036                                            LlamaCoreError::Operation(err_msg)
5037                                        })?;
5038
5039                                    let chat_completion_chunk = ChatCompletionChunk {
5040                                        id,
5041                                        object: "chat.completion.chunk".to_string(),
5042                                        created: created.as_secs(),
5043                                        model: graph.name().to_owned(),
5044                                        system_fingerprint: "fp_44709d6fcb".to_string(),
5045                                        choices: vec![],
5046                                        usage,
5047                                    };
5048
5049                                    // serialize chat completion chunk
5050                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
5051                                        .map_err(|e| {
5052                                        let err_msg = format!(
5053                                            "Failed to serialize chat completion chunk. Reason: {e}"
5054                                        );
5055
5056                                        #[cfg(feature = "logging")]
5057                                        error!(target: "stdout", "{}", &err_msg);
5058
5059                                        LlamaCoreError::Operation(err_msg)
5060                                    })?;
5061
5062                                    Ok(format!("data: {chunk_str}\n\n"))
5063                                }
5064                                StreamState::Done | StreamState::NoUsage => {
5065                                    *stream_state = StreamState::EndOfSequence;
5066
5067                                    Ok("data: [DONE]\n\n".to_string())
5068                                }
5069                                StreamState::EndOfSequence => {
5070                                    Ok("[GGML] End of sequence".to_string())
5071                                }
5072                            }
5073                        }
5074                        Err(wasmedge_wasi_nn::Error::BackendError(
5075                            wasmedge_wasi_nn::BackendError::ContextFull,
5076                        )) => {
5077                            #[cfg(feature = "logging")]
5078                            debug!(target: "stdout", "Context full");
5079
5080                            match context_full_state {
5081                                ContextFullState::Message => {
5082                                    match include_usage {
5083                                        true => *context_full_state = ContextFullState::Usage,
5084                                        false => *context_full_state = ContextFullState::Done,
5085                                    }
5086
5087                                    let created = SystemTime::now()
5088                                        .duration_since(std::time::UNIX_EPOCH)
5089                                        .map_err(|e| {
5090                                            let err_msg = format!(
5091                                                "Failed to get the current time. Reason: {e}"
5092                                            );
5093
5094                                            #[cfg(feature = "logging")]
5095                                            error!(target: "stdout", "{}", &err_msg);
5096
5097                                            LlamaCoreError::Operation(err_msg)
5098                                        })?;
5099
5100                                    let chat_completion_chunk = ChatCompletionChunk {
5101                                        id,
5102                                        object: "chat.completion.chunk".to_string(),
5103                                        created: created.as_secs(),
5104                                        model: graph.name().to_owned(),
5105                                        system_fingerprint: "fp_44709d6fcb".to_string(),
5106                                        choices: vec![ChatCompletionChunkChoice {
5107                                            index: 0,
5108                                            delta: ChatCompletionChunkChoiceDelta {
5109                                                role: ChatCompletionRole::Assistant,
5110                                                content: Some(
5111                                                    "<|WASMEDGE-GGML-CONTEXT-FULL|>".to_string(),
5112                                                ),
5113                                                tool_calls: vec![],
5114                                            },
5115                                            logprobs: None,
5116                                            finish_reason: Some(FinishReason::length),
5117                                        }],
5118                                        usage: None,
5119                                    };
5120
5121                                    // serialize chat completion chunk
5122                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
5123                                        .map_err(|e| {
5124                                        let err_msg = format!(
5125                                            "Failed to serialize chat completion chunk. Reason: {e}"
5126                                        );
5127
5128                                        #[cfg(feature = "logging")]
5129                                        error!(target: "stdout", "{}", &err_msg);
5130
5131                                        LlamaCoreError::Operation(err_msg)
5132                                    })?;
5133
5134                                    Ok(format!("data: {chunk_str}\n\n"))
5135                                }
5136                                ContextFullState::Usage => {
5137                                    *context_full_state = ContextFullState::Done;
5138
5139                                    // retrieve the number of prompt and completion tokens
5140                                    let token_info = get_token_info_by_graph(graph)?;
5141
5142                                    let usage = Some(Usage {
5143                                        prompt_tokens: token_info.prompt_tokens,
5144                                        completion_tokens: token_info.completion_tokens,
5145                                        total_tokens: token_info.prompt_tokens
5146                                            + token_info.completion_tokens,
5147                                    });
5148
5149                                    let created = SystemTime::now()
5150                                        .duration_since(std::time::UNIX_EPOCH)
5151                                        .map_err(|e| {
5152                                            let err_msg = format!(
5153                                                "Failed to get the current time. Reason: {e}"
5154                                            );
5155
5156                                            #[cfg(feature = "logging")]
5157                                            error!(target: "stdout", "{}", &err_msg);
5158
5159                                            LlamaCoreError::Operation(err_msg)
5160                                        })?;
5161
5162                                    let chat_completion_chunk = ChatCompletionChunk {
5163                                        id,
5164                                        object: "chat.completion.chunk".to_string(),
5165                                        created: created.as_secs(),
5166                                        model: graph.name().to_owned(),
5167                                        system_fingerprint: "fp_44709d6fcb".to_string(),
5168                                        choices: vec![],
5169                                        usage,
5170                                    };
5171
5172                                    // serialize chat completion chunk
5173                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
5174                                        .map_err(|e| {
5175                                        let err_msg = format!(
5176                                            "Failed to serialize chat completion chunk. Reason: {e}"
5177                                        );
5178
5179                                        #[cfg(feature = "logging")]
5180                                        error!(target: "stdout", "{}", &err_msg);
5181
5182                                        LlamaCoreError::Operation(err_msg)
5183                                    })?;
5184
5185                                    Ok(format!("data: {chunk_str}\n\n"))
5186                                }
5187                                ContextFullState::Done => {
5188                                    *context_full_state = ContextFullState::EndOfSequence;
5189
5190                                    Ok("data: [DONE]\n\n".to_string())
5191                                }
5192                                ContextFullState::EndOfSequence => {
5193                                    Ok("[GGML] End of sequence".to_string())
5194                                }
5195                            }
5196                        }
5197                        Err(wasmedge_wasi_nn::Error::BackendError(
5198                            wasmedge_wasi_nn::BackendError::PromptTooLong,
5199                        )) => {
5200                            #[cfg(feature = "logging")]
5201                            debug!(target: "stdout", "Prompt too long");
5202
5203                            match prompt_too_long_state {
5204                                PromptTooLongState::Message => {
5205                                    match include_usage {
5206                                        true => *prompt_too_long_state = PromptTooLongState::Usage,
5207                                        false => *prompt_too_long_state = PromptTooLongState::Done,
5208                                    }
5209
5210                                    let created = SystemTime::now()
5211                                        .duration_since(std::time::UNIX_EPOCH)
5212                                        .map_err(|e| {
5213                                            let err_msg = format!(
5214                                                "Failed to get the current time. Reason: {e}"
5215                                            );
5216
5217                                            #[cfg(feature = "logging")]
5218                                            error!(target: "stdout", "{}", &err_msg);
5219
5220                                            LlamaCoreError::Operation(err_msg)
5221                                        })?;
5222
5223                                    let chat_completion_chunk = ChatCompletionChunk {
5224                                        id,
5225                                        object: "chat.completion.chunk".to_string(),
5226                                        created: created.as_secs(),
5227                                        model: graph.name().to_owned(),
5228                                        system_fingerprint: "fp_44709d6fcb".to_string(),
5229                                        choices: vec![ChatCompletionChunkChoice {
5230                                            index: 0,
5231                                            delta: ChatCompletionChunkChoiceDelta {
5232                                                role: ChatCompletionRole::Assistant,
5233                                                content: None,
5234                                                tool_calls: vec![],
5235                                            },
5236                                            logprobs: None,
5237                                            finish_reason: Some(FinishReason::length),
5238                                        }],
5239                                        usage: None,
5240                                    };
5241
5242                                    // serialize chat completion chunk
5243                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
5244                                        .map_err(|e| {
5245                                        let err_msg = format!(
5246                                            "Failed to serialize chat completion chunk. Reason: {e}"
5247                                        );
5248
5249                                        #[cfg(feature = "logging")]
5250                                        error!(target: "stdout", "{}", &err_msg);
5251
5252                                        LlamaCoreError::Operation(err_msg)
5253                                    })?;
5254
5255                                    Ok(format!("data: {chunk_str}\n\n"))
5256                                }
5257                                PromptTooLongState::Usage => {
5258                                    *prompt_too_long_state = PromptTooLongState::Done;
5259
5260                                    // retrieve the number of prompt and completion tokens
5261                                    let token_info = get_token_info_by_graph(graph)?;
5262
5263                                    let usage = Some(Usage {
5264                                        prompt_tokens: token_info.prompt_tokens,
5265                                        completion_tokens: token_info.completion_tokens,
5266                                        total_tokens: token_info.prompt_tokens
5267                                            + token_info.completion_tokens,
5268                                    });
5269
5270                                    let created = SystemTime::now()
5271                                        .duration_since(std::time::UNIX_EPOCH)
5272                                        .map_err(|e| {
5273                                            let err_msg = format!(
5274                                                "Failed to get the current time. Reason: {e}"
5275                                            );
5276
5277                                            #[cfg(feature = "logging")]
5278                                            error!(target: "stdout", "{}", &err_msg);
5279
5280                                            LlamaCoreError::Operation(err_msg)
5281                                        })?;
5282
5283                                    let chat_completion_chunk = ChatCompletionChunk {
5284                                        id,
5285                                        object: "chat.completion.chunk".to_string(),
5286                                        created: created.as_secs(),
5287                                        model: graph.name().to_owned(),
5288                                        system_fingerprint: "fp_44709d6fcb".to_string(),
5289                                        choices: vec![],
5290                                        usage,
5291                                    };
5292
5293                                    // serialize chat completion chunk
5294                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
5295                                        .map_err(|e| {
5296                                        let err_msg = format!(
5297                                            "Failed to serialize chat completion chunk. Reason: {e}"
5298                                        );
5299
5300                                        #[cfg(feature = "logging")]
5301                                        error!(target: "stdout", "{}", &err_msg);
5302
5303                                        LlamaCoreError::Operation(err_msg)
5304                                    })?;
5305
5306                                    Ok(format!("data: {chunk_str}\n\n"))
5307                                }
5308                                PromptTooLongState::Done => {
5309                                    *prompt_too_long_state = PromptTooLongState::EndOfSequence;
5310
5311                                    Ok("data: [DONE]\n\n".to_string())
5312                                }
5313                                PromptTooLongState::EndOfSequence => {
5314                                    Ok("[GGML] End of sequence".to_string())
5315                                }
5316                            }
5317                        }
5318                        Err(e) => {
5319                            let err_msg =
5320                                format!("Failed to compute the chat completion. Reason: {e}");
5321
5322                            #[cfg(feature = "logging")]
5323                            error!(target: "stdout", "{}", &err_msg);
5324
5325                            Err(LlamaCoreError::Backend(BackendError::ComputeSingle(
5326                                err_msg,
5327                            )))
5328                        }
5329                    }
5330                }
5331                None => {
5332                    let err_msg = "There is no model available in the chat graphs.";
5333
5334                    #[cfg(feature = "logging")]
5335                    error!(target: "stdout", "{}", &err_msg);
5336
5337                    Err(LlamaCoreError::Operation(err_msg.into()))
5338                }
5339            }
5340        }
5341    };
5342
5343    #[cfg(feature = "logging")]
5344    info!(target: "stdout", "Return the chat stream chunk!");
5345
5346    res
5347}
5348
5349#[allow(dead_code)]
5350#[derive(Debug)]
5351struct ParseResult {
5352    raw: String,
5353    content: Option<String>,
5354    tool_calls: Vec<ToolCall>,
5355}