llama_core/
chat.rs

1//! Define APIs for chat completion.
2
3use crate::{
4    error,
5    metadata::ggml::GgmlMetadata,
6    running_mode,
7    utils::{
8        gen_chat_id, get_output_buffer, get_output_buffer_single, get_token_info_by_graph,
9        get_token_info_by_graph_name, set_tensor_data_u8,
10    },
11    Graph, RunningMode, CACHED_UTF8_ENCODINGS, CHAT_GRAPHS, OUTPUT_TENSOR,
12};
13use chat_prompts::{
14    chat::{BuildChatPrompt, ChatPrompt},
15    PromptTemplateType,
16};
17use either::{Either, Left, Right};
18use endpoints::{
19    chat::{
20        ChatCompletionChunk, ChatCompletionChunkChoice, ChatCompletionChunkChoiceDelta,
21        ChatCompletionObject, ChatCompletionObjectChoice, ChatCompletionObjectMessage,
22        ChatCompletionRequest, ChatCompletionRequestMessage, ChatCompletionRole,
23        ChatCompletionUserMessageContent, ContentPart, Function, ToolCall, ToolCallForChunk,
24        ToolChoice,
25    },
26    common::{FinishReason, Usage},
27};
28use error::{BackendError, LlamaCoreError};
29use std::{
30    collections::VecDeque,
31    pin::Pin,
32    sync::{
33        atomic::{AtomicBool, Ordering},
34        Mutex, OnceLock,
35    },
36    task::{Context, Poll, Waker},
37    time::SystemTime,
38};
39
40// Define a global waker queue for storing waiting ChatStreams
41static CHAT_STREAM_WAKER_QUEUE: OnceLock<Mutex<VecDeque<Waker>>> = OnceLock::new();
42
43// Define a global atomic boolean indicating whether there is an active ChatStream
44static CHAT_STREAM_ACTIVE: AtomicBool = AtomicBool::new(false);
45
46/// Processes a chat-completion request and returns either a stream of ChatCompletionChunk instances or a ChatCompletionObject instance.
47pub async fn chat(
48    chat_request: &mut ChatCompletionRequest,
49) -> Result<
50    (
51        Either<impl futures::TryStream<Ok = String, Error = LlamaCoreError>, ChatCompletionObject>,
52        bool,
53    ),
54    LlamaCoreError,
55> {
56    #[cfg(feature = "logging")]
57    {
58        debug!(target: "stdout", "tool choice: {:?}", chat_request.tool_choice.as_ref());
59        debug!(target: "stdout", "tools: {:?}", chat_request.tools.as_ref());
60        debug!(target: "stdout", "stream mode: {:?}", chat_request.stream);
61    }
62
63    let result = match chat_request.stream {
64        Some(true) => match chat_stream(chat_request).await {
65            Ok((stream, include_tool_calls)) => Ok((Left(stream), include_tool_calls)),
66            Err(e) => Err(e),
67        },
68        Some(false) | None => match chat_once(chat_request).await {
69            Ok((chat_completion_object, include_tool_calls)) => {
70                Ok((Right(chat_completion_object), include_tool_calls))
71            }
72            Err(e) => Err(e),
73        },
74    };
75
76    #[cfg(feature = "logging")]
77    info!(target: "stdout", "Reset the model metadata");
78
79    result
80}
81
82async fn chat_stream(
83    chat_request: &mut ChatCompletionRequest,
84) -> Result<
85    (
86        impl futures::TryStream<Ok = String, Error = LlamaCoreError>,
87        bool,
88    ),
89    LlamaCoreError,
90> {
91    #[cfg(feature = "logging")]
92    info!(target: "stdout", "Process chat completion request in the stream mode");
93
94    let running_mode = running_mode()?;
95    if !running_mode.contains(RunningMode::CHAT) && !running_mode.contains(RunningMode::RAG) {
96        let err_msg = "The chat completion is only supported in the chat or rag mode.";
97
98        #[cfg(feature = "logging")]
99        error!(target: "stdout", "{err_msg}");
100
101        return Err(LlamaCoreError::Operation(err_msg.to_string()));
102    }
103
104    let model_name = chat_request.model.clone();
105    let id = match &chat_request.user {
106        Some(id) => id.clone(),
107        None => gen_chat_id(),
108    };
109    #[cfg(feature = "logging")]
110    info!(target: "stdout", "user: {}", &id);
111
112    #[cfg(feature = "logging")]
113    info!(target: "stdout", "Check model metadata");
114
115    // update metadata
116    let mut metadata = check_model_metadata(chat_request)?;
117
118    // parse the `include_usage` option
119    let include_usage = match chat_request.stream_options {
120        Some(ref stream_options) => stream_options.include_usage.unwrap_or_default(),
121        None => metadata.include_usage,
122    };
123    #[cfg(feature = "logging")]
124    info!(target: "stdout", "include_usage: {include_usage}");
125
126    #[cfg(feature = "logging")]
127    info!(target: "stdout", "Build the chat prompt");
128
129    // build prompt
130    let (prompt, avaible_completion_tokens, tool_use) =
131        build_prompt(model_name.as_ref(), chat_request)?;
132
133    #[cfg(feature = "logging")]
134    {
135        info!(target: "stdout", "prompt:\n{}", &prompt);
136        info!(target: "stdout", "available_completion_tokens: {avaible_completion_tokens}");
137        info!(target: "stdout", "tool_use: {tool_use}");
138    }
139
140    #[cfg(feature = "logging")]
141    info!(target: "stdout", "Update the n_predict");
142
143    // update metadata n_predict
144    update_n_predict(chat_request, &mut metadata, avaible_completion_tokens)?;
145
146    #[cfg(feature = "logging")]
147    info!(target: "stdout", "Feed the prompt to the model");
148
149    // set prompt
150    set_prompt(chat_request.model.as_ref(), &prompt)?;
151
152    let stream = match tool_use {
153        false => (ChatStream::new(model_name, id, include_usage, None), false),
154        true => {
155            let chat_graphs = match CHAT_GRAPHS.get() {
156                Some(chat_graphs) => chat_graphs,
157                None => {
158                    let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
159
160                    #[cfg(feature = "logging")]
161                    error!(target: "stdout", "{}", &err_msg);
162
163                    return Err(LlamaCoreError::Operation(err_msg.into()));
164                }
165            };
166
167            let mut chat_graphs = chat_graphs.lock().map_err(|e| {
168                let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
169
170                #[cfg(feature = "logging")]
171                error!(target: "stdout", "{}", &err_msg);
172
173                LlamaCoreError::Operation(err_msg)
174            })?;
175
176            match model_name {
177                Some(model_name) => match chat_graphs.contains_key(&model_name) {
178                    true => {
179                        let graph = chat_graphs.get_mut(&model_name).unwrap();
180                        chat_stream_for_tool(graph, id, include_usage)?
181                    }
182                    false => match chat_graphs.iter_mut().next() {
183                        Some((_, graph)) => chat_stream_for_tool(graph, id, include_usage)?,
184                        None => {
185                            let err_msg = "There is no model available in the chat graphs.";
186
187                            #[cfg(feature = "logging")]
188                            error!(target: "stdout", "{}", &err_msg);
189
190                            return Err(LlamaCoreError::Operation(err_msg.into()));
191                        }
192                    },
193                },
194                None => match chat_graphs.iter_mut().next() {
195                    Some((_, graph)) => chat_stream_for_tool(graph, id, include_usage)?,
196                    None => {
197                        let err_msg = "There is no model available in the chat graphs.";
198
199                        #[cfg(feature = "logging")]
200                        error!(target: "stdout", "{}", &err_msg);
201
202                        return Err(LlamaCoreError::Operation(err_msg.into()));
203                    }
204                },
205            }
206        }
207    };
208
209    #[cfg(feature = "logging")]
210    info!(target: "stdout", "End of the chat completion stream.");
211
212    Ok(stream)
213}
214
215fn chat_stream_for_tool(
216    graph: &mut Graph<GgmlMetadata>,
217    id: impl Into<String>,
218    include_usage: bool,
219) -> Result<(ChatStream, bool), LlamaCoreError> {
220    #[cfg(feature = "logging")]
221    info!(target: "stdout", "Handle chat request with available tools by the model named {}.", graph.name());
222
223    let id = id.into();
224
225    match graph.compute() {
226        Ok(_) => {
227            // Retrieve the output.
228            let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
229            let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
230                let err_msg = format!(
231                    "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
232                );
233
234                #[cfg(feature = "logging")]
235                error!(target: "stdout", "{}", &err_msg);
236
237                LlamaCoreError::Operation(err_msg)
238            })?;
239
240            #[cfg(feature = "logging")]
241            info!(target: "stdout", "raw generation:\n{output}");
242
243            // post-process
244            let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
245                LlamaCoreError::Operation(format!("Failed to post-process the output. {e}"))
246            })?;
247
248            #[cfg(feature = "logging")]
249            info!(target: "stdout", "post-processed generation:\n{}", &message);
250
251            // retrieve the number of prompt and completion tokens
252            let token_info = get_token_info_by_graph(graph)?;
253
254            #[cfg(feature = "logging")]
255            info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
256
257            let usage = Some(Usage {
258                prompt_tokens: token_info.prompt_tokens,
259                completion_tokens: token_info.completion_tokens,
260                total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
261            });
262
263            let created = SystemTime::now()
264                .duration_since(std::time::UNIX_EPOCH)
265                .map_err(|e| {
266                    let err_msg = format!("Failed to get the current time. Reason: {e}");
267
268                    #[cfg(feature = "logging")]
269                    error!(target: "stdout", "{}", &err_msg);
270
271                    LlamaCoreError::Operation(err_msg)
272                })?;
273
274            if graph.metadata.prompt_template != PromptTemplateType::MistralTool
275                && graph.metadata.prompt_template != PromptTemplateType::ChatMLTool
276                && graph.metadata.prompt_template != PromptTemplateType::GroqLlama3Tool
277                && graph.metadata.prompt_template != PromptTemplateType::Llama3Tool
278                && graph.metadata.prompt_template != PromptTemplateType::InternLM2Tool
279                && graph.metadata.prompt_template != PromptTemplateType::NemotronTool
280                && graph.metadata.prompt_template != PromptTemplateType::FunctionaryV32
281                && graph.metadata.prompt_template != PromptTemplateType::FunctionaryV31
282                && graph.metadata.prompt_template != PromptTemplateType::MistralSmallTool
283                && graph.metadata.prompt_template != PromptTemplateType::Llama4Chat
284                && graph.metadata.prompt_template != PromptTemplateType::Qwen3NoThink
285            {
286                let err_msg = format!("Unsupported prompt template: {}. The tool use is only supported for 'mistral-tool', 'chatml-tool', 'groq-llama3-tool', 'llama-3-tool', 'internlm-2-tool', 'nemotron-tool', 'functionary-31', 'functionary-32', 'mistral-small-tool', 'llama-4-chat', and 'qwen-3-no-think' prompt templates.", graph.metadata.prompt_template);
287
288                #[cfg(feature = "logging")]
289                error!(target: "stdout", "{}", &err_msg);
290
291                return Err(LlamaCoreError::Operation(err_msg));
292            }
293
294            let parsed_result = parse_tool_calls(&message, graph.metadata.prompt_template)?;
295
296            let content = if parsed_result.tool_calls.is_empty() {
297                Some(parsed_result.raw.clone())
298            } else {
299                parsed_result.content.clone()
300            };
301
302            let (tool_calls, include_tool_calls) = match parsed_result.tool_calls.is_empty() {
303                false => {
304                    let tool_calls: Vec<ToolCallForChunk> = parsed_result
305                        .tool_calls
306                        .into_iter()
307                        .enumerate()
308                        .map(|(index, tool_call)| ToolCallForChunk {
309                            index,
310                            id: tool_call.id,
311                            ty: tool_call.ty,
312                            function: tool_call.function,
313                        })
314                        .collect();
315                    (tool_calls, true)
316                }
317                true => (vec![], false),
318            };
319
320            // tool_calls chunk
321            let tool_call_chunk = {
322                let chat_completion_chunk = ChatCompletionChunk {
323                    id: id.clone(),
324                    object: "chat.completion.chunk".to_string(),
325                    created: created.as_secs(),
326                    model: graph.name().to_owned(),
327                    system_fingerprint: "fp_44709d6fcb".to_string(),
328                    choices: vec![ChatCompletionChunkChoice {
329                        index: 0,
330                        delta: ChatCompletionChunkChoiceDelta {
331                            role: ChatCompletionRole::Assistant,
332                            content,
333                            tool_calls,
334                        },
335                        logprobs: None,
336                        finish_reason: None,
337                    }],
338                    usage: None,
339                };
340                let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
341                    let err_msg = format!("Failed to serialize chat completion chunk. Reason: {e}");
342
343                    #[cfg(feature = "logging")]
344                    error!(target: "stdout", "{}", &err_msg);
345
346                    LlamaCoreError::Operation(err_msg)
347                })?;
348
349                format!("data: {chunk_str}\n\n")
350            };
351
352            // token uage chunk
353            let usage_chunk = {
354                let chat_completion_chunk = ChatCompletionChunk {
355                    id: id.clone(),
356                    object: "chat.completion.chunk".to_string(),
357                    created: created.as_secs(),
358                    model: graph.name().to_owned(),
359                    system_fingerprint: "fp_44709d6fcb".to_string(),
360                    choices: vec![],
361                    usage,
362                };
363                let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
364                    let err_msg = format!("Failed to serialize chat completion chunk. Reason: {e}");
365
366                    #[cfg(feature = "logging")]
367                    error!(target: "stdout", "{}", &err_msg);
368
369                    LlamaCoreError::Operation(err_msg)
370                })?;
371
372                format!("data: {chunk_str}\n\n")
373            };
374
375            // ending chunk
376            let ending_chunk = "data: [DONE]\n\n".to_string();
377
378            let chunks = vec![tool_call_chunk, usage_chunk, ending_chunk];
379
380            let stream = ChatStream::new(
381                Some(graph.name().to_owned()),
382                id,
383                include_usage,
384                Some(chunks),
385            );
386
387            Ok((stream, include_tool_calls))
388        }
389        Err(wasmedge_wasi_nn::Error::BackendError(wasmedge_wasi_nn::BackendError::ContextFull)) => {
390            // Retrieve the output.
391            let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
392            let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
393                let err_msg = format!(
394                    "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
395                );
396
397                #[cfg(feature = "logging")]
398                error!(target: "stdout", "{}", &err_msg);
399
400                LlamaCoreError::Operation(err_msg)
401            })?;
402
403            // post-process
404            let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
405                let err_msg = format!("Failed to post-process the output. {e}");
406
407                #[cfg(feature = "logging")]
408                error!(target: "stdout", "{}", &err_msg);
409
410                LlamaCoreError::Operation(err_msg)
411            })?;
412
413            // retrieve the number of prompt and completion tokens
414            let token_info = get_token_info_by_graph(graph)?;
415
416            #[cfg(feature = "logging")]
417            info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
418
419            let usage = Some(Usage {
420                prompt_tokens: token_info.prompt_tokens,
421                completion_tokens: token_info.completion_tokens,
422                total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
423            });
424
425            let created = SystemTime::now()
426                .duration_since(std::time::UNIX_EPOCH)
427                .map_err(|e| {
428                    let err_msg = format!("Failed to get the current time. Reason: {e}");
429
430                    #[cfg(feature = "logging")]
431                    error!(target: "stdout", "{}", &err_msg);
432
433                    LlamaCoreError::Operation(err_msg)
434                })?;
435
436            // context full chunk
437            let context_full_chunk = {
438                let chat_completion_chunk = ChatCompletionChunk {
439                    id: id.clone(),
440                    object: "chat.completion.chunk".to_string(),
441                    created: created.as_secs(),
442                    model: graph.name().to_owned(),
443                    system_fingerprint: "fp_44709d6fcb".to_string(),
444                    choices: vec![ChatCompletionChunkChoice {
445                        index: 0,
446                        delta: ChatCompletionChunkChoiceDelta {
447                            role: ChatCompletionRole::Assistant,
448                            content: Some(message),
449                            tool_calls: vec![],
450                        },
451                        logprobs: None,
452                        finish_reason: Some(FinishReason::length),
453                    }],
454                    usage: None,
455                };
456
457                // serialize chat completion chunk
458                let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
459                    let err_msg = format!("Failed to serialize chat completion chunk. Reason: {e}");
460
461                    #[cfg(feature = "logging")]
462                    error!(target: "stdout", "{}", &err_msg);
463
464                    LlamaCoreError::Operation(err_msg)
465                })?;
466
467                format!("data: {chunk_str}\n\n")
468            };
469
470            // usage chunk
471            let usage_chunk = {
472                let chat_completion_chunk = ChatCompletionChunk {
473                    id: id.clone(),
474                    object: "chat.completion.chunk".to_string(),
475                    created: created.as_secs(),
476                    model: graph.name().to_owned(),
477                    system_fingerprint: "fp_44709d6fcb".to_string(),
478                    choices: vec![],
479                    usage,
480                };
481
482                // serialize chat completion chunk
483                let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
484                    let err_msg = format!("Failed to serialize chat completion chunk. Reason: {e}");
485
486                    #[cfg(feature = "logging")]
487                    error!(target: "stdout", "{}", &err_msg);
488
489                    LlamaCoreError::Operation(err_msg)
490                })?;
491
492                format!("data: {chunk_str}\n\n")
493            };
494
495            // ending chunk
496            let ending_chunk = "data: [DONE]\n\n".to_string();
497
498            let chunks = vec![context_full_chunk, usage_chunk, ending_chunk];
499
500            let stream = ChatStream::new(
501                Some(graph.name().to_owned()),
502                id,
503                include_usage,
504                Some(chunks),
505            );
506
507            Ok((stream, false))
508        }
509        Err(wasmedge_wasi_nn::Error::BackendError(
510            wasmedge_wasi_nn::BackendError::PromptTooLong,
511        )) => {
512            #[cfg(feature = "logging")]
513            warn!(target: "stdout", "The prompt is too long. Please reduce the length of your input and try again.");
514
515            // Retrieve the output.
516            let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
517            let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
518                let err_msg = format!(
519                    "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
520                );
521
522                #[cfg(feature = "logging")]
523                error!(target: "stdout", "{}", &err_msg);
524
525                LlamaCoreError::Operation(err_msg)
526            })?;
527
528            // post-process
529            let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
530                let err_msg = format!("Failed to post-process the output. {e}");
531
532                #[cfg(feature = "logging")]
533                error!(target: "stdout", "{}", &err_msg);
534
535                LlamaCoreError::Operation(err_msg)
536            })?;
537
538            // retrieve the number of prompt and completion token
539            let token_info = get_token_info_by_graph(graph)?;
540
541            #[cfg(feature = "logging")]
542            info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
543
544            let usage = Some(Usage {
545                prompt_tokens: token_info.prompt_tokens,
546                completion_tokens: token_info.completion_tokens,
547                total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
548            });
549
550            let created = SystemTime::now()
551                .duration_since(std::time::UNIX_EPOCH)
552                .map_err(|e| {
553                    let err_msg = format!("Failed to get the current time. Reason: {e}");
554
555                    #[cfg(feature = "logging")]
556                    error!(target: "stdout", "{}", &err_msg);
557
558                    LlamaCoreError::Operation(err_msg)
559                })?;
560
561            // prompt too long chunk
562            let prompt_too_long_chunk = {
563                let chat_completion_chunk = ChatCompletionChunk {
564                    id: id.clone(),
565                    object: "chat.completion.chunk".to_string(),
566                    created: created.as_secs(),
567                    model: graph.name().to_owned(),
568                    system_fingerprint: "fp_44709d6fcb".to_string(),
569                    choices: vec![ChatCompletionChunkChoice {
570                        index: 0,
571                        delta: ChatCompletionChunkChoiceDelta {
572                            role: ChatCompletionRole::Assistant,
573                            content: Some(message),
574                            tool_calls: vec![],
575                        },
576                        logprobs: None,
577                        finish_reason: Some(FinishReason::length),
578                    }],
579                    usage: None,
580                };
581
582                // serialize chat completion chunk
583                let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
584                    let err_msg = format!("Failed to serialize chat completion chunk. Reason: {e}");
585
586                    #[cfg(feature = "logging")]
587                    error!(target: "stdout", "{}", &err_msg);
588
589                    LlamaCoreError::Operation(err_msg)
590                })?;
591
592                format!("data: {chunk_str}\n\n")
593            };
594
595            // usage chunk
596            let usage_chunk = {
597                let chat_completion_chunk = ChatCompletionChunk {
598                    id: id.clone(),
599                    object: "chat.completion.chunk".to_string(),
600                    created: created.as_secs(),
601                    model: graph.name().to_owned(),
602                    system_fingerprint: "fp_44709d6fcb".to_string(),
603                    choices: vec![],
604                    usage,
605                };
606
607                // serialize chat completion chunk
608                let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
609                    let err_msg = format!("Failed to serialize chat completion chunk. Reason: {e}");
610
611                    #[cfg(feature = "logging")]
612                    error!(target: "stdout", "{}", &err_msg);
613
614                    LlamaCoreError::Operation(err_msg)
615                })?;
616
617                format!("data: {chunk_str}\n\n")
618            };
619
620            // ending chunk
621            let ending_chunk = "data: [DONE]\n\n".to_string();
622
623            let chunks = vec![prompt_too_long_chunk, usage_chunk, ending_chunk];
624
625            let stream = ChatStream::new(
626                Some(graph.name().to_owned()),
627                id,
628                include_usage,
629                Some(chunks),
630            );
631
632            Ok((stream, false))
633        }
634        Err(e) => {
635            let err_msg = format!("Failed to compute the chat completion. Reason: {e}");
636
637            #[cfg(feature = "logging")]
638            error!(target: "stdout", "{}", &err_msg);
639
640            Err(LlamaCoreError::Backend(BackendError::Compute(err_msg)))
641        }
642    }
643}
644
645async fn chat_once(
646    chat_request: &mut ChatCompletionRequest,
647) -> Result<(ChatCompletionObject, bool), LlamaCoreError> {
648    #[cfg(feature = "logging")]
649    info!(target: "stdout", "Processing chat completion request in non-stream mode");
650
651    let running_mode = running_mode()?;
652    if !running_mode.contains(RunningMode::CHAT) && !running_mode.contains(RunningMode::RAG) {
653        let err_msg = "The chat completion is only supported in the chat or rag mode.";
654
655        #[cfg(feature = "logging")]
656        error!(target: "stdout", "{err_msg}");
657
658        return Err(LlamaCoreError::Operation(err_msg.to_string()));
659    }
660
661    let model_name = chat_request.model.clone();
662    let id = match &chat_request.user {
663        Some(id) => id.clone(),
664        None => gen_chat_id(),
665    };
666
667    #[cfg(feature = "logging")]
668    info!(target: "stdout", "user: {}", &id);
669
670    #[cfg(feature = "logging")]
671    info!(target: "stdout", "Check model metadata");
672
673    // update metadata
674    let mut metadata = check_model_metadata(chat_request)?;
675
676    #[cfg(feature = "logging")]
677    info!(target: "stdout", "Build the chat prompt");
678
679    // build prompt
680    let (prompt, avaible_completion_tokens, tool_use) =
681        build_prompt(model_name.as_ref(), chat_request)?;
682
683    #[cfg(feature = "logging")]
684    {
685        info!(target: "stdout", "prompt:\n{}", &prompt);
686        info!(target: "stdout", "available_completion_tokens: {avaible_completion_tokens}");
687        info!(target: "stdout", "tool_use: {tool_use}");
688    }
689
690    #[cfg(feature = "logging")]
691    info!(target: "stdout", "Update n_predict");
692
693    // update metadata n_predict
694    update_n_predict(chat_request, &mut metadata, avaible_completion_tokens)?;
695
696    #[cfg(feature = "logging")]
697    info!(target: "stdout", "Feed the prompt to the model");
698
699    // feed the prompt to the model
700    set_prompt(model_name.as_ref(), &prompt)?;
701
702    #[cfg(feature = "logging")]
703    info!(target: "stdout", "Compute chat completion.");
704
705    // compute
706    let res = compute(model_name.as_ref(), id, tool_use);
707
708    #[cfg(feature = "logging")]
709    info!(target: "stdout", "End of the chat completion");
710
711    // reset the model metadata
712    reset_model_metadata(model_name.as_ref())?;
713
714    res
715}
716
717fn compute(
718    model_name: Option<&String>,
719    id: impl Into<String>,
720    tool_use: bool,
721) -> Result<(ChatCompletionObject, bool), LlamaCoreError> {
722    let chat_graphs = match CHAT_GRAPHS.get() {
723        Some(chat_graphs) => chat_graphs,
724        None => {
725            let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
726
727            #[cfg(feature = "logging")]
728            error!(target: "stdout", "{}", &err_msg);
729
730            return Err(LlamaCoreError::Operation(err_msg.into()));
731        }
732    };
733
734    let mut chat_graphs = chat_graphs.lock().map_err(|e| {
735        let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
736
737        #[cfg(feature = "logging")]
738        error!(target: "stdout", "{}", &err_msg);
739
740        LlamaCoreError::Operation(err_msg)
741    })?;
742
743    match model_name {
744        Some(model_name) => match chat_graphs.contains_key(model_name) {
745            true => {
746                let graph = chat_graphs.get_mut(model_name).unwrap();
747                compute_by_graph(graph, id, tool_use)
748            }
749            false => match chat_graphs.iter_mut().next() {
750                Some((_, graph)) => compute_by_graph(graph, id, tool_use),
751                None => {
752                    let err_msg = "There is no model available in the chat graphs.";
753
754                    #[cfg(feature = "logging")]
755                    error!(target: "stdout", "{}", &err_msg);
756
757                    Err(LlamaCoreError::Operation(err_msg.into()))
758                }
759            },
760        },
761        None => match chat_graphs.iter_mut().next() {
762            Some((_, graph)) => compute_by_graph(graph, id, tool_use),
763            None => {
764                let err_msg = "There is no model available in the chat graphs.";
765
766                #[cfg(feature = "logging")]
767                error!(target: "stdout", "{}", &err_msg);
768
769                Err(LlamaCoreError::Operation(err_msg.into()))
770            }
771        },
772    }
773}
774
775fn compute_by_graph(
776    graph: &mut Graph<GgmlMetadata>,
777    id: impl Into<String>,
778    tool_use: bool,
779) -> Result<(ChatCompletionObject, bool), LlamaCoreError> {
780    #[cfg(feature = "logging")]
781    info!(target: "stdout", "Compute chat completion by the model named {}.", graph.name());
782
783    match graph.compute() {
784        Ok(_) => {
785            // Retrieve the output.
786            let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
787            let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
788                let err_msg = format!(
789                    "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
790                );
791
792                #[cfg(feature = "logging")]
793                error!(target: "stdout", "{}", &err_msg);
794
795                LlamaCoreError::Operation(err_msg)
796            })?;
797
798            #[cfg(feature = "logging")]
799            info!(target: "stdout", "raw generation: {output}");
800
801            // post-process
802            let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
803                LlamaCoreError::Operation(format!("Failed to post-process the output. {e}"))
804            })?;
805
806            #[cfg(feature = "logging")]
807            info!(target: "stdout", "post-processed generation:\n{}", &message);
808
809            // retrieve the number of prompt and completion tokens
810            let token_info = get_token_info_by_graph(graph)?;
811
812            #[cfg(feature = "logging")]
813            info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
814
815            let created = SystemTime::now()
816                .duration_since(std::time::UNIX_EPOCH)
817                .map_err(|e| {
818                    let err_msg = format!("Failed to get the current time. Reason: {e}");
819
820                    #[cfg(feature = "logging")]
821                    error!(target: "stdout", "{}", &err_msg);
822
823                    LlamaCoreError::Operation(err_msg)
824                })?;
825
826            match tool_use {
827                true => {
828                    if graph.metadata.prompt_template != PromptTemplateType::MistralTool
829                        && graph.metadata.prompt_template != PromptTemplateType::ChatMLTool
830                        && graph.metadata.prompt_template != PromptTemplateType::GroqLlama3Tool
831                        && graph.metadata.prompt_template != PromptTemplateType::Llama3Tool
832                        && graph.metadata.prompt_template != PromptTemplateType::InternLM2Tool
833                        && graph.metadata.prompt_template != PromptTemplateType::NemotronTool
834                        && graph.metadata.prompt_template != PromptTemplateType::FunctionaryV32
835                        && graph.metadata.prompt_template != PromptTemplateType::FunctionaryV31
836                        && graph.metadata.prompt_template != PromptTemplateType::MistralSmallTool
837                        && graph.metadata.prompt_template != PromptTemplateType::Llama4Chat
838                        && graph.metadata.prompt_template != PromptTemplateType::Qwen3NoThink
839                    {
840                        let err_msg = format!("Unsupported prompt template: {}. The tool use is only supported for 'mistral-tool', 'chatml-tool', 'groq-llama3-tool', 'llama-3-tool', 'internlm-2-tool', 'nemotron-tool', 'functionary-31', 'functionary-32', 'mistral-small-tool', 'llama-4-chat', and 'qwen-3-no-think' prompt templates.", graph.metadata.prompt_template);
841
842                        #[cfg(feature = "logging")]
843                        error!(target: "stdout", "{}", &err_msg);
844
845                        return Err(LlamaCoreError::Operation(err_msg));
846                    }
847
848                    let parsed_result = parse_tool_calls(&message, graph.metadata.prompt_template)?;
849
850                    let (finish_reason, content, include_tool_calls) =
851                        if parsed_result.tool_calls.is_empty() {
852                            (FinishReason::stop, Some(parsed_result.raw.clone()), false)
853                        } else {
854                            (
855                                FinishReason::tool_calls,
856                                parsed_result.content.clone(),
857                                true,
858                            )
859                        };
860
861                    let res = ChatCompletionObject {
862                        id: id.into(),
863                        object: String::from("chat.completion"),
864                        created: created.as_secs(),
865                        model: graph.name().to_owned(),
866                        choices: vec![ChatCompletionObjectChoice {
867                            index: 0,
868                            message: ChatCompletionObjectMessage {
869                                role: ChatCompletionRole::Assistant,
870                                content,
871                                tool_calls: parsed_result.tool_calls,
872                                function_call: None,
873                            },
874                            finish_reason,
875                            logprobs: None,
876                        }],
877                        usage: Usage {
878                            prompt_tokens: token_info.prompt_tokens,
879                            completion_tokens: token_info.completion_tokens,
880                            total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
881                        },
882                    };
883
884                    // create ChatCompletionResponse
885                    Ok((res, include_tool_calls))
886                }
887                false => {
888                    // create ChatCompletionResponse
889                    let res = ChatCompletionObject {
890                        id: id.into(),
891                        object: String::from("chat.completion"),
892                        created: created.as_secs(),
893                        model: graph.name().to_owned(),
894                        choices: vec![ChatCompletionObjectChoice {
895                            index: 0,
896                            message: ChatCompletionObjectMessage {
897                                role: ChatCompletionRole::Assistant,
898                                content: Some(message),
899                                tool_calls: vec![],
900                                function_call: None,
901                            },
902                            finish_reason: FinishReason::stop,
903                            logprobs: None,
904                        }],
905                        usage: Usage {
906                            prompt_tokens: token_info.prompt_tokens,
907                            completion_tokens: token_info.completion_tokens,
908                            total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
909                        },
910                    };
911
912                    Ok((res, false))
913                }
914            }
915        }
916        Err(wasmedge_wasi_nn::Error::BackendError(wasmedge_wasi_nn::BackendError::ContextFull)) => {
917            // Retrieve the output.
918            let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
919            let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
920                let err_msg = format!(
921                    "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
922                );
923
924                #[cfg(feature = "logging")]
925                error!(target: "stdout", "{}", &err_msg);
926
927                LlamaCoreError::Operation(err_msg)
928            })?;
929
930            // post-process
931            let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
932                let err_msg = format!("Failed to post-process the output. {e}");
933
934                #[cfg(feature = "logging")]
935                error!(target: "stdout", "{}", &err_msg);
936
937                LlamaCoreError::Operation(err_msg)
938            })?;
939
940            // retrieve the number of prompt and completion tokens
941            let token_info = get_token_info_by_graph(graph)?;
942
943            #[cfg(feature = "logging")]
944            info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
945
946            let created = SystemTime::now()
947                .duration_since(std::time::UNIX_EPOCH)
948                .map_err(|e| {
949                    let err_msg = format!("Failed to get the current time. Reason: {e}");
950
951                    #[cfg(feature = "logging")]
952                    error!(target: "stdout", "{}", &err_msg);
953
954                    LlamaCoreError::Operation(err_msg)
955                })?;
956
957            // create ChatCompletionResponse
958            let res = ChatCompletionObject {
959                id: id.into(),
960                object: String::from("chat.completion"),
961                created: created.as_secs(),
962                model: graph.name().to_owned(),
963                choices: vec![ChatCompletionObjectChoice {
964                    index: 0,
965                    message: ChatCompletionObjectMessage {
966                        role: ChatCompletionRole::Assistant,
967                        content: Some(message),
968                        tool_calls: vec![],
969                        function_call: None,
970                    },
971                    finish_reason: FinishReason::length,
972                    logprobs: None,
973                }],
974                usage: Usage {
975                    prompt_tokens: token_info.prompt_tokens,
976                    completion_tokens: token_info.completion_tokens,
977                    total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
978                },
979            };
980
981            Ok((res, false))
982        }
983        Err(wasmedge_wasi_nn::Error::BackendError(
984            wasmedge_wasi_nn::BackendError::PromptTooLong,
985        )) => {
986            #[cfg(feature = "logging")]
987            warn!(target: "stdout", "The prompt is too long. Please reduce the length of your input and try again.");
988
989            // Retrieve the output.
990            let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
991            let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
992                let err_msg = format!(
993                    "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
994                );
995
996                #[cfg(feature = "logging")]
997                error!(target: "stdout", "{}", &err_msg);
998
999                LlamaCoreError::Operation(err_msg)
1000            })?;
1001
1002            // post-process
1003            let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
1004                let err_msg = format!("Failed to post-process the output. {e}");
1005
1006                #[cfg(feature = "logging")]
1007                error!(target: "stdout", "{}", &err_msg);
1008
1009                LlamaCoreError::Operation(err_msg)
1010            })?;
1011
1012            // retrieve the number of prompt and completion token
1013            let token_info = get_token_info_by_graph(graph)?;
1014
1015            #[cfg(feature = "logging")]
1016            info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
1017
1018            let usage = Usage {
1019                prompt_tokens: token_info.prompt_tokens,
1020                completion_tokens: token_info.completion_tokens,
1021                total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
1022            };
1023
1024            let created = SystemTime::now()
1025                .duration_since(std::time::UNIX_EPOCH)
1026                .map_err(|e| {
1027                    let err_msg = format!("Failed to get the current time. Reason: {e}");
1028
1029                    #[cfg(feature = "logging")]
1030                    error!(target: "stdout", "{}", &err_msg);
1031
1032                    LlamaCoreError::Operation(err_msg)
1033                })?;
1034
1035            // create ChatCompletionResponse
1036            let res = ChatCompletionObject {
1037                id: id.into(),
1038                object: String::from("chat.completion"),
1039                created: created.as_secs(),
1040                model: graph.name().to_owned(),
1041                choices: vec![ChatCompletionObjectChoice {
1042                    index: 0,
1043                    message: ChatCompletionObjectMessage {
1044                        role: ChatCompletionRole::Assistant,
1045                        content: Some(message),
1046                        tool_calls: vec![],
1047                        function_call: None,
1048                    },
1049                    finish_reason: FinishReason::length,
1050                    logprobs: None,
1051                }],
1052                usage,
1053            };
1054
1055            Ok((res, false))
1056        }
1057        Err(e) => {
1058            let err_msg = format!("Failed to compute the chat completion. Reason: {e}");
1059
1060            #[cfg(feature = "logging")]
1061            error!(target: "stdout", "{}", &err_msg);
1062
1063            Err(LlamaCoreError::Backend(BackendError::Compute(err_msg)))
1064        }
1065    }
1066}
1067
1068fn parse_tool_calls(
1069    input: &str,
1070    prompt_template: PromptTemplateType,
1071) -> Result<ParseResult, LlamaCoreError> {
1072    match prompt_template {
1073        PromptTemplateType::MistralTool => match regex::Regex::new(r"\[\{.*?\}\]") {
1074            Ok(re) => {
1075                let mut values: Vec<serde_json::Value> = vec![];
1076                for cap in re.captures_iter(input) {
1077                    let matched = &cap[0];
1078
1079                    #[cfg(feature = "logging")]
1080                    info!(target: "stdout", "captured: {matched}");
1081
1082                    match serde_json::from_str::<Vec<serde_json::Value>>(matched) {
1083                        Ok(group) => values.extend(group),
1084                        Err(e) => {
1085                            let err_msg =
1086                                format!("Failed to deserialize generated tool calls. Reason: {e}");
1087
1088                            #[cfg(feature = "logging")]
1089                            error!(target: "stdout", "{}", &err_msg);
1090
1091                            return Err(LlamaCoreError::Operation(err_msg));
1092                        }
1093                    }
1094                }
1095
1096                let mut tool_calls: Vec<ToolCall> = vec![];
1097                for value in values.iter() {
1098                    let name = match value.get("name") {
1099                        Some(name) => name.to_string().replace("\"", ""),
1100                        None => {
1101                            let err_msg = format!(
1102                                "Failed to get the name of the function. Tool call: {value:?}"
1103                            );
1104
1105                            #[cfg(feature = "logging")]
1106                            error!(target: "stdout", "{}", &err_msg);
1107
1108                            return Err(LlamaCoreError::Operation(err_msg));
1109                        }
1110                    };
1111
1112                    let arguments = match value.get("arguments") {
1113                        Some(arguments) => arguments.to_string(),
1114                        None => {
1115                            let err_msg = format!(
1116                                "Failed to get the arguments of the function. Tool call: {value:?}"
1117                            );
1118
1119                            #[cfg(feature = "logging")]
1120                            error!(target: "stdout", "{}", &err_msg);
1121
1122                            return Err(LlamaCoreError::Operation(err_msg));
1123                        }
1124                    };
1125
1126                    let function = Function { name, arguments };
1127
1128                    let tool_call = ToolCall {
1129                        id: "call_abc123".to_string(),
1130                        ty: "function".to_string(),
1131                        function,
1132                    };
1133
1134                    tool_calls.push(tool_call);
1135                }
1136
1137                let parsed = ParseResult {
1138                    raw: input.to_owned(),
1139                    content: None,
1140                    tool_calls,
1141                };
1142
1143                #[cfg(feature = "logging")]
1144                info!(target: "stdout", "parsed result: {parsed:?}");
1145
1146                Ok(parsed)
1147            }
1148            Err(e) => {
1149                let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1150
1151                #[cfg(feature = "logging")]
1152                error!(target: "stdout", "{}", &err_msg);
1153
1154                Err(LlamaCoreError::Operation(err_msg))
1155            }
1156        },
1157        PromptTemplateType::ChatMLTool => {
1158            match regex::Regex::new(r"<tool_call>(.*?)</tool_call>") {
1159                Ok(re) => {
1160                    let mut values: Vec<serde_json::Value> = vec![];
1161                    for cap in re.captures_iter(input) {
1162                        let matched = cap[1].replace("\\n", ""); // Remove "\\n" from the captured group
1163
1164                        #[cfg(feature = "logging")]
1165                        info!(target: "stdout", "captured: {}", &matched);
1166
1167                        match serde_json::from_str::<serde_json::Value>(&matched) {
1168                            Ok(value) => values.push(value),
1169                            Err(e) => {
1170                                let err_msg = format!(
1171                                    "Failed to deserialize generated tool calls. Reason: {e}"
1172                                );
1173
1174                                #[cfg(feature = "logging")]
1175                                error!(target: "stdout", "{}", &err_msg);
1176
1177                                return Err(LlamaCoreError::Operation(err_msg));
1178                            }
1179                        }
1180                    }
1181
1182                    let mut tool_calls: Vec<ToolCall> = vec![];
1183                    for value in values.iter() {
1184                        let name = match value.get("name") {
1185                            Some(name) => name.to_string().replace("\"", ""),
1186                            None => {
1187                                let err_msg = format!(
1188                                    "Failed to get the name of the function. Tool call: {value:?}"
1189                                );
1190
1191                                #[cfg(feature = "logging")]
1192                                error!(target: "stdout", "{}", &err_msg);
1193
1194                                return Err(LlamaCoreError::Operation(err_msg));
1195                            }
1196                        };
1197
1198                        let arguments = match value.get("arguments") {
1199                            Some(arguments) => arguments.to_string(),
1200                            None => {
1201                                let err_msg = format!(
1202                                    "Failed to get the arguments of the function. Tool call: {value:?}"
1203                                );
1204
1205                                #[cfg(feature = "logging")]
1206                                error!(target: "stdout", "{}", &err_msg);
1207
1208                                return Err(LlamaCoreError::Operation(err_msg));
1209                            }
1210                        };
1211
1212                        let function = Function { name, arguments };
1213
1214                        let tool_call = ToolCall {
1215                            id: "call_abc123".to_string(),
1216                            ty: "function".to_string(),
1217                            function,
1218                        };
1219
1220                        tool_calls.push(tool_call);
1221                    }
1222
1223                    let parsed = ParseResult {
1224                        raw: input.to_owned(),
1225                        content: None,
1226                        tool_calls,
1227                    };
1228
1229                    #[cfg(feature = "logging")]
1230                    info!(target: "stdout", "parsed result: {parsed:?}");
1231
1232                    Ok(parsed)
1233                }
1234                Err(e) => {
1235                    let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1236
1237                    #[cfg(feature = "logging")]
1238                    error!(target: "stdout", "{}", &err_msg);
1239
1240                    Err(LlamaCoreError::Operation(err_msg))
1241                }
1242            }
1243        }
1244        PromptTemplateType::GroqLlama3Tool => {
1245            #[cfg(feature = "logging")]
1246            info!(target: "stdout", "raw input: {input}");
1247
1248            match regex::Regex::new(r"(?s)<tool_call>((.|\r|\n)*?)</tool_call>") {
1249                Ok(re) => {
1250                    let mut values: Vec<serde_json::Value> = vec![];
1251                    for cap in re.captures_iter(input) {
1252                        let matched = cap[1].trim();
1253
1254                        #[cfg(feature = "logging")]
1255                        info!(target: "stdout", "captured: {matched}");
1256
1257                        match serde_json::from_str::<serde_json::Value>(matched) {
1258                            Ok(value) => values.push(value),
1259                            Err(e) => {
1260                                let err_msg = format!(
1261                                    "Failed to deserialize generated tool calls. Reason: {e}"
1262                                );
1263
1264                                #[cfg(feature = "logging")]
1265                                error!(target: "stdout", "{}", &err_msg);
1266
1267                                return Err(LlamaCoreError::Operation(err_msg));
1268                            }
1269                        }
1270                    }
1271
1272                    let mut tool_calls: Vec<ToolCall> = vec![];
1273                    for value in values.iter() {
1274                        let name = match value.get("name") {
1275                            Some(name) => name.to_string().replace("\"", ""),
1276                            None => {
1277                                let err_msg = format!(
1278                                    "Failed to get the name of the function. Tool call: {value:?}"
1279                                );
1280
1281                                #[cfg(feature = "logging")]
1282                                error!(target: "stdout", "{}", &err_msg);
1283
1284                                return Err(LlamaCoreError::Operation(err_msg));
1285                            }
1286                        };
1287
1288                        let arguments = match value.get("arguments") {
1289                            Some(arguments) => {
1290                                if arguments.is_string() {
1291                                    arguments.as_str().unwrap().to_string()
1292                                } else if arguments.is_object() {
1293                                    let map = arguments.as_object().unwrap();
1294
1295                                    #[cfg(feature = "logging")]
1296                                    info!(target: "stdout", "func arguments: {map:?}");
1297
1298                                    serde_json::to_string(map).unwrap()
1299                                } else {
1300                                    serde_json::to_string(arguments).unwrap()
1301                                }
1302                            }
1303                            None => {
1304                                let err_msg = format!(
1305                                    "Failed to get the arguments of the function. Tool call: {value:?}"
1306                                );
1307
1308                                #[cfg(feature = "logging")]
1309                                error!(target: "stdout", "{}", &err_msg);
1310
1311                                return Err(LlamaCoreError::Operation(err_msg));
1312                            }
1313                        };
1314
1315                        let function = Function { name, arguments };
1316
1317                        let tool_call = ToolCall {
1318                            id: "call_abc123".to_string(),
1319                            ty: "function".to_string(),
1320                            function,
1321                        };
1322
1323                        tool_calls.push(tool_call);
1324                    }
1325
1326                    let parsed = if tool_calls.is_empty() {
1327                        ParseResult {
1328                            raw: input.to_owned(),
1329                            content: Some(input.to_owned()),
1330                            tool_calls: vec![],
1331                        }
1332                    } else {
1333                        ParseResult {
1334                            raw: input.to_owned(),
1335                            content: None,
1336                            tool_calls,
1337                        }
1338                    };
1339
1340                    #[cfg(feature = "logging")]
1341                    info!(target: "stdout", "parsed result: {parsed:?}");
1342
1343                    Ok(parsed)
1344                }
1345                Err(e) => {
1346                    let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1347
1348                    #[cfg(feature = "logging")]
1349                    error!(target: "stdout", "{}", &err_msg);
1350
1351                    Err(LlamaCoreError::Operation(err_msg))
1352                }
1353            }
1354        }
1355        PromptTemplateType::Llama3Tool => {
1356            #[cfg(feature = "logging")]
1357            info!(target: "stdout", "raw input: {input}");
1358
1359            let re = match regex::Regex::new(r"^\{(.|\r|\n)*\}$") {
1360                Ok(re) => re,
1361                Err(e) => {
1362                    let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1363
1364                    #[cfg(feature = "logging")]
1365                    error!(target: "stdout", "{}", &err_msg);
1366
1367                    return Err(LlamaCoreError::Operation(err_msg));
1368                }
1369            };
1370
1371            if re.is_match(input) {
1372                match serde_json::from_str::<serde_json::Value>(input) {
1373                    Ok(value) => {
1374                        let values: Vec<serde_json::Value> = vec![value];
1375
1376                        let mut tool_calls: Vec<ToolCall> = vec![];
1377                        for value in values.iter() {
1378                            let name = match value.get("name") {
1379                                Some(name) => name.to_string().replace("\"", ""),
1380                                None => {
1381                                    let err_msg = format!(
1382                                        "Failed to get the name of the function. Tool call: {value:?}"
1383                                    );
1384
1385                                    #[cfg(feature = "logging")]
1386                                    error!(target: "stdout", "{}", &err_msg);
1387
1388                                    return Err(LlamaCoreError::Operation(err_msg));
1389                                }
1390                            };
1391
1392                            let arguments = match value.get("parameters") {
1393                                Some(arguments) => arguments.to_string(),
1394                                None => {
1395                                    let err_msg = format!(
1396                                        "Failed to get the arguments of the function. Tool call: {value:?}"
1397                                    );
1398
1399                                    #[cfg(feature = "logging")]
1400                                    error!(target: "stdout", "{}", &err_msg);
1401
1402                                    return Err(LlamaCoreError::Operation(err_msg));
1403                                }
1404                            };
1405
1406                            let function = Function { name, arguments };
1407
1408                            let tool_call = ToolCall {
1409                                id: "call_abc123".to_string(),
1410                                ty: "function".to_string(),
1411                                function,
1412                            };
1413
1414                            tool_calls.push(tool_call);
1415                        }
1416
1417                        let parsed = ParseResult {
1418                            raw: input.to_owned(),
1419                            content: None,
1420                            tool_calls,
1421                        };
1422
1423                        #[cfg(feature = "logging")]
1424                        info!(target: "stdout", "parsed result: {parsed:?}");
1425
1426                        Ok(parsed)
1427                    }
1428                    Err(e) => {
1429                        let err_msg =
1430                            format!("Failed to deserialize generated tool calls. Reason: {e}");
1431
1432                        #[cfg(feature = "logging")]
1433                        error!(target: "stdout", "{}", &err_msg);
1434
1435                        Err(LlamaCoreError::Operation(err_msg))
1436                    }
1437                }
1438            } else {
1439                let parsed = ParseResult {
1440                    raw: input.to_owned(),
1441                    content: None,
1442                    tool_calls: vec![],
1443                };
1444
1445                #[cfg(feature = "logging")]
1446                info!(target: "stdout", "parsed result: {parsed:?}");
1447
1448                Ok(parsed)
1449            }
1450        }
1451        PromptTemplateType::InternLM2Tool => {
1452            #[cfg(feature = "logging")]
1453            info!(target: "stdout", "raw input: {input}");
1454
1455            let blocks: Vec<&str> = input.trim().split("<|action_start|><|plugin|>").collect();
1456
1457            #[cfg(feature = "logging")]
1458            info!(target: "stdout", "blocks: {blocks:?}");
1459
1460            let mut tool_calls: Vec<ToolCall> = vec![];
1461            let mut content = String::new();
1462            for block in blocks {
1463                let block = block.trim();
1464                if !block.is_empty() {
1465                    if block.ends_with("<|action_end|>") {
1466                        let value = block.trim().trim_end_matches("<|action_end|>");
1467
1468                        #[cfg(feature = "logging")]
1469                        info!(target: "stdout", "tool call: {value}");
1470
1471                        match serde_json::from_str::<serde_json::Value>(value) {
1472                            Ok(value) => {
1473                                let name = match value.get("name") {
1474                                    Some(name) => name.to_string().replace("\"", ""),
1475                                    None => {
1476                                        let err_msg = format!(
1477                                            "Failed to get the name of the function. Tool call: {value:?}"
1478                                        );
1479
1480                                        #[cfg(feature = "logging")]
1481                                        error!(target: "stdout", "{}", &err_msg);
1482
1483                                        return Err(LlamaCoreError::Operation(err_msg));
1484                                    }
1485                                };
1486
1487                                let arguments = match value.get("parameters") {
1488                                    Some(arguments) => arguments.to_string(),
1489                                    None => {
1490                                        let err_msg = format!(
1491                                            "Failed to get the arguments of the function. Tool call: {value:?}"
1492                                        );
1493
1494                                        #[cfg(feature = "logging")]
1495                                        error!(target: "stdout", "{}", &err_msg);
1496
1497                                        return Err(LlamaCoreError::Operation(err_msg));
1498                                    }
1499                                };
1500
1501                                let function = Function { name, arguments };
1502
1503                                let tool_call = ToolCall {
1504                                    id: "call_abc123".to_string(),
1505                                    ty: "function".to_string(),
1506                                    function,
1507                                };
1508
1509                                tool_calls.push(tool_call);
1510                            }
1511                            Err(e) => {
1512                                let err_msg = format!(
1513                                    "Failed to deserialize generated tool calls. Reason: {e}"
1514                                );
1515
1516                                #[cfg(feature = "logging")]
1517                                error!(target: "stdout", "{}", &err_msg);
1518
1519                                return Err(LlamaCoreError::Operation(err_msg));
1520                            }
1521                        }
1522                    } else {
1523                        content.push_str(block);
1524                        content.push('\n');
1525                    }
1526                }
1527            }
1528
1529            let parsed = match content.is_empty() {
1530                true => ParseResult {
1531                    raw: input.to_owned(),
1532                    content: None,
1533                    tool_calls,
1534                },
1535                false => ParseResult {
1536                    raw: input.to_owned(),
1537                    content: Some(content.trim().to_owned()),
1538                    tool_calls,
1539                },
1540            };
1541
1542            #[cfg(feature = "logging")]
1543            info!(target: "stdout", "parsed result: {parsed:?}");
1544
1545            Ok(parsed)
1546        }
1547        PromptTemplateType::NemotronTool => {
1548            #[cfg(feature = "logging")]
1549            info!(target: "stdout", "raw input: {input}");
1550
1551            match regex::Regex::new(r"(?s)<toolcall>\s*(.*?)\s*</toolcall>") {
1552                Ok(re) => {
1553                    let mut values: Vec<serde_json::Value> = vec![];
1554                    for cap in re.captures_iter(input) {
1555                        #[cfg(feature = "logging")]
1556                        info!(target: "stdout", "captured: {}", &cap[0]);
1557
1558                        #[cfg(feature = "logging")]
1559                        info!(target: "stdout", "extracted: {}", &cap[1]);
1560
1561                        let matched = cap[1].trim();
1562
1563                        #[cfg(feature = "logging")]
1564                        info!(target: "stdout", "captured: {matched}");
1565
1566                        match serde_json::from_str::<serde_json::Value>(matched) {
1567                            Ok(value) => values.push(value),
1568                            Err(e) => {
1569                                let err_msg = format!(
1570                                    "Failed to deserialize generated tool calls. Reason: {e}"
1571                                );
1572
1573                                #[cfg(feature = "logging")]
1574                                error!(target: "stdout", "{}", &err_msg);
1575
1576                                return Err(LlamaCoreError::Operation(err_msg));
1577                            }
1578                        }
1579                    }
1580
1581                    let mut tool_calls: Vec<ToolCall> = vec![];
1582                    for value in values.iter() {
1583                        let name = match value.get("name") {
1584                            Some(name) => name.to_string().replace("\"", ""),
1585                            None => {
1586                                let err_msg = format!(
1587                                    "Failed to get the name of the function. Tool call: {value:?}"
1588                                );
1589
1590                                #[cfg(feature = "logging")]
1591                                error!(target: "stdout", "{}", &err_msg);
1592
1593                                return Err(LlamaCoreError::Operation(err_msg));
1594                            }
1595                        };
1596
1597                        let arguments = match value.get("arguments") {
1598                            Some(arguments) => arguments.to_string(),
1599                            None => {
1600                                let err_msg = format!(
1601                                    "Failed to get the arguments of the function. Tool call: {value:?}"
1602                                );
1603
1604                                #[cfg(feature = "logging")]
1605                                error!(target: "stdout", "{}", &err_msg);
1606
1607                                return Err(LlamaCoreError::Operation(err_msg));
1608                            }
1609                        };
1610
1611                        let function = Function { name, arguments };
1612
1613                        let tool_call = ToolCall {
1614                            id: "call_abc123".to_string(),
1615                            ty: "function".to_string(),
1616                            function,
1617                        };
1618
1619                        tool_calls.push(tool_call);
1620                    }
1621
1622                    let parsed = ParseResult {
1623                        raw: input.to_owned(),
1624                        content: None,
1625                        tool_calls,
1626                    };
1627
1628                    #[cfg(feature = "logging")]
1629                    info!(target: "stdout", "parsed result: {parsed:?}");
1630
1631                    Ok(parsed)
1632                }
1633                Err(e) => {
1634                    let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1635
1636                    #[cfg(feature = "logging")]
1637                    error!(target: "stdout", "{}", &err_msg);
1638
1639                    Err(LlamaCoreError::Operation(err_msg))
1640                }
1641            }
1642        }
1643        PromptTemplateType::FunctionaryV32 => {
1644            #[cfg(feature = "logging")]
1645            info!(target: "stdout", "raw input: {input}");
1646
1647            match regex::Regex::new(r">>>\s*(\w+)\s*\{(.*)\}<\|eot_id\|>") {
1648                Ok(re) => {
1649                    let mut tool_calls: Vec<ToolCall> = vec![];
1650                    for cap in re.captures_iter(input) {
1651                        #[cfg(feature = "logging")]
1652                        info!(target: "stdout", "func_name: {}", &cap[1]);
1653
1654                        #[cfg(feature = "logging")]
1655                        info!(target: "stdout", "arguments: {}", &cap[2]);
1656
1657                        let tool_call = ToolCall {
1658                            id: "call_abc123".to_string(),
1659                            ty: "function".to_string(),
1660                            function: Function {
1661                                name: cap[1].to_string(),
1662                                arguments: cap[2].to_string(),
1663                            },
1664                        };
1665
1666                        tool_calls.push(tool_call);
1667                    }
1668
1669                    let parsed = ParseResult {
1670                        raw: input.to_owned(),
1671                        content: None,
1672                        tool_calls,
1673                    };
1674
1675                    #[cfg(feature = "logging")]
1676                    info!(target: "stdout", "parsed result: {parsed:?}");
1677
1678                    Ok(parsed)
1679                }
1680                Err(e) => {
1681                    let warn_msg = format!("Failed to create a regex pattern. Reason: {e}");
1682
1683                    #[cfg(feature = "logging")]
1684                    warn!(target: "stdout", "{}", &warn_msg);
1685
1686                    Ok(ParseResult {
1687                        raw: input.to_owned(),
1688                        content: None,
1689                        tool_calls: vec![],
1690                    })
1691                }
1692            }
1693        }
1694        PromptTemplateType::FunctionaryV31 => {
1695            #[cfg(feature = "logging")]
1696            info!(target: "stdout", "raw input: {input}");
1697
1698            match regex::Regex::new(r"<function=(\w+)>\s*(\{.*?\})</function>") {
1699                Ok(re) => {
1700                    let mut tool_calls: Vec<ToolCall> = vec![];
1701                    for cap in re.captures_iter(input) {
1702                        #[cfg(feature = "logging")]
1703                        info!(target: "stdout", "func_name: {}", &cap[1]);
1704
1705                        #[cfg(feature = "logging")]
1706                        info!(target: "stdout", "arguments: {}", &cap[2]);
1707
1708                        let tool_call = ToolCall {
1709                            id: "call_abc123".to_string(),
1710                            ty: "function".to_string(),
1711                            function: Function {
1712                                name: cap[1].to_string(),
1713                                arguments: cap[2].to_string(),
1714                            },
1715                        };
1716
1717                        tool_calls.push(tool_call);
1718                    }
1719
1720                    let parsed = ParseResult {
1721                        raw: input.to_owned(),
1722                        content: None,
1723                        tool_calls,
1724                    };
1725
1726                    #[cfg(feature = "logging")]
1727                    info!(target: "stdout", "parsed result: {parsed:?}");
1728
1729                    Ok(parsed)
1730                }
1731                Err(e) => {
1732                    let warn_msg = format!("Failed to create a regex pattern. Reason: {e}");
1733
1734                    #[cfg(feature = "logging")]
1735                    warn!(target: "stdout", "{}", &warn_msg);
1736
1737                    Ok(ParseResult {
1738                        raw: input.to_owned(),
1739                        content: None,
1740                        tool_calls: vec![],
1741                    })
1742                }
1743            }
1744        }
1745        PromptTemplateType::MistralSmallTool => {
1746            #[cfg(feature = "logging")]
1747            info!(target: "stdout", "raw input: {input}");
1748
1749            match regex::Regex::new(r"\[TOOL_CALLS\]\s*(\[(.*?)\])") {
1750                Ok(re) => {
1751                    let mut values: Vec<serde_json::Value> = vec![];
1752                    if let Some(cap) = re.captures(input) {
1753                        let matched = cap[1].trim();
1754
1755                        #[cfg(feature = "logging")]
1756                        info!(target: "stdout", "captured: {matched}");
1757
1758                        match serde_json::from_str::<Vec<serde_json::Value>>(matched) {
1759                            Ok(vals) => values = vals,
1760                            Err(e) => {
1761                                let err_msg = format!(
1762                                    "Failed to deserialize generated tool calls. Reason: {e}"
1763                                );
1764
1765                                #[cfg(feature = "logging")]
1766                                error!(target: "stdout", "{}", &err_msg);
1767
1768                                return Err(LlamaCoreError::Operation(err_msg));
1769                            }
1770                        }
1771                    };
1772
1773                    let mut tool_calls: Vec<ToolCall> = vec![];
1774                    for value in values.iter() {
1775                        if let Some(object_map) = value.as_object() {
1776                            if object_map.contains_key("function") {
1777                                let mut function = Function {
1778                                    name: String::new(),
1779                                    arguments: String::new(),
1780                                };
1781
1782                                let value = object_map.get("function").unwrap();
1783                                let func_map = value.as_object().unwrap();
1784                                if func_map.contains_key("name") {
1785                                    let func_name = func_map.get("name").unwrap().as_str().unwrap();
1786                                    println!("Function name: {func_name:?}");
1787
1788                                    function.name = func_name.to_string();
1789                                }
1790                                if func_map.contains_key("arguments") {
1791                                    let args = func_map.get("arguments").unwrap();
1792                                    let arguments = args.to_string();
1793                                    println!("Arguments: {arguments:?}");
1794
1795                                    function.arguments = arguments;
1796                                }
1797
1798                                let tool_call = ToolCall {
1799                                    id: "call_abc123".to_string(),
1800                                    ty: "function".to_string(),
1801                                    function,
1802                                };
1803
1804                                tool_calls.push(tool_call);
1805                            } else if object_map.contains_key("name") {
1806                                let mut function = Function {
1807                                    name: String::new(),
1808                                    arguments: String::new(),
1809                                };
1810
1811                                let name = object_map.get("name").unwrap().as_str().unwrap();
1812                                println!("name: {name:?}");
1813                                function.name = name.to_string();
1814
1815                                if object_map.contains_key("arguments") {
1816                                    let args = object_map.get("arguments").unwrap();
1817                                    let arguments = args.to_string();
1818                                    println!("Arguments: {arguments:?}");
1819
1820                                    function.arguments = arguments;
1821                                }
1822
1823                                let tool_call = ToolCall {
1824                                    id: "call_abc123".to_string(),
1825                                    ty: "function".to_string(),
1826                                    function,
1827                                };
1828
1829                                tool_calls.push(tool_call);
1830                            }
1831                        }
1832                    }
1833
1834                    let parsed = ParseResult {
1835                        raw: input.to_owned(),
1836                        content: None,
1837                        tool_calls,
1838                    };
1839
1840                    #[cfg(feature = "logging")]
1841                    info!(target: "stdout", "parsed result: {parsed:?}");
1842
1843                    Ok(parsed)
1844                }
1845                Err(e) => {
1846                    let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1847
1848                    #[cfg(feature = "logging")]
1849                    error!(target: "stdout", "{}", &err_msg);
1850
1851                    Err(LlamaCoreError::Operation(err_msg))
1852                }
1853            }
1854        }
1855        PromptTemplateType::Llama4Chat => {
1856            #[cfg(feature = "logging")]
1857            info!(target: "stdout", "raw input: {input:?}");
1858
1859            let mut tool_calls: Vec<ToolCall> = vec![];
1860            if let Ok(value) = serde_json::from_str::<serde_json::Value>(input) {
1861                match value.as_object() {
1862                    Some(object_map) => {
1863                        #[cfg(feature = "logging")]
1864                        debug!(target: "stdout", "object_map: {object_map:?}");
1865
1866                        // parse function name
1867                        if object_map.contains_key("name") {
1868                            let name = object_map.get("name").unwrap().as_str().unwrap();
1869
1870                            #[cfg(feature = "logging")]
1871                            debug!(target: "stdout", "name: {name:?}");
1872
1873                            let mut function = Function {
1874                                name: name.to_string(),
1875                                arguments: String::new(),
1876                            };
1877
1878                            // parse function arguments
1879                            if object_map.contains_key("parameters") {
1880                                let args = object_map.get("parameters").unwrap();
1881                                let arguments = args.to_string();
1882
1883                                #[cfg(feature = "logging")]
1884                                debug!(target: "stdout", "arguments: {:?}", &arguments);
1885
1886                                function.arguments = arguments;
1887                            }
1888
1889                            tool_calls.push(ToolCall {
1890                                id: "call_abc123".to_string(),
1891                                ty: "function".to_string(),
1892                                function,
1893                            });
1894                        } else {
1895                            let err_msg = format!(
1896                                "Failed to get the name of the function. raw input: {input:?}"
1897                            );
1898
1899                            #[cfg(feature = "logging")]
1900                            error!(target: "stdout", "{}", &err_msg);
1901
1902                            return Err(LlamaCoreError::Operation(err_msg));
1903                        }
1904                    }
1905                    None => {
1906                        let err_msg = format!("Failed to parse the JSON string. JSON: {input}");
1907
1908                        #[cfg(feature = "logging")]
1909                        error!(target: "stdout", "{}", &err_msg);
1910
1911                        return Err(LlamaCoreError::Operation(err_msg));
1912                    }
1913                }
1914            }
1915
1916            let parsed = ParseResult {
1917                raw: input.to_owned(),
1918                content: None,
1919                tool_calls,
1920            };
1921
1922            #[cfg(feature = "logging")]
1923            info!(target: "stdout", "parsed result: {parsed:?}");
1924
1925            Ok(parsed)
1926        }
1927        PromptTemplateType::Qwen3NoThink => {
1928            #[cfg(feature = "logging")]
1929            info!(target: "stdout", "raw input: {input:?}");
1930
1931            match regex::Regex::new(r"(?s)<tool_call>((.|\r|\n)*?)</tool_call>") {
1932                Ok(re) => {
1933                    let mut values: Vec<serde_json::Value> = vec![];
1934                    for cap in re.captures_iter(input) {
1935                        let matched = cap[1].trim();
1936
1937                        #[cfg(feature = "logging")]
1938                        info!(target: "stdout", "captured: {matched}");
1939
1940                        match serde_json::from_str::<serde_json::Value>(matched) {
1941                            Ok(value) => values.push(value),
1942                            Err(e) => {
1943                                let err_msg = format!(
1944                                    "Failed to deserialize generated tool calls. Reason: {e}"
1945                                );
1946
1947                                #[cfg(feature = "logging")]
1948                                error!(target: "stdout", "{}", &err_msg);
1949
1950                                return Err(LlamaCoreError::Operation(err_msg));
1951                            }
1952                        }
1953                    }
1954
1955                    let mut tool_calls: Vec<ToolCall> = vec![];
1956                    for value in values.iter() {
1957                        let name = match value.get("name") {
1958                            Some(name) => name.to_string().replace("\"", ""),
1959                            None => {
1960                                let err_msg = format!(
1961                                    "Failed to get the name of the function. Tool call: {value:?}"
1962                                );
1963
1964                                #[cfg(feature = "logging")]
1965                                error!(target: "stdout", "{}", &err_msg);
1966
1967                                return Err(LlamaCoreError::Operation(err_msg));
1968                            }
1969                        };
1970
1971                        let arguments = match value.get("arguments") {
1972                            Some(arguments) => {
1973                                if arguments.is_string() {
1974                                    arguments.as_str().unwrap().to_string()
1975                                } else if arguments.is_object() {
1976                                    let map = arguments.as_object().unwrap();
1977
1978                                    #[cfg(feature = "logging")]
1979                                    info!(target: "stdout", "func arguments: {map:?}");
1980
1981                                    serde_json::to_string(map).unwrap()
1982                                } else {
1983                                    serde_json::to_string(arguments).unwrap()
1984                                }
1985                            }
1986                            None => {
1987                                let err_msg = format!(
1988                                    "Failed to get the arguments of the function. Tool call: {value:?}"
1989                                );
1990
1991                                #[cfg(feature = "logging")]
1992                                error!(target: "stdout", "{}", &err_msg);
1993
1994                                return Err(LlamaCoreError::Operation(err_msg));
1995                            }
1996                        };
1997
1998                        let function = Function { name, arguments };
1999
2000                        let tool_call = ToolCall {
2001                            id: "call_abc123".to_string(),
2002                            ty: "function".to_string(),
2003                            function,
2004                        };
2005
2006                        tool_calls.push(tool_call);
2007                    }
2008
2009                    let parsed = if tool_calls.is_empty() {
2010                        ParseResult {
2011                            raw: input.to_owned(),
2012                            content: Some(input.to_owned()),
2013                            tool_calls: vec![],
2014                        }
2015                    } else {
2016                        ParseResult {
2017                            raw: input.to_owned(),
2018                            content: None,
2019                            tool_calls,
2020                        }
2021                    };
2022
2023                    #[cfg(feature = "logging")]
2024                    info!(target: "stdout", "parsed result: {parsed:?}");
2025
2026                    Ok(parsed)
2027                }
2028                Err(e) => {
2029                    let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
2030
2031                    #[cfg(feature = "logging")]
2032                    error!(target: "stdout", "{}", &err_msg);
2033
2034                    Err(LlamaCoreError::Operation(err_msg))
2035                }
2036            }
2037        }
2038        _ => {
2039            let err_msg = format!(
2040                "The tool use is only supported for prompt templates: {}, {}, {}, {}, {}, {}, {}, {}, {}, and {}.",
2041                PromptTemplateType::MistralTool,
2042                PromptTemplateType::ChatMLTool,
2043                PromptTemplateType::GroqLlama3Tool,
2044                PromptTemplateType::Llama3Tool,
2045                PromptTemplateType::InternLM2Tool,
2046                PromptTemplateType::NemotronTool,
2047                PromptTemplateType::FunctionaryV32,
2048                PromptTemplateType::MistralSmallTool,
2049                PromptTemplateType::Llama4Chat,
2050                PromptTemplateType::Qwen3NoThink,
2051            );
2052
2053            #[cfg(feature = "logging")]
2054            error!(target: "stdout", "{}", &err_msg);
2055
2056            Err(LlamaCoreError::Operation(err_msg))
2057        }
2058    }
2059}
2060
2061fn check_model_metadata(
2062    chat_request: &ChatCompletionRequest,
2063) -> Result<GgmlMetadata, LlamaCoreError> {
2064    let mut should_update = false;
2065    let mut metadata = get_model_metadata(chat_request.model.as_ref())?;
2066
2067    // check if necessary to update `image`
2068    if metadata.prompt_template.is_image_supported() {
2069        if let Some(ChatCompletionRequestMessage::User(user_message)) = chat_request.messages.last()
2070        {
2071            if let ChatCompletionUserMessageContent::Parts(parts) = user_message.content() {
2072                for part in parts {
2073                    if let ContentPart::Image(image_part) = part {
2074                        let image = image_part.image();
2075
2076                        if image.is_url() {
2077                            let err_msg = "The image is provided in URL format. Only base64 format is supported.".to_string();
2078
2079                            #[cfg(feature = "logging")]
2080                            error!(target: "stdout", "{}", &err_msg);
2081
2082                            return Err(LlamaCoreError::Operation(err_msg));
2083                        } else {
2084                            #[cfg(feature = "logging")]
2085                            info!(target: "stdout", "The image is provided in base64 format.");
2086
2087                            // TODO: now only support a single image
2088
2089                            break;
2090                        }
2091                    }
2092                }
2093            }
2094        }
2095    }
2096
2097    // check if necessary to update temperature
2098    if let Some(temp) = chat_request.temperature {
2099        if metadata.temperature != temp {
2100            // update temperature
2101            metadata.temperature = temp;
2102
2103            if !should_update {
2104                should_update = true;
2105            }
2106        }
2107    }
2108
2109    // check if necessary to update top_p
2110    if let Some(top_p) = chat_request.top_p {
2111        if metadata.top_p != top_p {
2112            // update top_p
2113            metadata.top_p = top_p;
2114
2115            if !should_update {
2116                should_update = true;
2117            }
2118        }
2119    }
2120
2121    // check if necessary to update frequency_penalty
2122    if let Some(frequency_penalty) = chat_request.frequency_penalty {
2123        if metadata.frequency_penalty != frequency_penalty {
2124            // update frequency_penalty
2125            metadata.frequency_penalty = frequency_penalty;
2126
2127            if !should_update {
2128                should_update = true;
2129            }
2130        }
2131    }
2132
2133    // check if necessary to update presence_penalty
2134    if let Some(presence_penalty) = chat_request.presence_penalty {
2135        if metadata.presence_penalty != presence_penalty {
2136            // update presence_penalty
2137            metadata.presence_penalty = presence_penalty;
2138
2139            if !should_update {
2140                should_update = true;
2141            }
2142        }
2143    }
2144
2145    // check if the `embedding` option is disabled
2146    if metadata.embeddings {
2147        metadata.embeddings = false;
2148
2149        if !should_update {
2150            should_update = true;
2151        }
2152    }
2153
2154    if should_update {
2155        #[cfg(feature = "logging")]
2156        info!(target: "stdout", "Update the model metadata.");
2157
2158        // update the target graph with the new metadata
2159        update_model_metadata(chat_request.model.as_ref(), &metadata)?;
2160    }
2161
2162    Ok(metadata)
2163}
2164
2165fn update_n_predict(
2166    chat_request: &ChatCompletionRequest,
2167    metadata: &mut GgmlMetadata,
2168    available_completion_tokens: u64,
2169) -> Result<(), LlamaCoreError> {
2170    let mut should_update = false;
2171
2172    #[cfg(feature = "logging")]
2173    info!(target: "stdout", "n_predict: {}", metadata.n_predict);
2174
2175    // From high to low priority
2176    // 1. chat_request.max_tokens & chat_request.max_completion_tokens
2177    // 2. available_completion_tokens
2178    // 3. n_predict
2179
2180    if let Some(max_completion_tokens) = chat_request.max_completion_tokens {
2181        if metadata.n_predict != max_completion_tokens {
2182            #[cfg(feature = "logging")]
2183            info!(target: "stdout", "Update n_predict with max_completion_tokens from {} to {}", metadata.n_predict, max_completion_tokens);
2184
2185            metadata.n_predict = max_completion_tokens;
2186
2187            if !should_update {
2188                should_update = true;
2189            }
2190        }
2191    }
2192
2193    // TODO: remove this condition after [Issue #3958 on WasmEdge](https://github.com/WasmEdge/WasmEdge/issues/3958) is fixed
2194    if metadata.n_predict == -2 {
2195        #[cfg(feature = "logging")]
2196        info!(target: "stdout", "Update n_predict with available_completion_tokens from {} to {}", metadata.n_predict, available_completion_tokens);
2197
2198        // update n_predict
2199        metadata.n_predict = available_completion_tokens as i32;
2200
2201        if !should_update {
2202            should_update = true;
2203        }
2204    }
2205
2206    if metadata.n_predict == -1
2207        || (metadata.n_predict > 0 && metadata.n_predict < available_completion_tokens as i32)
2208        || (metadata.n_predict < 0 && metadata.n_predict != -2)
2209    // TODO: remove this condition after [Issue #3958 on WasmEdge](https://github.com/WasmEdge/WasmEdge/issues/3958) is fixed
2210    {
2211        #[cfg(feature = "logging")]
2212        info!(target: "stdout", "Update n_predict with available_completion_tokens from {} to {}", metadata.n_predict, available_completion_tokens);
2213
2214        // update n_predict
2215        metadata.n_predict = available_completion_tokens as i32;
2216
2217        if !should_update {
2218            should_update = true;
2219        }
2220    }
2221
2222    if should_update {
2223        #[cfg(feature = "logging")]
2224        info!(target: "stdout", "Update the model metadata.");
2225
2226        // update the target graph with the new metadata
2227        update_model_metadata(chat_request.model.as_ref(), metadata)?;
2228    }
2229
2230    Ok(())
2231}
2232
2233/// Build post-processing for output based on template type
2234fn post_process(
2235    output: impl AsRef<str>,
2236    template_ty: &PromptTemplateType,
2237) -> Result<String, String> {
2238    let output = if *template_ty == PromptTemplateType::Baichuan2 {
2239        if output.as_ref().contains("用户:") {
2240            output.as_ref().trim_end_matches("用户:").trim().to_owned()
2241        } else {
2242            output.as_ref().trim().to_owned()
2243        }
2244    } else if *template_ty == PromptTemplateType::OpenChat {
2245        if output.as_ref().contains("<|end_of_turn|>") {
2246            output
2247                .as_ref()
2248                .trim_end_matches("<|end_of_turn|>")
2249                .trim()
2250                .to_owned()
2251        } else {
2252            output.as_ref().trim().to_owned()
2253        }
2254    } else if *template_ty == PromptTemplateType::GemmaInstruct
2255        || *template_ty == PromptTemplateType::Gemma3
2256    {
2257        let s = output.as_ref().trim();
2258        if s.ends_with("<end_of_turn>") {
2259            s.trim_end_matches("<end_of_turn>").trim().to_owned()
2260        } else {
2261            s.to_owned()
2262        }
2263    } else if *template_ty == PromptTemplateType::ChatML
2264        || *template_ty == PromptTemplateType::ChatMLTool
2265        || *template_ty == PromptTemplateType::InternLM2Tool
2266        || *template_ty == PromptTemplateType::MiniCPMV
2267    {
2268        let mut s = output.as_ref().trim();
2269        if s.ends_with("<|endoftext|>") {
2270            s = s.trim_end_matches("<|endoftext|>").trim();
2271        }
2272
2273        if s.starts_with(":") {
2274            s = s.trim_start_matches(":").trim();
2275        }
2276
2277        // handle Qwen3 empty think tags
2278        let x = {
2279            let pat = r#"<think>
2280
2281</think>
2282"#;
2283            if s.contains(pat) {
2284                let x = s.replace(pat, "");
2285                if x.starts_with("()") {
2286                    x.trim_start_matches("()").to_owned()
2287                } else {
2288                    x.to_owned()
2289                }
2290            } else {
2291                s.to_owned()
2292            }
2293        };
2294        s = x.trim();
2295
2296        if s.contains("<|im_start|>") && s.contains("<|im_end|>") {
2297            let idx_start = s.find("<|im_start|>").unwrap();
2298            let idx_end = s.find("<|im_end|>").unwrap();
2299
2300            match idx_start <= idx_end {
2301                true => s.split("<|im_start|>").collect::<Vec<_>>()[0]
2302                    .trim()
2303                    .to_owned(),
2304                false => s.split("<|im_end|>").collect::<Vec<_>>()[0]
2305                    .trim()
2306                    .to_owned(),
2307            }
2308        } else if s.contains("<|im_start|>") {
2309            s.split("<|im_start|>").collect::<Vec<_>>()[0]
2310                .trim()
2311                .to_owned()
2312        } else if s.contains("<|im_end|>") {
2313            let output = s.trim_end_matches("<|im_end|>").trim();
2314            if output.starts_with(": ") {
2315                output.trim_start_matches(": ").to_owned()
2316            } else {
2317                output.to_owned()
2318            }
2319        } else {
2320            s.to_owned()
2321        }
2322    } else if *template_ty == PromptTemplateType::Zephyr
2323        || *template_ty == PromptTemplateType::MistralLite
2324        || *template_ty == PromptTemplateType::MistralTool
2325        || *template_ty == PromptTemplateType::MistralInstruct
2326        || *template_ty == PromptTemplateType::MistralSmallChat
2327        || *template_ty == PromptTemplateType::MistralSmallTool
2328        || *template_ty == PromptTemplateType::BreezeInstruct
2329    {
2330        if output.as_ref().contains("</s><") {
2331            output.as_ref().trim_end_matches("</s><").trim().to_owned()
2332        } else if output.as_ref().contains("</s>") {
2333            output
2334                .as_ref()
2335                .strip_suffix("</s>")
2336                .unwrap()
2337                .trim()
2338                .to_owned()
2339        } else {
2340            output.as_ref().trim().to_owned()
2341        }
2342    } else if *template_ty == PromptTemplateType::DeepseekChat {
2343        if output.as_ref().contains("<|end_of_sentence|>") {
2344            output
2345                .as_ref()
2346                .trim_end_matches("<|end_of_sentence|>")
2347                .trim()
2348                .replace("<|end_of_sentence|>", " ")
2349                .trim()
2350                .to_owned()
2351        } else {
2352            output.as_ref().trim().to_owned()
2353        }
2354    } else if *template_ty == PromptTemplateType::HumanAssistant {
2355        if output.as_ref().contains("Human:") {
2356            output.as_ref().trim_end_matches("Human:").trim().to_owned()
2357        } else {
2358            output.as_ref().trim().to_owned()
2359        }
2360    } else if *template_ty == PromptTemplateType::SolarInstruct {
2361        let s = output.as_ref().trim();
2362
2363        if s.starts_with("### Answer") {
2364            let s = s.trim_start_matches("###").trim();
2365
2366            if s.starts_with("Answer:\n") {
2367                s.replace("Answer:\n", "Answer: ")
2368            } else {
2369                s.to_owned()
2370            }
2371        } else {
2372            s.to_owned()
2373        }
2374    } else if *template_ty == PromptTemplateType::Llama2Chat
2375        || *template_ty == PromptTemplateType::NemotronTool
2376        || *template_ty == PromptTemplateType::NemotronChat
2377    {
2378        let s = output.as_ref().trim();
2379        if s.ends_with("</s>") {
2380            s.trim_end_matches("</s>").trim().to_owned()
2381        } else {
2382            s.to_owned()
2383        }
2384    } else if *template_ty == PromptTemplateType::Llama3Chat
2385        || *template_ty == PromptTemplateType::GroqLlama3Tool
2386        || *template_ty == PromptTemplateType::Llama3Tool
2387        || *template_ty == PromptTemplateType::FunctionaryV32
2388    {
2389        let s = output.as_ref().trim();
2390        if s.ends_with("<|eot_id|>") {
2391            s.trim_end_matches("<|eot_id|>").trim().to_owned()
2392        } else {
2393            s.to_owned()
2394        }
2395    } else if *template_ty == PromptTemplateType::Phi3Chat {
2396        let s = output.as_ref().trim();
2397        if s.ends_with("<|end|>") {
2398            s.trim_end_matches("<|end|>").trim().to_owned()
2399        } else {
2400            s.to_owned()
2401        }
2402    } else if *template_ty == PromptTemplateType::Phi4Chat {
2403        let mut s = output.as_ref().trim();
2404
2405        if s.starts_with("think>") {
2406            s = s.trim_start_matches("think>").trim();
2407        }
2408
2409        if s.ends_with("<|im_end|>") {
2410            s.trim_end_matches("<|im_end|>").trim().to_owned()
2411        } else if s.ends_with("<|end|>") {
2412            s.trim_end_matches("<|end|>").trim().to_owned()
2413        } else {
2414            s.to_owned()
2415        }
2416    } else if *template_ty == PromptTemplateType::FunctionaryV31 {
2417        let mut s = output.as_ref().trim();
2418        if s.ends_with("<|eot_id|>") {
2419            s = s.trim_end_matches("<|eot_id|>").trim();
2420        }
2421        if s.ends_with("<|eom_id|>") {
2422            s = s.trim_end_matches("<|eom_id|>").trim();
2423        }
2424        s.to_owned()
2425    } else if *template_ty == PromptTemplateType::MoxinChat
2426        || *template_ty == PromptTemplateType::MoxinInstruct
2427    {
2428        let s = output.as_ref().trim();
2429        if s.ends_with("</s>") {
2430            s.trim_end_matches("</s>").trim().to_owned()
2431        } else if s.ends_with("[INST]") {
2432            s.trim_end_matches("[INST]").trim().to_owned()
2433        } else {
2434            s.to_owned()
2435        }
2436    } else if *template_ty == PromptTemplateType::Falcon3 {
2437        let s = output.as_ref().trim();
2438        if s.ends_with("<|endoftext|>") {
2439            s.trim_end_matches("<|endoftext|>").trim().to_owned()
2440        } else {
2441            s.to_owned()
2442        }
2443    } else if *template_ty == PromptTemplateType::Megrez {
2444        let s = output.as_ref().trim();
2445        if s.ends_with("<|turn_end|>") {
2446            s.trim_end_matches("<|turn_end|>").trim().to_owned()
2447        } else {
2448            s.to_owned()
2449        }
2450    } else if *template_ty == PromptTemplateType::Qwen2vl
2451        || *template_ty == PromptTemplateType::Qwen3NoThink
2452        || *template_ty == PromptTemplateType::ChatMLThink
2453    {
2454        let mut s = output.as_ref().trim();
2455
2456        if s.starts_with(":") {
2457            s = s.trim_start_matches(":").trim();
2458        }
2459
2460        if s.starts_with("</think>") {
2461            s = s.trim_start_matches("</think>").trim();
2462        }
2463
2464        if s.ends_with("<|im_end|>") {
2465            s.trim_end_matches("<|im_end|>").trim().to_owned()
2466        } else {
2467            s.to_owned()
2468        }
2469    } else if *template_ty == PromptTemplateType::VicunaLlava {
2470        let s = output.as_ref().trim();
2471        if s.ends_with("</s>") {
2472            s.trim_end_matches("</s>").trim().to_owned()
2473        } else {
2474            s.to_owned()
2475        }
2476    } else if *template_ty == PromptTemplateType::ExaoneDeepChat
2477        || *template_ty == PromptTemplateType::ExaoneChat
2478    {
2479        let mut s = output.as_ref().trim();
2480
2481        if s.ends_with("[|endofturn|]") {
2482            s = s.trim_end_matches("[|endofturn|]").trim();
2483        }
2484
2485        s.to_owned()
2486    } else if *template_ty == PromptTemplateType::Llama4Chat {
2487        let mut s = output.as_ref().trim();
2488
2489        if s.ends_with("<|eot|>") {
2490            s = s.trim_end_matches("<|eot|>").trim();
2491        }
2492
2493        s.to_owned()
2494    } else if *template_ty == PromptTemplateType::Smolvl {
2495        let mut s = output.as_ref().trim();
2496
2497        if s.starts_with(":") {
2498            s = s.trim_start_matches(":").trim();
2499        }
2500
2501        if s.ends_with("<end_of_utterance>") {
2502            s = s.trim_end_matches("<end_of_utterance>").trim();
2503        }
2504
2505        if s.contains("<end_of_utterance>:") {
2506            let parts = s.split("<end_of_utterance>:").collect::<Vec<_>>();
2507            parts.last().unwrap().trim().to_owned()
2508        } else {
2509            s.to_owned()
2510        }
2511    } else {
2512        output.as_ref().trim().to_owned()
2513    };
2514
2515    Ok(output)
2516}
2517
2518/// Build the chat prompt from the chat messages.
2519///
2520/// # Arguments
2521///
2522/// * `model_name`: The name of the model.
2523///
2524/// * `chat_request`: The chat request.
2525///
2526/// # Returns
2527///
2528/// A tuple containing the prompt, the number of available tokens for completions, and a boolean indicating whether tools are used.
2529fn build_prompt(
2530    model_name: Option<&String>,
2531    chat_request: &mut ChatCompletionRequest,
2532) -> Result<(String, u64, bool), LlamaCoreError> {
2533    let metadata = get_model_metadata(model_name)?;
2534    let ctx_size = metadata.ctx_size as u64;
2535    let chat_prompt = ChatPrompt::from(metadata.prompt_template);
2536
2537    // compute max prompt tokens, which is 80% of the context size
2538    let max_prompt_tokens = ctx_size * 4 / 5;
2539
2540    loop {
2541        // ! DO NOT REMOVE
2542        {
2543            // // build prompt
2544            // let prompt = match chat_prompt.build(&mut chat_request.messages) {
2545            //     Ok(prompt) => prompt,
2546            //     Err(e) => {
2547            //         let err_msg = format!("Fail to build chat prompts. Reason: {}", e);
2548
2549            //         #[cfg(feature = "logging")]
2550            //         error!(target: "stdout", "{}", &err_msg);
2551
2552            //         return Err(LlamaCoreError::Operation(err_msg));
2553            //     }
2554            // };
2555        }
2556
2557        if chat_request.messages.is_empty() {
2558            let err_msg = "The messages in the chat request are empty.";
2559
2560            #[cfg(feature = "logging")]
2561            error!(target: "stdout", "{err_msg}");
2562
2563            return Err(LlamaCoreError::Operation(err_msg.to_owned()));
2564        }
2565
2566        #[cfg(feature = "logging")]
2567        {
2568            let mut role_chain = String::new();
2569            for (idx, message) in chat_request.messages.iter().enumerate() {
2570                if idx == chat_request.messages.len() - 1 {
2571                    role_chain.push_str(&format!("{}", message.role()));
2572                } else {
2573                    role_chain.push_str(&format!("{} -> ", message.role()));
2574                }
2575            }
2576            info!(target: "stdout", "Role chain: {role_chain}");
2577        }
2578
2579        let (prompt, tool_use) = match chat_request.tool_choice.as_ref() {
2580            Some(tool_choice) => match tool_choice {
2581                ToolChoice::None => {
2582                    match chat_prompt.build_with_tools(&mut chat_request.messages, Some(&[])) {
2583                        Ok(prompt) => (prompt, false),
2584                        Err(e) => {
2585                            let err_msg = format!("Fail to build chat prompts. Reason: {e}");
2586
2587                            #[cfg(feature = "logging")]
2588                            error!(target: "stdout", "{}", &err_msg);
2589
2590                            return Err(LlamaCoreError::Operation(err_msg));
2591                        }
2592                    }
2593                }
2594                _ => match chat_request.tools.as_ref() {
2595                    Some(tools) => match chat_prompt
2596                        .build_with_tools(&mut chat_request.messages, Some(tools.as_slice()))
2597                    {
2598                        Ok(prompt) => (prompt, true),
2599                        Err(e) => {
2600                            let err_msg = format!("Fail to build chat prompts. Reason: {e}");
2601
2602                            #[cfg(feature = "logging")]
2603                            error!(target: "stdout", "{}", &err_msg);
2604
2605                            return Err(LlamaCoreError::Operation(err_msg));
2606                        }
2607                    },
2608                    None => {
2609                        #[cfg(feature = "logging")]
2610                        warn!(target: "stdout", "The tool choice without tools is not supported.");
2611
2612                        match chat_prompt.build_with_tools(&mut chat_request.messages, None) {
2613                            Ok(prompt) => (prompt, false),
2614                            Err(e) => {
2615                                let err_msg = format!("Fail to build chat prompts. Reason: {e}");
2616
2617                                #[cfg(feature = "logging")]
2618                                error!(target: "stdout", "{}", &err_msg);
2619
2620                                return Err(LlamaCoreError::Operation(err_msg));
2621                            }
2622                        }
2623                    }
2624                },
2625            },
2626            None => match chat_prompt.build_with_tools(&mut chat_request.messages, None) {
2627                Ok(prompt) => (prompt, false),
2628                Err(e) => {
2629                    let err_msg = format!("Fail to build chat prompts. Reason: {e}");
2630
2631                    #[cfg(feature = "logging")]
2632                    error!(target: "stdout", "{}", &err_msg);
2633
2634                    return Err(LlamaCoreError::Operation(err_msg));
2635                }
2636            },
2637        };
2638        #[cfg(feature = "logging")]
2639        info!(target: "stdout", "Try to set prompt: {prompt}");
2640
2641        // set prompt
2642        set_prompt(model_name, &prompt)?;
2643
2644        // Retrieve the number of prompt tokens.
2645        let token_info = get_token_info_by_graph_name(model_name)?;
2646
2647        match token_info.prompt_tokens > max_prompt_tokens {
2648            true => {
2649                match chat_request.messages[0].role() {
2650                    ChatCompletionRole::System => {
2651                        if chat_request.messages.len() > 2 {
2652                            #[cfg(feature = "logging")]
2653                            info!(target: "stdout", "Prune chat history: current length {}", chat_request.messages.len());
2654
2655                            // remove user_1 if it exists
2656                            // For example, `system -> user_1 -> ... -> user_2 -> ... -> user_latest` will be converted to `system -> ... -> user_2 -> ... -> user_latest`
2657                            if chat_request.messages[1].role() == ChatCompletionRole::User {
2658                                let user_message = chat_request.messages.remove(1);
2659
2660                                #[cfg(feature = "logging")]
2661                                info!(target: "stdout", "Remove a user message from the chat history: {user_message:?}");
2662                            }
2663
2664                            // remove all messages until the message is of `user`
2665                            // For example, `system -> ... -> user_2 -> ... -> user_latest` will be converted to `system -> user_2 -> ... -> user_latest`
2666                            while chat_request.messages[1].role() != ChatCompletionRole::User {
2667                                let message = chat_request.messages.remove(1);
2668
2669                                #[cfg(feature = "logging")]
2670                                info!(target: "stdout", "Remove a {} message from the chat history: {:?}", message.role(), message);
2671
2672                                if chat_request.messages.len() == 1 {
2673                                    let err_msg = format!("The last message in the chat history should be a user message, but found a {} message.", message.role());
2674
2675                                    #[cfg(feature = "logging")]
2676                                    error!(target: "stdout", "{err_msg}");
2677
2678                                    return Err(LlamaCoreError::Operation(err_msg));
2679                                }
2680                            }
2681                        } else if token_info.prompt_tokens > ctx_size {
2682                            let err_msg = format!(
2683                                    "The number of prompt tokens ({}) is greater than the context size ({}). Please increase the context size, or simplify the input message.",
2684                                    token_info.prompt_tokens, ctx_size
2685                                );
2686
2687                            #[cfg(feature = "logging")]
2688                            error!(target: "stdout", "{}", &err_msg);
2689
2690                            return Err(LlamaCoreError::Operation(err_msg));
2691                        } else {
2692                            return Ok((prompt, ctx_size - token_info.prompt_tokens, tool_use));
2693                        }
2694                    }
2695                    ChatCompletionRole::User => {
2696                        if chat_request.messages.len() > 1 {
2697                            // user_1 -> ... -> user_2 -> ... -> user_latest
2698
2699                            // remove user_1 if it exists
2700                            // For example, `user_1 -> ... -> user_2 -> ... -> user_latest` will be converted to `... -> user_2 -> ... -> user_latest`
2701                            if chat_request.messages[0].role() == ChatCompletionRole::User {
2702                                let user_message = chat_request.messages.remove(0);
2703
2704                                #[cfg(feature = "logging")]
2705                                info!(target: "stdout", "Remove a user message from the chat history: {user_message:?}");
2706                            }
2707
2708                            // remove all messages until the message is of `user`
2709                            // For example, `... -> user_2 -> ... -> user_latest` will be converted to `user_2 -> ... -> user_latest`
2710                            while chat_request.messages[0].role() != ChatCompletionRole::User {
2711                                let message = chat_request.messages.remove(0);
2712
2713                                #[cfg(feature = "logging")]
2714                                info!(target: "stdout", "Remove a {} message from the chat history: {:?}", message.role(), message);
2715
2716                                if chat_request.messages.is_empty() {
2717                                    let err_msg = format!("The last message in the chat history should be a user message, but found a {} message.", message.role());
2718
2719                                    #[cfg(feature = "logging")]
2720                                    error!(target: "stdout", "{err_msg}");
2721
2722                                    return Err(LlamaCoreError::Operation(err_msg));
2723                                }
2724                            }
2725                        } else if token_info.prompt_tokens > ctx_size {
2726                            let err_msg = format!(
2727                                    "The number of prompt tokens ({}) is greater than the context size ({}). Please increase the context size, or simplify the input message.",
2728                                    token_info.prompt_tokens, ctx_size
2729                                );
2730
2731                            #[cfg(feature = "logging")]
2732                            error!(target: "stdout", "{}", &err_msg);
2733
2734                            return Err(LlamaCoreError::Operation(err_msg));
2735                        } else {
2736                            return Ok((prompt, ctx_size - token_info.prompt_tokens, tool_use));
2737                        }
2738                    }
2739                    _ => {
2740                        #[cfg(feature = "logging")]
2741                        info!(target: "stdout", "remove a {} message from the message queue", chat_request.messages[0].role());
2742
2743                        chat_request.messages.remove(0);
2744                    }
2745                }
2746
2747                continue;
2748            }
2749            false => return Ok((prompt, ctx_size - max_prompt_tokens, tool_use)),
2750        }
2751    }
2752}
2753
2754fn set_prompt(model_name: Option<&String>, prompt: impl AsRef<str>) -> Result<(), LlamaCoreError> {
2755    let chat_graphs = match CHAT_GRAPHS.get() {
2756        Some(chat_graphs) => chat_graphs,
2757        None => {
2758            let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
2759
2760            #[cfg(feature = "logging")]
2761            error!(target: "stdout", "{}", &err_msg);
2762
2763            return Err(LlamaCoreError::Operation(err_msg.into()));
2764        }
2765    };
2766
2767    let mut chat_graphs = chat_graphs.lock().map_err(|e| {
2768        let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
2769
2770        #[cfg(feature = "logging")]
2771        error!(target: "stdout", "{}", &err_msg);
2772
2773        LlamaCoreError::Operation(err_msg)
2774    })?;
2775
2776    match model_name {
2777        Some(model_name) => {
2778            #[cfg(feature = "logging")]
2779            info!(target: "stdout", "Set prompt to the chat model named {model_name}");
2780
2781            match chat_graphs.contains_key(model_name) {
2782                true => {
2783                    let graph = chat_graphs.get_mut(model_name).unwrap();
2784                    let tensor_data = prompt.as_ref().as_bytes().to_vec();
2785                    set_tensor_data_u8(graph, 0, &tensor_data)
2786                }
2787                false => match chat_graphs.iter_mut().next() {
2788                    Some((_, graph)) => {
2789                        let tensor_data = prompt.as_ref().as_bytes().to_vec();
2790                        set_tensor_data_u8(graph, 0, &tensor_data)
2791                    }
2792                    None => {
2793                        let err_msg = "There is no model available in the chat graphs.";
2794
2795                        #[cfg(feature = "logging")]
2796                        error!(target: "stdout", "{}", &err_msg);
2797
2798                        Err(LlamaCoreError::Operation(err_msg.into()))
2799                    }
2800                },
2801            }
2802        }
2803        None => {
2804            #[cfg(feature = "logging")]
2805            info!(target: "stdout", "Set prompt to the default chat model.");
2806
2807            match chat_graphs.iter_mut().next() {
2808                Some((_, graph)) => {
2809                    let tensor_data = prompt.as_ref().as_bytes().to_vec();
2810                    set_tensor_data_u8(graph, 0, &tensor_data)
2811                }
2812                None => {
2813                    let err_msg = "There is no model available in the chat graphs while trying to set prompt to the default model.";
2814
2815                    #[cfg(feature = "logging")]
2816                    error!(target: "stdout", "{err_msg}");
2817
2818                    Err(LlamaCoreError::Operation(err_msg.into()))
2819                }
2820            }
2821        }
2822    }
2823}
2824
2825/// Get a copy of the metadata of the model.
2826fn get_model_metadata(model_name: Option<&String>) -> Result<GgmlMetadata, LlamaCoreError> {
2827    let chat_graphs = match CHAT_GRAPHS.get() {
2828        Some(chat_graphs) => chat_graphs,
2829        None => {
2830            let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
2831
2832            #[cfg(feature = "logging")]
2833            error!(target: "stdout", "{err_msg}");
2834
2835            return Err(LlamaCoreError::Operation(err_msg.into()));
2836        }
2837    };
2838
2839    let chat_graphs = chat_graphs.lock().map_err(|e| {
2840        let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
2841
2842        #[cfg(feature = "logging")]
2843        error!(target: "stdout", "{}", &err_msg);
2844
2845        LlamaCoreError::Operation(err_msg)
2846    })?;
2847
2848    match model_name {
2849        Some(model_name) => match chat_graphs.contains_key(model_name) {
2850            true => {
2851                let graph = chat_graphs.get(model_name).unwrap();
2852                Ok(graph.metadata.clone())
2853            }
2854            false => match chat_graphs.iter().next() {
2855                Some((_, graph)) => Ok(graph.metadata.clone()),
2856                None => {
2857                    let err_msg = "There is no model available in the chat graphs.";
2858
2859                    #[cfg(feature = "logging")]
2860                    error!(target: "stdout", "{}", &err_msg);
2861
2862                    Err(LlamaCoreError::Operation(err_msg.into()))
2863                }
2864            },
2865        },
2866        None => match chat_graphs.iter().next() {
2867            Some((_, graph)) => Ok(graph.metadata.clone()),
2868            None => {
2869                let err_msg = "There is no model available in the chat graphs.";
2870
2871                #[cfg(feature = "logging")]
2872                error!(target: "stdout", "{err_msg}");
2873
2874                Err(LlamaCoreError::Operation(err_msg.into()))
2875            }
2876        },
2877    }
2878}
2879
2880fn update_model_metadata(
2881    model_name: Option<&String>,
2882    metadata: &GgmlMetadata,
2883) -> Result<(), LlamaCoreError> {
2884    let config = match serde_json::to_string(metadata) {
2885        Ok(config) => config,
2886        Err(e) => {
2887            let err_msg = format!("Fail to serialize metadata to a JSON string. {e}");
2888
2889            #[cfg(feature = "logging")]
2890            error!(target: "stdout", "{}", &err_msg);
2891
2892            return Err(LlamaCoreError::Operation(err_msg));
2893        }
2894    };
2895
2896    let chat_graphs = match CHAT_GRAPHS.get() {
2897        Some(chat_graphs) => chat_graphs,
2898        None => {
2899            let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
2900
2901            #[cfg(feature = "logging")]
2902            error!(target: "stdout", "{err_msg}");
2903
2904            return Err(LlamaCoreError::Operation(err_msg.into()));
2905        }
2906    };
2907
2908    let mut chat_graphs = chat_graphs.lock().map_err(|e| {
2909        let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. Reason: {e}");
2910
2911        #[cfg(feature = "logging")]
2912        error!(target: "stdout", "{}", &err_msg);
2913
2914        LlamaCoreError::Operation(err_msg)
2915    })?;
2916
2917    match model_name {
2918        Some(model_name) => {
2919            match chat_graphs.contains_key(model_name) {
2920                true => {
2921                    let graph = chat_graphs.get_mut(model_name).unwrap();
2922                    // update metadata
2923                    set_tensor_data_u8(graph, 1, config.as_bytes())
2924                }
2925                false => match chat_graphs.iter_mut().next() {
2926                    Some((_, graph)) => {
2927                        // update metadata
2928                        set_tensor_data_u8(graph, 1, config.as_bytes())
2929                    }
2930                    None => {
2931                        let err_msg = "There is no model available in the chat graphs.";
2932
2933                        #[cfg(feature = "logging")]
2934                        error!(target: "stdout", "{}", &err_msg);
2935
2936                        Err(LlamaCoreError::Operation(err_msg.into()))
2937                    }
2938                },
2939            }
2940        }
2941        None => {
2942            match chat_graphs.iter_mut().next() {
2943                Some((_, graph)) => {
2944                    // update metadata
2945                    set_tensor_data_u8(graph, 1, config.as_bytes())
2946                }
2947                None => {
2948                    let err_msg = "There is no model available in the chat graphs.";
2949
2950                    #[cfg(feature = "logging")]
2951                    error!(target: "stdout", "{err_msg}");
2952
2953                    Err(LlamaCoreError::Operation(err_msg.into()))
2954                }
2955            }
2956        }
2957    }
2958}
2959
2960fn reset_model_metadata(model_name: Option<&String>) -> Result<(), LlamaCoreError> {
2961    // get metadata
2962    let metadata = get_model_metadata(model_name)?;
2963
2964    // update model with the original metadata
2965    update_model_metadata(model_name, &metadata)
2966}
2967
2968#[derive(Debug, Clone, Copy, PartialEq, Eq)]
2969enum ContextFullState {
2970    Message,
2971    Usage,
2972    Done,
2973    EndOfSequence,
2974}
2975
2976#[derive(Debug, Clone, Copy, PartialEq, Eq)]
2977enum StreamState {
2978    Usage,
2979    NoUsage,
2980    Done,
2981    EndOfSequence,
2982}
2983
2984#[derive(Debug, Clone, Copy, PartialEq, Eq)]
2985enum PromptTooLongState {
2986    Message,
2987    Usage,
2988    Done,
2989    EndOfSequence,
2990}
2991
2992struct ChatStream {
2993    id: String,
2994    model: Option<String>,
2995    include_usage: bool,
2996    context_full_state: ContextFullState,
2997    prompt_too_long_state: PromptTooLongState,
2998    stream_state: StreamState,
2999    cache: Option<VecDeque<String>>,
3000    is_waiting: bool,
3001    has_lock: bool,
3002}
3003impl ChatStream {
3004    fn new(
3005        model: Option<String>,
3006        id: String,
3007        include_usage: bool,
3008        cache: Option<Vec<String>>,
3009    ) -> Self {
3010        // Try to acquire lock
3011        let has_lock = CHAT_STREAM_ACTIVE
3012            .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
3013            .is_ok();
3014
3015        #[cfg(feature = "logging")]
3016        if !has_lock {
3017            info!(target: "stdout", "Lock acquisition failed in ChatStream::new, creating with waiting status");
3018        }
3019
3020        ChatStream {
3021            id,
3022            model,
3023            include_usage,
3024            context_full_state: ContextFullState::Message,
3025            prompt_too_long_state: PromptTooLongState::Message,
3026            stream_state: if include_usage {
3027                StreamState::Usage
3028            } else {
3029                StreamState::NoUsage
3030            },
3031            cache: cache.map(VecDeque::from),
3032            is_waiting: !has_lock,
3033            has_lock,
3034        }
3035    }
3036
3037    // Try to acquire lock, returns whether successful
3038    fn try_acquire_lock(&mut self) -> bool {
3039        if self.has_lock {
3040            return true;
3041        }
3042
3043        let acquired = CHAT_STREAM_ACTIVE
3044            .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
3045            .is_ok();
3046
3047        if acquired {
3048            self.has_lock = true;
3049            self.is_waiting = false;
3050        }
3051
3052        acquired
3053    }
3054}
3055impl Drop for ChatStream {
3056    fn drop(&mut self) {
3057        // Clean up is only needed if we have the lock or if stream was actually used
3058        if self.has_lock || (self.cache.is_none() && !self.is_waiting) {
3059            #[cfg(feature = "logging")]
3060            info!(target: "stdout", "Cleaning up context for ChatStream {}", &self.id);
3061
3062            match &self.model {
3063                Some(model_name) => {
3064                    match CHAT_GRAPHS.get() {
3065                        Some(chat_graphs) => {
3066                            match chat_graphs.lock() {
3067                                Ok(mut chat_graphs) => match chat_graphs.contains_key(model_name) {
3068                                    true => {
3069                                        let graph = chat_graphs.get_mut(model_name).unwrap();
3070
3071                                        // clean up the context
3072                                        if let Err(e) = graph.finish_single() {
3073                                            let err_msg = format!(
3074                                                "Failed to clean up the context. Reason: {e}"
3075                                            );
3076
3077                                            #[cfg(feature = "logging")]
3078                                            error!(target: "stdout", "{}", &err_msg);
3079
3080                                            #[cfg(not(feature = "logging"))]
3081                                            println!(
3082                                                "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3083                                                &err_msg
3084                                            );
3085                                        }
3086                                    }
3087                                    false => match chat_graphs.iter_mut().next() {
3088                                        Some((_, graph)) => {
3089                                            // clean up the context
3090                                            if let Err(e) = graph.finish_single() {
3091                                                let err_msg = format!(
3092                                                    "Failed to clean up the context. Reason: {e}"
3093                                                );
3094
3095                                                #[cfg(feature = "logging")]
3096                                                error!(target: "stdout", "{}", &err_msg);
3097
3098                                                #[cfg(not(feature = "logging"))]
3099                                                println!(
3100                                                    "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3101                                                    &err_msg
3102                                                );
3103                                            }
3104                                        }
3105                                        None => {
3106                                            let err_msg =
3107                                                "There is no model available in the chat graphs.";
3108
3109                                            #[cfg(feature = "logging")]
3110                                            error!(target: "stdout", "{}", &err_msg);
3111
3112                                            #[cfg(not(feature = "logging"))]
3113                                            println!(
3114                                                "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3115                                                &err_msg
3116                                            );
3117                                        }
3118                                    },
3119                                },
3120                                Err(e) => {
3121                                    let err_msg =
3122                                        format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
3123
3124                                    #[cfg(feature = "logging")]
3125                                    error!(target: "stdout", "{}", &err_msg);
3126
3127                                    #[cfg(not(feature = "logging"))]
3128                                    println!(
3129                                        "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3130                                        &err_msg
3131                                    );
3132                                }
3133                            }
3134                        }
3135                        None => {
3136                            let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
3137
3138                            #[cfg(feature = "logging")]
3139                            error!(target: "stdout", "{}", &err_msg);
3140
3141                            #[cfg(not(feature = "logging"))]
3142                            println!(
3143                                "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3144                                &err_msg
3145                            );
3146                        }
3147                    };
3148                }
3149                None => {
3150                    match CHAT_GRAPHS.get() {
3151                        Some(chat_graphs) => {
3152                            match chat_graphs.lock() {
3153                                Ok(mut chat_graphs) => match chat_graphs.iter_mut().next() {
3154                                    Some((_, graph)) => {
3155                                        // clean up the context
3156                                        if let Err(e) = graph.finish_single() {
3157                                            let err_msg = format!(
3158                                                "Failed to clean up the context. Reason: {e}"
3159                                            );
3160
3161                                            #[cfg(feature = "logging")]
3162                                            error!(target: "stdout", "{}", &err_msg);
3163
3164                                            #[cfg(not(feature = "logging"))]
3165                                            println!(
3166                                                "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3167                                                &err_msg
3168                                            );
3169                                        }
3170                                    }
3171                                    None => {
3172                                        let err_msg =
3173                                            "There is no model available in the chat graphs.";
3174
3175                                        #[cfg(feature = "logging")]
3176                                        error!(target: "stdout", "{err_msg}");
3177
3178                                        #[cfg(not(feature = "logging"))]
3179                                        println!(
3180                                            "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3181                                            err_msg
3182                                        );
3183                                    }
3184                                },
3185                                Err(e) => {
3186                                    let err_msg =
3187                                        format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
3188
3189                                    #[cfg(feature = "logging")]
3190                                    error!(target: "stdout", "{}", &err_msg);
3191
3192                                    #[cfg(not(feature = "logging"))]
3193                                    println!(
3194                                        "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3195                                        &err_msg
3196                                    );
3197                                }
3198                            }
3199                        }
3200                        None => {
3201                            let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
3202
3203                            #[cfg(feature = "logging")]
3204                            error!(target: "stdout", "{}", &err_msg);
3205
3206                            #[cfg(not(feature = "logging"))]
3207                            println!(
3208                                "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3209                                &err_msg
3210                            );
3211                        }
3212                    };
3213                }
3214            }
3215
3216            #[cfg(feature = "logging")]
3217            info!(target: "stdout", "Model context cleanup done!");
3218        }
3219
3220        // reset the model metadata
3221        if let Err(e) = reset_model_metadata(self.model.as_ref()) {
3222            let err_msg = format!("Fail to reset model metadata. Reason: {e}");
3223
3224            #[cfg(feature = "logging")]
3225            error!(target: "stdout", "{}", &err_msg);
3226
3227            #[cfg(not(feature = "logging"))]
3228            println!("[ERROR][llama_core] {}", &err_msg);
3229        }
3230        #[cfg(feature = "logging")]
3231        info!(target: "stdout", "Model metadata reset done!");
3232
3233        // When dropping a ChatStream that held the lock, check if there are waiting streams
3234        if self.has_lock {
3235            // Reset the atomic flag
3236            CHAT_STREAM_ACTIVE.store(false, Ordering::SeqCst);
3237
3238            #[cfg(feature = "logging")]
3239            info!(target: "stdout", "Lock from ChatStream {} released", &self.id);
3240
3241            // Wake up waiting streams
3242            if let Ok(mut queue) = get_chat_stream_waker_queue().lock() {
3243                if let Some(waker) = queue.pop_front() {
3244                    #[cfg(feature = "logging")]
3245                    info!(target: "stdout", "Waking up a waiting ChatStream");
3246
3247                    waker.wake();
3248                }
3249            }
3250        }
3251    }
3252}
3253impl futures::Stream for ChatStream {
3254    type Item = Result<String, LlamaCoreError>;
3255
3256    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
3257        let this = self.get_mut();
3258
3259        // If this is a waiting stream, try to acquire the lock
3260        if this.is_waiting {
3261            if !this.try_acquire_lock() {
3262                // Store the waker to be notified when the lock becomes available
3263                if let Ok(mut queue) = get_chat_stream_waker_queue().lock() {
3264                    // Remove any previous instance of this waker
3265                    queue.retain(|w| !w.will_wake(cx.waker()));
3266                    // Add this waker to the queue
3267                    queue.push_back(cx.waker().clone());
3268
3269                    #[cfg(feature = "logging")]
3270                    debug!(target: "stdout", "ChatStream {} is waiting for lock, added waker to queue", &this.id);
3271                }
3272
3273                return Poll::Pending;
3274            }
3275
3276            #[cfg(feature = "logging")]
3277            info!(target: "stdout", "ChatStream {} acquired lock and is now active", &this.id);
3278            // If we got here, we successfully acquired the lock and can proceed
3279        }
3280
3281        // Ensure we still have the lock
3282        if !this.has_lock && !this.try_acquire_lock() {
3283            // Lost the lock, need to wait
3284            this.is_waiting = true;
3285
3286            // Register waker to be notified when lock is available
3287            if let Ok(mut queue) = get_chat_stream_waker_queue().lock() {
3288                queue.retain(|w| !w.will_wake(cx.waker()));
3289                queue.push_back(cx.waker().clone());
3290            }
3291
3292            return Poll::Pending;
3293        }
3294
3295        if this.cache.is_none() {
3296            let res = compute_stream(
3297                this.model.clone(),
3298                this.id.clone(),
3299                this.include_usage,
3300                &mut this.prompt_too_long_state,
3301                &mut this.context_full_state,
3302                &mut this.stream_state,
3303            );
3304
3305            match res {
3306                Ok(x) => {
3307                    #[cfg(feature = "logging")]
3308                    info!(target: "stdout", "next item for ChatStream {}: {}", &this.id, &x);
3309
3310                    if x != "[GGML] End of sequence" && !x.is_empty() {
3311                        Poll::Ready(Some(Ok(x)))
3312                    } else {
3313                        // stopped
3314                        Poll::Ready(None)
3315                    }
3316                }
3317                Err(e) => Poll::Ready(Some(Err(e))),
3318            }
3319        } else {
3320            let x = this.cache.as_mut().unwrap().pop_front();
3321
3322            #[cfg(feature = "logging")]
3323            info!(target: "stdout", "Get the next item from the cache for ChatStream {}: {:?}", &this.id, &x);
3324
3325            match x {
3326                Some(x) => Poll::Ready(Some(Ok(x))),
3327                None => Poll::Ready(None),
3328            }
3329        }
3330    }
3331}
3332
3333/// Helper function to get or initialize the waker queue for waiting ChatStreams
3334fn get_chat_stream_waker_queue() -> &'static Mutex<VecDeque<Waker>> {
3335    CHAT_STREAM_WAKER_QUEUE.get_or_init(|| {
3336        #[cfg(feature = "logging")]
3337        info!(target: "stdout", "Initializing ChatStream waker queue");
3338        Mutex::new(VecDeque::new())
3339    })
3340}
3341
3342fn compute_stream(
3343    model_name: Option<String>,
3344    id: String,
3345    include_usage: bool,
3346    prompt_too_long_state: &mut PromptTooLongState,
3347    context_full_state: &mut ContextFullState,
3348    stream_state: &mut StreamState,
3349) -> Result<String, LlamaCoreError> {
3350    #[cfg(feature = "logging")]
3351    info!(target: "stdout", "Computing stream chunk for ChatStream {}", &id);
3352
3353    #[cfg(feature = "logging")]
3354    debug!(target: "stdout", "prompt_too_long_state: {:?}", *prompt_too_long_state);
3355    #[cfg(feature = "logging")]
3356    debug!(target: "stdout", "context_full_state: {:?}", *context_full_state);
3357    #[cfg(feature = "logging")]
3358    debug!(target: "stdout", "stream_state: {:?}", *stream_state);
3359
3360    if *prompt_too_long_state == PromptTooLongState::EndOfSequence
3361        || *context_full_state == ContextFullState::EndOfSequence
3362        || *stream_state == StreamState::EndOfSequence
3363    {
3364        #[cfg(feature = "logging")]
3365        info!(target: "stdout", "Return the chat stream chunk!");
3366
3367        return Ok("[GGML] End of sequence".to_string());
3368    }
3369
3370    let chat_graphs = match CHAT_GRAPHS.get() {
3371        Some(chat_graphs) => chat_graphs,
3372        None => {
3373            let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
3374
3375            #[cfg(feature = "logging")]
3376            error!(target: "stdout", "{}", &err_msg);
3377
3378            return Err(LlamaCoreError::Operation(err_msg.into()));
3379        }
3380    };
3381
3382    // We're already holding the ChatStream lock, so we know we have exclusive access to the graph
3383    let mut chat_graphs = chat_graphs.lock().map_err(|e| {
3384        let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
3385
3386        #[cfg(feature = "logging")]
3387        error!(target: "stdout", "{}", &err_msg);
3388
3389        LlamaCoreError::Operation(err_msg)
3390    })?;
3391
3392    // Get the graph based on model name
3393    let res = match &model_name {
3394        Some(model_name) => {
3395            match chat_graphs.contains_key(model_name) {
3396                true => {
3397                    let graph = chat_graphs.get_mut(model_name).unwrap();
3398                    // compute
3399                    match graph.compute_single() {
3400                        Ok(_) => {
3401                            #[cfg(feature = "logging")]
3402                            debug!(target: "stdout", "Compute the chat stream chunk successfully.");
3403
3404                            // Process according to state
3405                            match stream_state {
3406                                StreamState::Usage | StreamState::NoUsage => {
3407                                    // Retrieve the output
3408                                    let output_buffer =
3409                                        get_output_buffer_single(graph, OUTPUT_TENSOR)?;
3410
3411                                    #[cfg(feature = "logging")]
3412                                    info!(target: "stdout", "retrieved the output buffer");
3413
3414                                    // decode the output buffer to a utf8 string
3415                                    let output = match String::from_utf8(output_buffer.clone()) {
3416                                        Ok(token) => token,
3417                                        Err(_) => {
3418                                            let mutex = CACHED_UTF8_ENCODINGS
3419                                                .get_or_init(|| Mutex::new(Vec::new()));
3420                                            let mut cached_encodings = mutex.lock().map_err(|e| {
3421                                            let err_msg = format!(
3422                                                "Fail to acquire the lock of `UTF8_ENCODINGS`. Reason: {e}"
3423                                            );
3424
3425                                            #[cfg(feature = "logging")]
3426                                            error!(target: "stdout", "{}", &err_msg);
3427
3428
3429                                            LlamaCoreError::Operation(err_msg)
3430                                        })?;
3431
3432                                            // cache the bytes for future decoding
3433                                            cached_encodings.extend_from_slice(&output_buffer[..]);
3434
3435                                            match String::from_utf8(cached_encodings.to_vec()) {
3436                                                Ok(token) => {
3437                                                    // clear CACHED_UTF8_ENCODINGS
3438                                                    cached_encodings.clear();
3439
3440                                                    token
3441                                                }
3442                                                Err(e) => {
3443                                                    // TODO This is a temp check. In case, infinite cached encodings happen.
3444                                                    if cached_encodings.len() > 4 {
3445                                                        let err_msg = format!("Fail to convert a vector of bytes to string. The length of the utf8 bytes exceeds 4. {e}");
3446
3447                                                        #[cfg(feature = "logging")]
3448                                                        error!(target: "stdout", "{}", &err_msg);
3449
3450                                                        #[cfg(feature = "logging")]
3451                                                        error!(target: "stdout", "The cached buffer: {:?}", &cached_encodings[..]);
3452
3453                                                        // let token = String::from_utf8_lossy(
3454                                                        //     &cached_encodings,
3455                                                        // )
3456                                                        // .to_string();
3457
3458                                                        // clear CACHED_UTF8_ENCODINGS
3459                                                        cached_encodings.clear();
3460
3461                                                        String::from("")
3462                                                    } else {
3463                                                        let warn_msg = format!("Fail to convert a vector of bytes to string. {e}");
3464
3465                                                        #[cfg(feature = "logging")]
3466                                                        warn!(target: "stdout", "{}", &warn_msg);
3467
3468                                                        String::from("")
3469                                                    }
3470                                                }
3471                                            }
3472                                        }
3473                                    };
3474
3475                                    #[cfg(feature = "logging")]
3476                                    info!(target: "stdout", "decoded the output buffer");
3477
3478                                    let created = SystemTime::now()
3479                                        .duration_since(std::time::UNIX_EPOCH)
3480                                        .map_err(|e| {
3481                                            let err_msg = format!(
3482                                                "Failed to get the current time. Reason: {e}"
3483                                            );
3484
3485                                            #[cfg(feature = "logging")]
3486                                            error!(target: "stdout", "{}", &err_msg);
3487
3488                                            LlamaCoreError::Operation(err_msg)
3489                                        })?;
3490
3491                                    let chat_completion_chunk = ChatCompletionChunk {
3492                                        id,
3493                                        object: "chat.completion.chunk".to_string(),
3494                                        created: created.as_secs(),
3495                                        model: graph.name().to_owned(),
3496                                        system_fingerprint: "fp_44709d6fcb".to_string(),
3497                                        choices: vec![ChatCompletionChunkChoice {
3498                                            index: 0,
3499                                            delta: ChatCompletionChunkChoiceDelta {
3500                                                role: ChatCompletionRole::Assistant,
3501                                                content: Some(output),
3502                                                tool_calls: vec![],
3503                                            },
3504                                            logprobs: None,
3505                                            finish_reason: None,
3506                                        }],
3507                                        usage: None,
3508                                    };
3509
3510                                    #[cfg(feature = "logging")]
3511                                    info!(target: "stdout", "created chat completion chunk");
3512
3513                                    // serialize chat completion chunk
3514                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
3515                                        .map_err(|e| {
3516                                        let err_msg = format!(
3517                                            "Failed to serialize chat completion chunk. Reason: {e}"
3518                                        );
3519
3520                                        #[cfg(feature = "logging")]
3521                                        error!(target: "stdout", "{}", &err_msg);
3522
3523                                        LlamaCoreError::Operation(err_msg)
3524                                    })?;
3525
3526                                    Ok(format!("data: {chunk_str}\n\n"))
3527                                }
3528                                StreamState::Done => {
3529                                    *stream_state = StreamState::EndOfSequence;
3530
3531                                    Ok("data: [DONE]\n\n".to_string())
3532                                }
3533                                StreamState::EndOfSequence => {
3534                                    Ok("[GGML] End of sequence".to_string())
3535                                }
3536                            }
3537                        }
3538                        Err(wasmedge_wasi_nn::Error::BackendError(
3539                            wasmedge_wasi_nn::BackendError::EndOfSequence,
3540                        )) => {
3541                            #[cfg(feature = "logging")]
3542                            debug!(target: "stdout", "End of sequence");
3543
3544                            match stream_state {
3545                                StreamState::Usage => {
3546                                    *stream_state = StreamState::Done;
3547
3548                                    // retrieve the number of prompt and completion tokens
3549                                    let token_info = get_token_info_by_graph(graph)?;
3550
3551                                    let usage = Some(Usage {
3552                                        prompt_tokens: token_info.prompt_tokens,
3553                                        completion_tokens: token_info.completion_tokens,
3554                                        total_tokens: token_info.prompt_tokens
3555                                            + token_info.completion_tokens,
3556                                    });
3557
3558                                    #[cfg(feature = "logging")]
3559                                    info!(target: "stdout", "token_info: {} prompt tokens, {} completion tokens", token_info.prompt_tokens, token_info.completion_tokens);
3560
3561                                    let created = SystemTime::now()
3562                                        .duration_since(std::time::UNIX_EPOCH)
3563                                        .map_err(|e| {
3564                                            let err_msg = format!(
3565                                                "Failed to get the current time. Reason: {e}"
3566                                            );
3567
3568                                            #[cfg(feature = "logging")]
3569                                            error!(target: "stdout", "{}", &err_msg);
3570
3571                                            LlamaCoreError::Operation(err_msg)
3572                                        })?;
3573
3574                                    let chat_completion_chunk = ChatCompletionChunk {
3575                                        id,
3576                                        object: "chat.completion.chunk".to_string(),
3577                                        created: created.as_secs(),
3578                                        model: graph.name().to_owned(),
3579                                        system_fingerprint: "fp_44709d6fcb".to_string(),
3580                                        choices: vec![],
3581                                        usage,
3582                                    };
3583
3584                                    // serialize chat completion chunk
3585                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
3586                                        .map_err(|e| {
3587                                        let err_msg = format!(
3588                                            "Failed to serialize chat completion chunk. Reason: {e}"
3589                                        );
3590
3591                                        #[cfg(feature = "logging")]
3592                                        error!(target: "stdout", "{}", &err_msg);
3593
3594                                        LlamaCoreError::Operation(err_msg)
3595                                    })?;
3596
3597                                    Ok(format!("data: {chunk_str}\n\n"))
3598                                }
3599                                StreamState::Done | StreamState::NoUsage => {
3600                                    *stream_state = StreamState::EndOfSequence;
3601
3602                                    Ok("data: [DONE]\n\n".to_string())
3603                                }
3604                                StreamState::EndOfSequence => {
3605                                    Ok("[GGML] End of sequence".to_string())
3606                                }
3607                            }
3608                        }
3609                        Err(wasmedge_wasi_nn::Error::BackendError(
3610                            wasmedge_wasi_nn::BackendError::ContextFull,
3611                        )) => {
3612                            #[cfg(feature = "logging")]
3613                            debug!(target: "stdout", "Context full");
3614
3615                            match context_full_state {
3616                                ContextFullState::Message => {
3617                                    match include_usage {
3618                                        true => *context_full_state = ContextFullState::Usage,
3619                                        false => *context_full_state = ContextFullState::Done,
3620                                    }
3621
3622                                    let created = SystemTime::now()
3623                                        .duration_since(std::time::UNIX_EPOCH)
3624                                        .map_err(|e| {
3625                                            let err_msg = format!(
3626                                                "Failed to get the current time. Reason: {e}"
3627                                            );
3628
3629                                            #[cfg(feature = "logging")]
3630                                            error!(target: "stdout", "{}", &err_msg);
3631
3632                                            LlamaCoreError::Operation(err_msg)
3633                                        })?;
3634
3635                                    let chat_completion_chunk = ChatCompletionChunk {
3636                                        id,
3637                                        object: "chat.completion.chunk".to_string(),
3638                                        created: created.as_secs(),
3639                                        model: graph.name().to_owned(),
3640                                        system_fingerprint: "fp_44709d6fcb".to_string(),
3641                                        choices: vec![ChatCompletionChunkChoice {
3642                                            index: 0,
3643                                            delta: ChatCompletionChunkChoiceDelta {
3644                                                role: ChatCompletionRole::Assistant,
3645                                                content: Some(
3646                                                    "<|WASMEDGE-GGML-CONTEXT-FULL|>".to_string(),
3647                                                ),
3648                                                tool_calls: vec![],
3649                                            },
3650                                            logprobs: None,
3651                                            finish_reason: Some(FinishReason::length),
3652                                        }],
3653                                        usage: None,
3654                                    };
3655
3656                                    // serialize chat completion chunk
3657                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
3658                                        .map_err(|e| {
3659                                        let err_msg = format!(
3660                                            "Failed to serialize chat completion chunk. Reason: {e}"
3661                                        );
3662
3663                                        #[cfg(feature = "logging")]
3664                                        error!(target: "stdout", "{}", &err_msg);
3665
3666                                        LlamaCoreError::Operation(err_msg)
3667                                    })?;
3668
3669                                    Ok(format!("data: {chunk_str}\n\n"))
3670                                }
3671                                ContextFullState::Usage => {
3672                                    *context_full_state = ContextFullState::Done;
3673
3674                                    // retrieve the number of prompt and completion tokens
3675                                    let token_info = get_token_info_by_graph(graph)?;
3676
3677                                    let usage = Some(Usage {
3678                                        prompt_tokens: token_info.prompt_tokens,
3679                                        completion_tokens: token_info.completion_tokens,
3680                                        total_tokens: token_info.prompt_tokens
3681                                            + token_info.completion_tokens,
3682                                    });
3683
3684                                    let created = SystemTime::now()
3685                                        .duration_since(std::time::UNIX_EPOCH)
3686                                        .map_err(|e| {
3687                                            let err_msg = format!(
3688                                                "Failed to get the current time. Reason: {e}"
3689                                            );
3690
3691                                            #[cfg(feature = "logging")]
3692                                            error!(target: "stdout", "{}", &err_msg);
3693
3694                                            LlamaCoreError::Operation(err_msg)
3695                                        })?;
3696
3697                                    let chat_completion_chunk = ChatCompletionChunk {
3698                                        id,
3699                                        object: "chat.completion.chunk".to_string(),
3700                                        created: created.as_secs(),
3701                                        model: graph.name().to_owned(),
3702                                        system_fingerprint: "fp_44709d6fcb".to_string(),
3703                                        choices: vec![],
3704                                        usage,
3705                                    };
3706
3707                                    // serialize chat completion chunk
3708                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
3709                                        .map_err(|e| {
3710                                        let err_msg = format!(
3711                                            "Failed to serialize chat completion chunk. Reason: {e}"
3712                                        );
3713
3714                                        #[cfg(feature = "logging")]
3715                                        error!(target: "stdout", "{}", &err_msg);
3716
3717                                        LlamaCoreError::Operation(err_msg)
3718                                    })?;
3719
3720                                    Ok(format!("data: {chunk_str}\n\n"))
3721                                }
3722                                ContextFullState::Done => {
3723                                    *context_full_state = ContextFullState::EndOfSequence;
3724
3725                                    Ok("data: [DONE]\n\n".to_string())
3726                                }
3727                                ContextFullState::EndOfSequence => {
3728                                    Ok("[GGML] End of sequence".to_string())
3729                                }
3730                            }
3731                        }
3732                        Err(wasmedge_wasi_nn::Error::BackendError(
3733                            wasmedge_wasi_nn::BackendError::PromptTooLong,
3734                        )) => {
3735                            #[cfg(feature = "logging")]
3736                            debug!(target: "stdout", "Prompt too long");
3737
3738                            match prompt_too_long_state {
3739                                PromptTooLongState::Message => {
3740                                    match include_usage {
3741                                        true => *prompt_too_long_state = PromptTooLongState::Usage,
3742                                        false => *prompt_too_long_state = PromptTooLongState::Done,
3743                                    }
3744
3745                                    let created = SystemTime::now()
3746                                        .duration_since(std::time::UNIX_EPOCH)
3747                                        .map_err(|e| {
3748                                            let err_msg = format!(
3749                                                "Failed to get the current time. Reason: {e}"
3750                                            );
3751
3752                                            #[cfg(feature = "logging")]
3753                                            error!(target: "stdout", "{}", &err_msg);
3754
3755                                            LlamaCoreError::Operation(err_msg)
3756                                        })?;
3757
3758                                    let chat_completion_chunk = ChatCompletionChunk {
3759                                        id,
3760                                        object: "chat.completion.chunk".to_string(),
3761                                        created: created.as_secs(),
3762                                        model: graph.name().to_owned(),
3763                                        system_fingerprint: "fp_44709d6fcb".to_string(),
3764                                        choices: vec![ChatCompletionChunkChoice {
3765                                            index: 0,
3766                                            delta: ChatCompletionChunkChoiceDelta {
3767                                                role: ChatCompletionRole::Assistant,
3768                                                content: None,
3769                                                tool_calls: vec![],
3770                                            },
3771                                            logprobs: None,
3772                                            finish_reason: Some(FinishReason::length),
3773                                        }],
3774                                        usage: None,
3775                                    };
3776
3777                                    // serialize chat completion chunk
3778                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
3779                                        .map_err(|e| {
3780                                        let err_msg = format!(
3781                                            "Failed to serialize chat completion chunk. Reason: {e}"
3782                                        );
3783
3784                                        #[cfg(feature = "logging")]
3785                                        error!(target: "stdout", "{}", &err_msg);
3786
3787                                        LlamaCoreError::Operation(err_msg)
3788                                    })?;
3789
3790                                    Ok(format!("data: {chunk_str}\n\n"))
3791                                }
3792                                PromptTooLongState::Usage => {
3793                                    *prompt_too_long_state = PromptTooLongState::Done;
3794
3795                                    // retrieve the number of prompt and completion tokens
3796                                    let token_info = get_token_info_by_graph(graph)?;
3797
3798                                    let usage = Some(Usage {
3799                                        prompt_tokens: token_info.prompt_tokens,
3800                                        completion_tokens: token_info.completion_tokens,
3801                                        total_tokens: token_info.prompt_tokens
3802                                            + token_info.completion_tokens,
3803                                    });
3804
3805                                    let created = SystemTime::now()
3806                                        .duration_since(std::time::UNIX_EPOCH)
3807                                        .map_err(|e| {
3808                                            let err_msg = format!(
3809                                                "Failed to get the current time. Reason: {e}"
3810                                            );
3811
3812                                            #[cfg(feature = "logging")]
3813                                            error!(target: "stdout", "{}", &err_msg);
3814
3815                                            LlamaCoreError::Operation(err_msg)
3816                                        })?;
3817
3818                                    let chat_completion_chunk = ChatCompletionChunk {
3819                                        id,
3820                                        object: "chat.completion.chunk".to_string(),
3821                                        created: created.as_secs(),
3822                                        model: graph.name().to_owned(),
3823                                        system_fingerprint: "fp_44709d6fcb".to_string(),
3824                                        choices: vec![],
3825                                        usage,
3826                                    };
3827
3828                                    // serialize chat completion chunk
3829                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
3830                                        .map_err(|e| {
3831                                        let err_msg = format!(
3832                                            "Failed to serialize chat completion chunk. Reason: {e}"
3833                                        );
3834
3835                                        #[cfg(feature = "logging")]
3836                                        error!(target: "stdout", "{}", &err_msg);
3837
3838                                        LlamaCoreError::Operation(err_msg)
3839                                    })?;
3840
3841                                    Ok(format!("data: {chunk_str}\n\n"))
3842                                }
3843                                PromptTooLongState::Done => {
3844                                    *prompt_too_long_state = PromptTooLongState::EndOfSequence;
3845
3846                                    Ok("data: [DONE]\n\n".to_string())
3847                                }
3848                                PromptTooLongState::EndOfSequence => {
3849                                    Ok("[GGML] End of sequence".to_string())
3850                                }
3851                            }
3852                        }
3853                        Err(e) => {
3854                            let err_msg =
3855                                format!("Failed to compute the chat completion. Reason: {e}");
3856
3857                            #[cfg(feature = "logging")]
3858                            error!(target: "stdout", "{}", &err_msg);
3859
3860                            Err(LlamaCoreError::Backend(BackendError::ComputeSingle(
3861                                err_msg,
3862                            )))
3863                        }
3864                    }
3865                }
3866                false => {
3867                    match chat_graphs.iter_mut().next() {
3868                        Some((_, graph)) => {
3869                            // compute
3870                            match graph.compute_single() {
3871                                Ok(_) => {
3872                                    #[cfg(feature = "logging")]
3873                                    debug!(target: "stdout", "Compute the chat stream chunk successfully.");
3874
3875                                    match stream_state {
3876                                        StreamState::Usage | StreamState::NoUsage => {
3877                                            // Retrieve the output
3878                                            let output_buffer =
3879                                                get_output_buffer_single(graph, OUTPUT_TENSOR)?;
3880
3881                                            #[cfg(feature = "logging")]
3882                                            info!(target: "stdout", "retrieved the output buffer");
3883
3884                                            // decode the output buffer to a utf8 string
3885                                            let output = match String::from_utf8(
3886                                                output_buffer.clone(),
3887                                            ) {
3888                                                Ok(token) => token,
3889                                                Err(_) => {
3890                                                    let mutex = CACHED_UTF8_ENCODINGS
3891                                                        .get_or_init(|| Mutex::new(Vec::new()));
3892                                                    let mut cached_encodings = mutex.lock().map_err(|e| {
3893                                            let err_msg = format!(
3894                                                "Fail to acquire the lock of `UTF8_ENCODINGS`. Reason: {e}"
3895                                            );
3896
3897                                            #[cfg(feature = "logging")]
3898                                            error!(target: "stdout", "{}", &err_msg);
3899
3900
3901                                            LlamaCoreError::Operation(err_msg)
3902                                        })?;
3903
3904                                                    // cache the bytes for future decoding
3905                                                    cached_encodings
3906                                                        .extend_from_slice(&output_buffer[..]);
3907
3908                                                    match String::from_utf8(
3909                                                        cached_encodings.to_vec(),
3910                                                    ) {
3911                                                        Ok(token) => {
3912                                                            // clear encodings
3913                                                            cached_encodings.clear();
3914
3915                                                            token
3916                                                        }
3917                                                        Err(e) => {
3918                                                            // TODO This is a temp check. In case, infinite cached encodings happen.
3919                                                            if cached_encodings.len() > 4 {
3920                                                                let err_msg = format!("Fail to convert a vector of bytes to string. The length of the utf8 bytes exceeds 4. {e}");
3921
3922                                                                #[cfg(feature = "logging")]
3923                                                                error!(target: "stdout", "{}", &err_msg);
3924
3925                                                                #[cfg(feature = "logging")]
3926                                                                error!(target: "stdout", "The cached buffer: {:?}", &cached_encodings[..]);
3927
3928                                                                // let token =
3929                                                                //     String::from_utf8_lossy(
3930                                                                //         &cached_encodings,
3931                                                                //     )
3932                                                                //     .to_string();
3933
3934                                                                // clear CACHED_UTF8_ENCODINGS
3935                                                                cached_encodings.clear();
3936
3937                                                                String::from("")
3938                                                            } else {
3939                                                                let warn_msg = format!("Fail to convert a vector of bytes to string. {e}");
3940
3941                                                                #[cfg(feature = "logging")]
3942                                                                warn!(target: "stdout", "{}", &warn_msg);
3943
3944                                                                String::from("")
3945                                                            }
3946                                                        }
3947                                                    }
3948                                                }
3949                                            };
3950
3951                                            #[cfg(feature = "logging")]
3952                                            info!(target: "stdout", "decoded the output buffer");
3953
3954                                            let created = SystemTime::now()
3955                                                .duration_since(std::time::UNIX_EPOCH)
3956                                                .map_err(|e| {
3957                                                    let err_msg = format!(
3958                                                "Failed to get the current time. Reason: {e}"
3959                                            );
3960
3961                                                    #[cfg(feature = "logging")]
3962                                                    error!(target: "stdout", "{}", &err_msg);
3963
3964                                                    LlamaCoreError::Operation(err_msg)
3965                                                })?;
3966
3967                                            let chat_completion_chunk = ChatCompletionChunk {
3968                                                id,
3969                                                object: "chat.completion.chunk".to_string(),
3970                                                created: created.as_secs(),
3971                                                model: graph.name().to_owned(),
3972                                                system_fingerprint: "fp_44709d6fcb".to_string(),
3973                                                choices: vec![ChatCompletionChunkChoice {
3974                                                    index: 0,
3975                                                    delta: ChatCompletionChunkChoiceDelta {
3976                                                        role: ChatCompletionRole::Assistant,
3977                                                        content: Some(output),
3978                                                        tool_calls: vec![],
3979                                                    },
3980                                                    logprobs: None,
3981                                                    finish_reason: None,
3982                                                }],
3983                                                usage: None,
3984                                            };
3985
3986                                            #[cfg(feature = "logging")]
3987                                            info!(target: "stdout", "created chat completion chunk");
3988
3989                                            // serialize chat completion chunk
3990                                            let chunk_str =
3991                                                serde_json::to_string(&chat_completion_chunk)
3992                                                    .map_err(|e| {
3993                                                        let err_msg = format!(
3994                                            "Failed to serialize chat completion chunk. Reason: {e}"
3995                                        );
3996
3997                                                        #[cfg(feature = "logging")]
3998                                                        error!(target: "stdout", "{}", &err_msg);
3999
4000                                                        LlamaCoreError::Operation(err_msg)
4001                                                    })?;
4002
4003                                            Ok(format!("data: {chunk_str}\n\n"))
4004                                        }
4005                                        StreamState::Done => {
4006                                            *stream_state = StreamState::EndOfSequence;
4007
4008                                            Ok("data: [DONE]\n\n".to_string())
4009                                        }
4010                                        StreamState::EndOfSequence => {
4011                                            Ok("[GGML] End of sequence".to_string())
4012                                        }
4013                                    }
4014                                }
4015                                Err(wasmedge_wasi_nn::Error::BackendError(
4016                                    wasmedge_wasi_nn::BackendError::EndOfSequence,
4017                                )) => {
4018                                    #[cfg(feature = "logging")]
4019                                    debug!(target: "stdout", "End of sequence");
4020
4021                                    match stream_state {
4022                                        StreamState::Usage => {
4023                                            *stream_state = StreamState::Done;
4024
4025                                            // retrieve the number of prompt and completion tokens
4026                                            let token_info = get_token_info_by_graph(graph)?;
4027
4028                                            let usage = Some(Usage {
4029                                                prompt_tokens: token_info.prompt_tokens,
4030                                                completion_tokens: token_info.completion_tokens,
4031                                                total_tokens: token_info.prompt_tokens
4032                                                    + token_info.completion_tokens,
4033                                            });
4034
4035                                            #[cfg(feature = "logging")]
4036                                            info!(target: "stdout", "token_info: {} prompt tokens, {} completion tokens", token_info.prompt_tokens, token_info.completion_tokens);
4037
4038                                            let created = SystemTime::now()
4039                                                .duration_since(std::time::UNIX_EPOCH)
4040                                                .map_err(|e| {
4041                                                    let err_msg = format!(
4042                                                "Failed to get the current time. Reason: {e}"
4043                                            );
4044
4045                                                    #[cfg(feature = "logging")]
4046                                                    error!(target: "stdout", "{}", &err_msg);
4047
4048                                                    LlamaCoreError::Operation(err_msg)
4049                                                })?;
4050
4051                                            let chat_completion_chunk = ChatCompletionChunk {
4052                                                id,
4053                                                object: "chat.completion.chunk".to_string(),
4054                                                created: created.as_secs(),
4055                                                model: graph.name().to_owned(),
4056                                                system_fingerprint: "fp_44709d6fcb".to_string(),
4057                                                choices: vec![],
4058                                                usage,
4059                                            };
4060
4061                                            // serialize chat completion chunk
4062                                            let chunk_str =
4063                                                serde_json::to_string(&chat_completion_chunk)
4064                                                    .map_err(|e| {
4065                                                        let err_msg = format!(
4066                                            "Failed to serialize chat completion chunk. Reason: {e}"
4067                                        );
4068
4069                                                        #[cfg(feature = "logging")]
4070                                                        error!(target: "stdout", "{}", &err_msg);
4071
4072                                                        LlamaCoreError::Operation(err_msg)
4073                                                    })?;
4074
4075                                            Ok(format!("data: {chunk_str}\n\n"))
4076                                        }
4077                                        StreamState::Done | StreamState::NoUsage => {
4078                                            *stream_state = StreamState::EndOfSequence;
4079
4080                                            Ok("data: [DONE]\n\n".to_string())
4081                                        }
4082                                        StreamState::EndOfSequence => {
4083                                            Ok("[GGML] End of sequence".to_string())
4084                                        }
4085                                    }
4086                                }
4087                                Err(wasmedge_wasi_nn::Error::BackendError(
4088                                    wasmedge_wasi_nn::BackendError::ContextFull,
4089                                )) => {
4090                                    #[cfg(feature = "logging")]
4091                                    debug!(target: "stdout", "Context full");
4092
4093                                    match context_full_state {
4094                                        ContextFullState::Message => {
4095                                            match include_usage {
4096                                                true => {
4097                                                    *context_full_state = ContextFullState::Usage
4098                                                }
4099                                                false => {
4100                                                    *context_full_state = ContextFullState::Done
4101                                                }
4102                                            }
4103
4104                                            let created = SystemTime::now()
4105                                                .duration_since(std::time::UNIX_EPOCH)
4106                                                .map_err(|e| {
4107                                                    let err_msg = format!(
4108                                                "Failed to get the current time. Reason: {e}"
4109                                            );
4110
4111                                                    #[cfg(feature = "logging")]
4112                                                    error!(target: "stdout", "{}", &err_msg);
4113
4114                                                    LlamaCoreError::Operation(err_msg)
4115                                                })?;
4116
4117                                            let chat_completion_chunk = ChatCompletionChunk {
4118                                                id,
4119                                                object: "chat.completion.chunk".to_string(),
4120                                                created: created.as_secs(),
4121                                                model: graph.name().to_owned(),
4122                                                system_fingerprint: "fp_44709d6fcb".to_string(),
4123                                                choices: vec![ChatCompletionChunkChoice {
4124                                                    index: 0,
4125                                                    delta: ChatCompletionChunkChoiceDelta {
4126                                                        role: ChatCompletionRole::Assistant,
4127                                                        content: Some(
4128                                                            "<|WASMEDGE-GGML-CONTEXT-FULL|>"
4129                                                                .to_string(),
4130                                                        ),
4131                                                        tool_calls: vec![],
4132                                                    },
4133                                                    logprobs: None,
4134                                                    finish_reason: Some(FinishReason::length),
4135                                                }],
4136                                                usage: None,
4137                                            };
4138
4139                                            // serialize chat completion chunk
4140                                            let chunk_str =
4141                                                serde_json::to_string(&chat_completion_chunk)
4142                                                    .map_err(|e| {
4143                                                        let err_msg = format!(
4144                                            "Failed to serialize chat completion chunk. Reason: {e}"
4145                                        );
4146
4147                                                        #[cfg(feature = "logging")]
4148                                                        error!(target: "stdout", "{}", &err_msg);
4149
4150                                                        LlamaCoreError::Operation(err_msg)
4151                                                    })?;
4152
4153                                            Ok(format!("data: {chunk_str}\n\n"))
4154                                        }
4155                                        ContextFullState::Usage => {
4156                                            *context_full_state = ContextFullState::Done;
4157
4158                                            // retrieve the number of prompt and completion tokens
4159                                            let token_info = get_token_info_by_graph(graph)?;
4160
4161                                            let usage = Some(Usage {
4162                                                prompt_tokens: token_info.prompt_tokens,
4163                                                completion_tokens: token_info.completion_tokens,
4164                                                total_tokens: token_info.prompt_tokens
4165                                                    + token_info.completion_tokens,
4166                                            });
4167
4168                                            let created = SystemTime::now()
4169                                                .duration_since(std::time::UNIX_EPOCH)
4170                                                .map_err(|e| {
4171                                                    let err_msg = format!(
4172                                                "Failed to get the current time. Reason: {e}"
4173                                            );
4174
4175                                                    #[cfg(feature = "logging")]
4176                                                    error!(target: "stdout", "{}", &err_msg);
4177
4178                                                    LlamaCoreError::Operation(err_msg)
4179                                                })?;
4180
4181                                            let chat_completion_chunk = ChatCompletionChunk {
4182                                                id,
4183                                                object: "chat.completion.chunk".to_string(),
4184                                                created: created.as_secs(),
4185                                                model: graph.name().to_owned(),
4186                                                system_fingerprint: "fp_44709d6fcb".to_string(),
4187                                                choices: vec![],
4188                                                usage,
4189                                            };
4190
4191                                            // serialize chat completion chunk
4192                                            let chunk_str =
4193                                                serde_json::to_string(&chat_completion_chunk)
4194                                                    .map_err(|e| {
4195                                                        let err_msg = format!(
4196                                            "Failed to serialize chat completion chunk. Reason: {e}"
4197                                        );
4198
4199                                                        #[cfg(feature = "logging")]
4200                                                        error!(target: "stdout", "{}", &err_msg);
4201
4202                                                        LlamaCoreError::Operation(err_msg)
4203                                                    })?;
4204
4205                                            Ok(format!("data: {chunk_str}\n\n"))
4206                                        }
4207                                        ContextFullState::Done => {
4208                                            *context_full_state = ContextFullState::EndOfSequence;
4209
4210                                            Ok("data: [DONE]\n\n".to_string())
4211                                        }
4212                                        ContextFullState::EndOfSequence => {
4213                                            Ok("[GGML] End of sequence".to_string())
4214                                        }
4215                                    }
4216                                }
4217                                Err(wasmedge_wasi_nn::Error::BackendError(
4218                                    wasmedge_wasi_nn::BackendError::PromptTooLong,
4219                                )) => {
4220                                    #[cfg(feature = "logging")]
4221                                    debug!(target: "stdout", "Prompt too long");
4222
4223                                    match prompt_too_long_state {
4224                                        PromptTooLongState::Message => {
4225                                            match include_usage {
4226                                                true => {
4227                                                    *prompt_too_long_state =
4228                                                        PromptTooLongState::Usage
4229                                                }
4230                                                false => {
4231                                                    *prompt_too_long_state =
4232                                                        PromptTooLongState::Done
4233                                                }
4234                                            }
4235
4236                                            let created = SystemTime::now()
4237                                                .duration_since(std::time::UNIX_EPOCH)
4238                                                .map_err(|e| {
4239                                                    let err_msg = format!(
4240                                                "Failed to get the current time. Reason: {e}"
4241                                            );
4242
4243                                                    #[cfg(feature = "logging")]
4244                                                    error!(target: "stdout", "{}", &err_msg);
4245
4246                                                    LlamaCoreError::Operation(err_msg)
4247                                                })?;
4248
4249                                            let chat_completion_chunk = ChatCompletionChunk {
4250                                                id,
4251                                                object: "chat.completion.chunk".to_string(),
4252                                                created: created.as_secs(),
4253                                                model: graph.name().to_owned(),
4254                                                system_fingerprint: "fp_44709d6fcb".to_string(),
4255                                                choices: vec![ChatCompletionChunkChoice {
4256                                                    index: 0,
4257                                                    delta: ChatCompletionChunkChoiceDelta {
4258                                                        role: ChatCompletionRole::Assistant,
4259                                                        content: None,
4260                                                        tool_calls: vec![],
4261                                                    },
4262                                                    logprobs: None,
4263                                                    finish_reason: Some(FinishReason::length),
4264                                                }],
4265                                                usage: None,
4266                                            };
4267
4268                                            // serialize chat completion chunk
4269                                            let chunk_str =
4270                                                serde_json::to_string(&chat_completion_chunk)
4271                                                    .map_err(|e| {
4272                                                        let err_msg = format!(
4273                                            "Failed to serialize chat completion chunk. Reason: {e}"
4274                                        );
4275
4276                                                        #[cfg(feature = "logging")]
4277                                                        error!(target: "stdout", "{}", &err_msg);
4278
4279                                                        LlamaCoreError::Operation(err_msg)
4280                                                    })?;
4281
4282                                            Ok(format!("data: {chunk_str}\n\n"))
4283                                        }
4284                                        PromptTooLongState::Usage => {
4285                                            *prompt_too_long_state = PromptTooLongState::Done;
4286
4287                                            // retrieve the number of prompt and completion tokens
4288                                            let token_info = get_token_info_by_graph(graph)?;
4289
4290                                            let usage = Some(Usage {
4291                                                prompt_tokens: token_info.prompt_tokens,
4292                                                completion_tokens: token_info.completion_tokens,
4293                                                total_tokens: token_info.prompt_tokens
4294                                                    + token_info.completion_tokens,
4295                                            });
4296
4297                                            let created = SystemTime::now()
4298                                                .duration_since(std::time::UNIX_EPOCH)
4299                                                .map_err(|e| {
4300                                                    let err_msg = format!(
4301                                                "Failed to get the current time. Reason: {e}"
4302                                            );
4303
4304                                                    #[cfg(feature = "logging")]
4305                                                    error!(target: "stdout", "{}", &err_msg);
4306
4307                                                    LlamaCoreError::Operation(err_msg)
4308                                                })?;
4309
4310                                            let chat_completion_chunk = ChatCompletionChunk {
4311                                                id,
4312                                                object: "chat.completion.chunk".to_string(),
4313                                                created: created.as_secs(),
4314                                                model: graph.name().to_owned(),
4315                                                system_fingerprint: "fp_44709d6fcb".to_string(),
4316                                                choices: vec![],
4317                                                usage,
4318                                            };
4319
4320                                            // serialize chat completion chunk
4321                                            let chunk_str =
4322                                                serde_json::to_string(&chat_completion_chunk)
4323                                                    .map_err(|e| {
4324                                                        let err_msg = format!(
4325                                            "Failed to serialize chat completion chunk. Reason: {e}"
4326                                        );
4327
4328                                                        #[cfg(feature = "logging")]
4329                                                        error!(target: "stdout", "{}", &err_msg);
4330
4331                                                        LlamaCoreError::Operation(err_msg)
4332                                                    })?;
4333
4334                                            Ok(format!("data: {chunk_str}\n\n"))
4335                                        }
4336                                        PromptTooLongState::Done => {
4337                                            *prompt_too_long_state =
4338                                                PromptTooLongState::EndOfSequence;
4339
4340                                            Ok("data: [DONE]\n\n".to_string())
4341                                        }
4342                                        PromptTooLongState::EndOfSequence => {
4343                                            Ok("[GGML] End of sequence".to_string())
4344                                        }
4345                                    }
4346                                }
4347                                Err(e) => {
4348                                    let err_msg = format!(
4349                                        "Failed to compute the chat completion. Reason: {e}"
4350                                    );
4351
4352                                    #[cfg(feature = "logging")]
4353                                    error!(target: "stdout", "{}", &err_msg);
4354
4355                                    Err(LlamaCoreError::Backend(BackendError::ComputeSingle(
4356                                        err_msg,
4357                                    )))
4358                                }
4359                            }
4360                        }
4361                        None => {
4362                            let err_msg = "There is no model available in the chat graphs.";
4363
4364                            #[cfg(feature = "logging")]
4365                            error!(target: "stdout", "{}", &err_msg);
4366
4367                            Err(LlamaCoreError::Operation(err_msg.into()))
4368                        }
4369                    }
4370                }
4371            }
4372        }
4373        None => {
4374            match chat_graphs.iter_mut().next() {
4375                Some((_, graph)) => {
4376                    // compute
4377                    match graph.compute_single() {
4378                        Ok(_) => {
4379                            #[cfg(feature = "logging")]
4380                            debug!(target: "stdout", "Compute the chat stream chunk successfully.");
4381
4382                            match stream_state {
4383                                StreamState::Usage | StreamState::NoUsage => {
4384                                    // Retrieve the output
4385                                    let output_buffer =
4386                                        get_output_buffer_single(graph, OUTPUT_TENSOR)?;
4387
4388                                    #[cfg(feature = "logging")]
4389                                    info!(target: "stdout", "retrieved the output buffer");
4390
4391                                    // decode the output buffer to a utf8 string
4392                                    let output = match String::from_utf8(output_buffer.clone()) {
4393                                        Ok(token) => token,
4394                                        Err(_) => {
4395                                            let mutex = CACHED_UTF8_ENCODINGS
4396                                                .get_or_init(|| Mutex::new(Vec::new()));
4397                                            let mut cached_encodings = mutex.lock().map_err(|e| {
4398                                            let err_msg = format!(
4399                                                "Fail to acquire the lock of `UTF8_ENCODINGS`. Reason: {e}"
4400                                            );
4401
4402                                            #[cfg(feature = "logging")]
4403                                            error!(target: "stdout", "{}", &err_msg);
4404
4405                                            LlamaCoreError::Operation(err_msg)
4406                                        })?;
4407
4408                                            cached_encodings.extend_from_slice(&output_buffer[..]);
4409
4410                                            match String::from_utf8(cached_encodings.to_vec()) {
4411                                                Ok(token) => {
4412                                                    // clear encodings
4413                                                    cached_encodings.clear();
4414
4415                                                    token
4416                                                }
4417                                                Err(e) => {
4418                                                    // TODO This is a temp check. In case, infinite cached encodings happen.
4419                                                    if cached_encodings.len() > 4 {
4420                                                        let err_msg = format!("Fail to convert a vector of bytes to string. The length of the utf8 bytes exceeds 4. {e}");
4421
4422                                                        #[cfg(feature = "logging")]
4423                                                        error!(target: "stdout", "{}", &err_msg);
4424
4425                                                        #[cfg(feature = "logging")]
4426                                                        error!(target: "stdout", "The cached buffer: {:?}", &cached_encodings[..]);
4427
4428                                                        // let token = String::from_utf8_lossy(
4429                                                        //     &cached_encodings,
4430                                                        // )
4431                                                        // .to_string();
4432
4433                                                        // clear CACHED_UTF8_ENCODINGS
4434                                                        cached_encodings.clear();
4435
4436                                                        String::from("")
4437                                                    } else {
4438                                                        let warn_msg = format!("Fail to convert a vector of bytes to string. {e}");
4439
4440                                                        #[cfg(feature = "logging")]
4441                                                        warn!(target: "stdout", "{}", &warn_msg);
4442
4443                                                        String::from("")
4444                                                    }
4445                                                }
4446                                            }
4447                                        }
4448                                    };
4449
4450                                    #[cfg(feature = "logging")]
4451                                    info!(target: "stdout", "decoded the output buffer");
4452
4453                                    let created = SystemTime::now()
4454                                        .duration_since(std::time::UNIX_EPOCH)
4455                                        .map_err(|e| {
4456                                            let err_msg = format!(
4457                                                "Failed to get the current time. Reason: {e}"
4458                                            );
4459
4460                                            #[cfg(feature = "logging")]
4461                                            error!(target: "stdout", "{}", &err_msg);
4462
4463                                            LlamaCoreError::Operation(err_msg)
4464                                        })?;
4465
4466                                    let chat_completion_chunk = ChatCompletionChunk {
4467                                        id,
4468                                        object: "chat.completion.chunk".to_string(),
4469                                        created: created.as_secs(),
4470                                        model: graph.name().to_owned(),
4471                                        system_fingerprint: "fp_44709d6fcb".to_string(),
4472                                        choices: vec![ChatCompletionChunkChoice {
4473                                            index: 0,
4474                                            delta: ChatCompletionChunkChoiceDelta {
4475                                                role: ChatCompletionRole::Assistant,
4476                                                content: Some(output),
4477                                                tool_calls: vec![],
4478                                            },
4479                                            logprobs: None,
4480                                            finish_reason: None,
4481                                        }],
4482                                        usage: None,
4483                                    };
4484
4485                                    #[cfg(feature = "logging")]
4486                                    info!(target: "stdout", "created chat completion chunk");
4487
4488                                    // serialize chat completion chunk
4489                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
4490                                        .map_err(|e| {
4491                                        let err_msg = format!(
4492                                            "Failed to serialize chat completion chunk. Reason: {e}"
4493                                        );
4494
4495                                        #[cfg(feature = "logging")]
4496                                        error!(target: "stdout", "{}", &err_msg);
4497
4498                                        LlamaCoreError::Operation(err_msg)
4499                                    })?;
4500
4501                                    Ok(format!("data: {chunk_str}\n\n"))
4502                                }
4503                                StreamState::Done => {
4504                                    *stream_state = StreamState::EndOfSequence;
4505
4506                                    Ok("data: [DONE]\n\n".to_string())
4507                                }
4508                                StreamState::EndOfSequence => {
4509                                    Ok("[GGML] End of sequence".to_string())
4510                                }
4511                            }
4512                        }
4513                        Err(wasmedge_wasi_nn::Error::BackendError(
4514                            wasmedge_wasi_nn::BackendError::EndOfSequence,
4515                        )) => {
4516                            #[cfg(feature = "logging")]
4517                            debug!(target: "stdout", "End of sequence");
4518
4519                            match stream_state {
4520                                StreamState::Usage => {
4521                                    *stream_state = StreamState::Done;
4522
4523                                    // retrieve the number of prompt and completion tokens
4524                                    let token_info = get_token_info_by_graph(graph)?;
4525
4526                                    let usage = Some(Usage {
4527                                        prompt_tokens: token_info.prompt_tokens,
4528                                        completion_tokens: token_info.completion_tokens,
4529                                        total_tokens: token_info.prompt_tokens
4530                                            + token_info.completion_tokens,
4531                                    });
4532
4533                                    #[cfg(feature = "logging")]
4534                                    info!(target: "stdout", "token_info: {} prompt tokens, {} completion tokens", token_info.prompt_tokens, token_info.completion_tokens);
4535
4536                                    let created = SystemTime::now()
4537                                        .duration_since(std::time::UNIX_EPOCH)
4538                                        .map_err(|e| {
4539                                            let err_msg = format!(
4540                                                "Failed to get the current time. Reason: {e}"
4541                                            );
4542
4543                                            #[cfg(feature = "logging")]
4544                                            error!(target: "stdout", "{}", &err_msg);
4545
4546                                            LlamaCoreError::Operation(err_msg)
4547                                        })?;
4548
4549                                    let chat_completion_chunk = ChatCompletionChunk {
4550                                        id,
4551                                        object: "chat.completion.chunk".to_string(),
4552                                        created: created.as_secs(),
4553                                        model: graph.name().to_owned(),
4554                                        system_fingerprint: "fp_44709d6fcb".to_string(),
4555                                        choices: vec![],
4556                                        usage,
4557                                    };
4558
4559                                    // serialize chat completion chunk
4560                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
4561                                        .map_err(|e| {
4562                                        let err_msg = format!(
4563                                            "Failed to serialize chat completion chunk. Reason: {e}"
4564                                        );
4565
4566                                        #[cfg(feature = "logging")]
4567                                        error!(target: "stdout", "{}", &err_msg);
4568
4569                                        LlamaCoreError::Operation(err_msg)
4570                                    })?;
4571
4572                                    Ok(format!("data: {chunk_str}\n\n"))
4573                                }
4574                                StreamState::Done | StreamState::NoUsage => {
4575                                    *stream_state = StreamState::EndOfSequence;
4576
4577                                    Ok("data: [DONE]\n\n".to_string())
4578                                }
4579                                StreamState::EndOfSequence => {
4580                                    Ok("[GGML] End of sequence".to_string())
4581                                }
4582                            }
4583                        }
4584                        Err(wasmedge_wasi_nn::Error::BackendError(
4585                            wasmedge_wasi_nn::BackendError::ContextFull,
4586                        )) => {
4587                            #[cfg(feature = "logging")]
4588                            debug!(target: "stdout", "Context full");
4589
4590                            match context_full_state {
4591                                ContextFullState::Message => {
4592                                    match include_usage {
4593                                        true => *context_full_state = ContextFullState::Usage,
4594                                        false => *context_full_state = ContextFullState::Done,
4595                                    }
4596
4597                                    let created = SystemTime::now()
4598                                        .duration_since(std::time::UNIX_EPOCH)
4599                                        .map_err(|e| {
4600                                            let err_msg = format!(
4601                                                "Failed to get the current time. Reason: {e}"
4602                                            );
4603
4604                                            #[cfg(feature = "logging")]
4605                                            error!(target: "stdout", "{}", &err_msg);
4606
4607                                            LlamaCoreError::Operation(err_msg)
4608                                        })?;
4609
4610                                    let chat_completion_chunk = ChatCompletionChunk {
4611                                        id,
4612                                        object: "chat.completion.chunk".to_string(),
4613                                        created: created.as_secs(),
4614                                        model: graph.name().to_owned(),
4615                                        system_fingerprint: "fp_44709d6fcb".to_string(),
4616                                        choices: vec![ChatCompletionChunkChoice {
4617                                            index: 0,
4618                                            delta: ChatCompletionChunkChoiceDelta {
4619                                                role: ChatCompletionRole::Assistant,
4620                                                content: Some(
4621                                                    "<|WASMEDGE-GGML-CONTEXT-FULL|>".to_string(),
4622                                                ),
4623                                                tool_calls: vec![],
4624                                            },
4625                                            logprobs: None,
4626                                            finish_reason: Some(FinishReason::length),
4627                                        }],
4628                                        usage: None,
4629                                    };
4630
4631                                    // serialize chat completion chunk
4632                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
4633                                        .map_err(|e| {
4634                                        let err_msg = format!(
4635                                            "Failed to serialize chat completion chunk. Reason: {e}"
4636                                        );
4637
4638                                        #[cfg(feature = "logging")]
4639                                        error!(target: "stdout", "{}", &err_msg);
4640
4641                                        LlamaCoreError::Operation(err_msg)
4642                                    })?;
4643
4644                                    Ok(format!("data: {chunk_str}\n\n"))
4645                                }
4646                                ContextFullState::Usage => {
4647                                    *context_full_state = ContextFullState::Done;
4648
4649                                    // retrieve the number of prompt and completion tokens
4650                                    let token_info = get_token_info_by_graph(graph)?;
4651
4652                                    let usage = Some(Usage {
4653                                        prompt_tokens: token_info.prompt_tokens,
4654                                        completion_tokens: token_info.completion_tokens,
4655                                        total_tokens: token_info.prompt_tokens
4656                                            + token_info.completion_tokens,
4657                                    });
4658
4659                                    let created = SystemTime::now()
4660                                        .duration_since(std::time::UNIX_EPOCH)
4661                                        .map_err(|e| {
4662                                            let err_msg = format!(
4663                                                "Failed to get the current time. Reason: {e}"
4664                                            );
4665
4666                                            #[cfg(feature = "logging")]
4667                                            error!(target: "stdout", "{}", &err_msg);
4668
4669                                            LlamaCoreError::Operation(err_msg)
4670                                        })?;
4671
4672                                    let chat_completion_chunk = ChatCompletionChunk {
4673                                        id,
4674                                        object: "chat.completion.chunk".to_string(),
4675                                        created: created.as_secs(),
4676                                        model: graph.name().to_owned(),
4677                                        system_fingerprint: "fp_44709d6fcb".to_string(),
4678                                        choices: vec![],
4679                                        usage,
4680                                    };
4681
4682                                    // serialize chat completion chunk
4683                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
4684                                        .map_err(|e| {
4685                                        let err_msg = format!(
4686                                            "Failed to serialize chat completion chunk. Reason: {e}"
4687                                        );
4688
4689                                        #[cfg(feature = "logging")]
4690                                        error!(target: "stdout", "{}", &err_msg);
4691
4692                                        LlamaCoreError::Operation(err_msg)
4693                                    })?;
4694
4695                                    Ok(format!("data: {chunk_str}\n\n"))
4696                                }
4697                                ContextFullState::Done => {
4698                                    *context_full_state = ContextFullState::EndOfSequence;
4699
4700                                    Ok("data: [DONE]\n\n".to_string())
4701                                }
4702                                ContextFullState::EndOfSequence => {
4703                                    Ok("[GGML] End of sequence".to_string())
4704                                }
4705                            }
4706                        }
4707                        Err(wasmedge_wasi_nn::Error::BackendError(
4708                            wasmedge_wasi_nn::BackendError::PromptTooLong,
4709                        )) => {
4710                            #[cfg(feature = "logging")]
4711                            debug!(target: "stdout", "Prompt too long");
4712
4713                            match prompt_too_long_state {
4714                                PromptTooLongState::Message => {
4715                                    match include_usage {
4716                                        true => *prompt_too_long_state = PromptTooLongState::Usage,
4717                                        false => *prompt_too_long_state = PromptTooLongState::Done,
4718                                    }
4719
4720                                    let created = SystemTime::now()
4721                                        .duration_since(std::time::UNIX_EPOCH)
4722                                        .map_err(|e| {
4723                                            let err_msg = format!(
4724                                                "Failed to get the current time. Reason: {e}"
4725                                            );
4726
4727                                            #[cfg(feature = "logging")]
4728                                            error!(target: "stdout", "{}", &err_msg);
4729
4730                                            LlamaCoreError::Operation(err_msg)
4731                                        })?;
4732
4733                                    let chat_completion_chunk = ChatCompletionChunk {
4734                                        id,
4735                                        object: "chat.completion.chunk".to_string(),
4736                                        created: created.as_secs(),
4737                                        model: graph.name().to_owned(),
4738                                        system_fingerprint: "fp_44709d6fcb".to_string(),
4739                                        choices: vec![ChatCompletionChunkChoice {
4740                                            index: 0,
4741                                            delta: ChatCompletionChunkChoiceDelta {
4742                                                role: ChatCompletionRole::Assistant,
4743                                                content: None,
4744                                                tool_calls: vec![],
4745                                            },
4746                                            logprobs: None,
4747                                            finish_reason: Some(FinishReason::length),
4748                                        }],
4749                                        usage: None,
4750                                    };
4751
4752                                    // serialize chat completion chunk
4753                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
4754                                        .map_err(|e| {
4755                                        let err_msg = format!(
4756                                            "Failed to serialize chat completion chunk. Reason: {e}"
4757                                        );
4758
4759                                        #[cfg(feature = "logging")]
4760                                        error!(target: "stdout", "{}", &err_msg);
4761
4762                                        LlamaCoreError::Operation(err_msg)
4763                                    })?;
4764
4765                                    Ok(format!("data: {chunk_str}\n\n"))
4766                                }
4767                                PromptTooLongState::Usage => {
4768                                    *prompt_too_long_state = PromptTooLongState::Done;
4769
4770                                    // retrieve the number of prompt and completion tokens
4771                                    let token_info = get_token_info_by_graph(graph)?;
4772
4773                                    let usage = Some(Usage {
4774                                        prompt_tokens: token_info.prompt_tokens,
4775                                        completion_tokens: token_info.completion_tokens,
4776                                        total_tokens: token_info.prompt_tokens
4777                                            + token_info.completion_tokens,
4778                                    });
4779
4780                                    let created = SystemTime::now()
4781                                        .duration_since(std::time::UNIX_EPOCH)
4782                                        .map_err(|e| {
4783                                            let err_msg = format!(
4784                                                "Failed to get the current time. Reason: {e}"
4785                                            );
4786
4787                                            #[cfg(feature = "logging")]
4788                                            error!(target: "stdout", "{}", &err_msg);
4789
4790                                            LlamaCoreError::Operation(err_msg)
4791                                        })?;
4792
4793                                    let chat_completion_chunk = ChatCompletionChunk {
4794                                        id,
4795                                        object: "chat.completion.chunk".to_string(),
4796                                        created: created.as_secs(),
4797                                        model: graph.name().to_owned(),
4798                                        system_fingerprint: "fp_44709d6fcb".to_string(),
4799                                        choices: vec![],
4800                                        usage,
4801                                    };
4802
4803                                    // serialize chat completion chunk
4804                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
4805                                        .map_err(|e| {
4806                                        let err_msg = format!(
4807                                            "Failed to serialize chat completion chunk. Reason: {e}"
4808                                        );
4809
4810                                        #[cfg(feature = "logging")]
4811                                        error!(target: "stdout", "{}", &err_msg);
4812
4813                                        LlamaCoreError::Operation(err_msg)
4814                                    })?;
4815
4816                                    Ok(format!("data: {chunk_str}\n\n"))
4817                                }
4818                                PromptTooLongState::Done => {
4819                                    *prompt_too_long_state = PromptTooLongState::EndOfSequence;
4820
4821                                    Ok("data: [DONE]\n\n".to_string())
4822                                }
4823                                PromptTooLongState::EndOfSequence => {
4824                                    Ok("[GGML] End of sequence".to_string())
4825                                }
4826                            }
4827                        }
4828                        Err(e) => {
4829                            let err_msg =
4830                                format!("Failed to compute the chat completion. Reason: {e}");
4831
4832                            #[cfg(feature = "logging")]
4833                            error!(target: "stdout", "{}", &err_msg);
4834
4835                            Err(LlamaCoreError::Backend(BackendError::ComputeSingle(
4836                                err_msg,
4837                            )))
4838                        }
4839                    }
4840                }
4841                None => {
4842                    let err_msg = "There is no model available in the chat graphs.";
4843
4844                    #[cfg(feature = "logging")]
4845                    error!(target: "stdout", "{}", &err_msg);
4846
4847                    Err(LlamaCoreError::Operation(err_msg.into()))
4848                }
4849            }
4850        }
4851    };
4852
4853    #[cfg(feature = "logging")]
4854    info!(target: "stdout", "Return the chat stream chunk!");
4855
4856    res
4857}
4858
4859#[allow(dead_code)]
4860#[derive(Debug)]
4861struct ParseResult {
4862    raw: String,
4863    content: Option<String>,
4864    tool_calls: Vec<ToolCall>,
4865}