llama_core/
chat.rs

1//! Define APIs for chat completion.
2
3use crate::{
4    error,
5    metadata::ggml::GgmlMetadata,
6    running_mode,
7    utils::{
8        gen_chat_id, get_output_buffer, get_output_buffer_single, get_token_info_by_graph,
9        get_token_info_by_graph_name, set_tensor_data_u8,
10    },
11    Graph, RunningMode, CACHED_UTF8_ENCODINGS, CHAT_GRAPHS, OUTPUT_TENSOR,
12};
13use chat_prompts::{
14    chat::{BuildChatPrompt, ChatPrompt},
15    PromptTemplateType,
16};
17use either::{Either, Left, Right};
18use endpoints::{
19    chat::{
20        ChatCompletionChunk, ChatCompletionChunkChoice, ChatCompletionChunkChoiceDelta,
21        ChatCompletionObject, ChatCompletionObjectChoice, ChatCompletionObjectMessage,
22        ChatCompletionRequest, ChatCompletionRequestMessage, ChatCompletionRole,
23        ChatCompletionUserMessageContent, ContentPart, Function, ToolCall, ToolCallForChunk,
24        ToolChoice,
25    },
26    common::{FinishReason, Usage},
27};
28use error::{BackendError, LlamaCoreError};
29use futures::StreamExt;
30use std::{
31    collections::VecDeque,
32    fs::{self, File},
33    path::Path,
34    pin::Pin,
35    sync::{
36        atomic::{AtomicBool, Ordering},
37        Mutex, OnceLock,
38    },
39    task::{Context, Poll, Waker},
40    time::SystemTime,
41};
42
43// Define a global waker queue for storing waiting ChatStreams
44static CHAT_STREAM_WAKER_QUEUE: OnceLock<Mutex<VecDeque<Waker>>> = OnceLock::new();
45
46// Define a global atomic boolean indicating whether there is an active ChatStream
47static CHAT_STREAM_ACTIVE: AtomicBool = AtomicBool::new(false);
48
49/// Processes a chat-completion request and returns either a stream of ChatCompletionChunk instances or a ChatCompletionObject instance.
50pub async fn chat(
51    chat_request: &mut ChatCompletionRequest,
52) -> Result<
53    (
54        Either<impl futures::TryStream<Ok = String, Error = LlamaCoreError>, ChatCompletionObject>,
55        bool,
56    ),
57    LlamaCoreError,
58> {
59    #[cfg(feature = "logging")]
60    {
61        debug!(target: "stdout", "tool choice: {:?}", chat_request.tool_choice.as_ref());
62        debug!(target: "stdout", "tools: {:?}", chat_request.tools.as_ref());
63        debug!(target: "stdout", "stream mode: {:?}", chat_request.stream);
64    }
65
66    let result = match chat_request.stream {
67        Some(true) => match chat_stream(chat_request).await {
68            Ok((stream, include_tool_calls)) => Ok((Left(stream), include_tool_calls)),
69            Err(e) => Err(e),
70        },
71        Some(false) | None => match chat_once(chat_request).await {
72            Ok((chat_completion_object, include_tool_calls)) => {
73                Ok((Right(chat_completion_object), include_tool_calls))
74            }
75            Err(e) => Err(e),
76        },
77    };
78
79    #[cfg(feature = "logging")]
80    info!(target: "stdout", "Reset the model metadata");
81
82    result
83}
84
85// /// Processes a chat-completion request and returns ChatCompletionChunk instances in stream.
86// #[deprecated(since = "0.10.0", note = "Please use the `chat` function.")]
87// pub async fn chat_completions_stream(
88//     chat_request: &mut ChatCompletionRequest,
89// ) -> Result<impl futures::TryStream<Ok = String, Error = LlamaCoreError>, LlamaCoreError> {
90//     chat_stream(chat_request).await
91// }
92
93// /// Processes a chat-completion request and returns a ChatCompletionObject instance.
94// #[deprecated(since = "0.10.0", note = "Please use the `chat` function.")]
95// pub async fn chat_completions(
96//     chat_request: &mut ChatCompletionRequest,
97// ) -> Result<ChatCompletionObject, LlamaCoreError> {
98//     chat_once(chat_request).await
99// }
100
101async fn chat_stream(
102    chat_request: &mut ChatCompletionRequest,
103) -> Result<
104    (
105        impl futures::TryStream<Ok = String, Error = LlamaCoreError>,
106        bool,
107    ),
108    LlamaCoreError,
109> {
110    #[cfg(feature = "logging")]
111    info!(target: "stdout", "Process chat completion request in the stream mode");
112
113    let running_mode = running_mode()?;
114    if !running_mode.contains(RunningMode::CHAT) && !running_mode.contains(RunningMode::RAG) {
115        let err_msg = "The chat completion is only supported in the chat or rag mode.";
116
117        #[cfg(feature = "logging")]
118        error!(target: "stdout", "{err_msg}");
119
120        return Err(LlamaCoreError::Operation(err_msg.to_string()));
121    }
122
123    let model_name = chat_request.model.clone();
124    let id = match &chat_request.user {
125        Some(id) => id.clone(),
126        None => gen_chat_id(),
127    };
128    #[cfg(feature = "logging")]
129    info!(target: "stdout", "user: {}", &id);
130
131    #[cfg(feature = "logging")]
132    info!(target: "stdout", "Check model metadata");
133
134    // update metadata
135    let mut metadata = check_model_metadata(chat_request)?;
136
137    // parse the `include_usage` option
138    let include_usage = match chat_request.stream_options {
139        Some(ref stream_options) => stream_options.include_usage.unwrap_or_default(),
140        None => metadata.include_usage,
141    };
142    #[cfg(feature = "logging")]
143    info!(target: "stdout", "include_usage: {include_usage}");
144
145    #[cfg(feature = "logging")]
146    info!(target: "stdout", "Build the chat prompt");
147
148    // build prompt
149    let (prompt, avaible_completion_tokens, tool_use) =
150        build_prompt(model_name.as_ref(), chat_request)?;
151
152    #[cfg(feature = "logging")]
153    {
154        info!(target: "stdout", "prompt:\n{}", &prompt);
155        info!(target: "stdout", "available_completion_tokens: {avaible_completion_tokens}");
156        info!(target: "stdout", "tool_use: {tool_use}");
157    }
158
159    #[cfg(feature = "logging")]
160    info!(target: "stdout", "Update the n_predict");
161
162    // update metadata n_predict
163    update_n_predict(chat_request, &mut metadata, avaible_completion_tokens)?;
164
165    #[cfg(feature = "logging")]
166    info!(target: "stdout", "Feed the prompt to the model");
167
168    // set prompt
169    set_prompt(chat_request.model.as_ref(), &prompt)?;
170
171    let stream = match tool_use {
172        false => (ChatStream::new(model_name, id, include_usage, None), false),
173        true => {
174            let chat_graphs = match CHAT_GRAPHS.get() {
175                Some(chat_graphs) => chat_graphs,
176                None => {
177                    let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
178
179                    #[cfg(feature = "logging")]
180                    error!(target: "stdout", "{}", &err_msg);
181
182                    return Err(LlamaCoreError::Operation(err_msg.into()));
183                }
184            };
185
186            let mut chat_graphs = chat_graphs.lock().map_err(|e| {
187                let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
188
189                #[cfg(feature = "logging")]
190                error!(target: "stdout", "{}", &err_msg);
191
192                LlamaCoreError::Operation(err_msg)
193            })?;
194
195            match model_name {
196                Some(model_name) => match chat_graphs.contains_key(&model_name) {
197                    true => {
198                        let graph = chat_graphs.get_mut(&model_name).unwrap();
199                        chat_stream_for_tool(graph, id, include_usage)?
200                    }
201                    false => match chat_graphs.iter_mut().next() {
202                        Some((_, graph)) => chat_stream_for_tool(graph, id, include_usage)?,
203                        None => {
204                            let err_msg = "There is no model available in the chat graphs.";
205
206                            #[cfg(feature = "logging")]
207                            error!(target: "stdout", "{}", &err_msg);
208
209                            return Err(LlamaCoreError::Operation(err_msg.into()));
210                        }
211                    },
212                },
213                None => match chat_graphs.iter_mut().next() {
214                    Some((_, graph)) => chat_stream_for_tool(graph, id, include_usage)?,
215                    None => {
216                        let err_msg = "There is no model available in the chat graphs.";
217
218                        #[cfg(feature = "logging")]
219                        error!(target: "stdout", "{}", &err_msg);
220
221                        return Err(LlamaCoreError::Operation(err_msg.into()));
222                    }
223                },
224            }
225        }
226    };
227
228    #[cfg(feature = "logging")]
229    info!(target: "stdout", "End of the chat completion stream.");
230
231    Ok(stream)
232}
233
234fn chat_stream_for_tool(
235    graph: &mut Graph<GgmlMetadata>,
236    id: impl Into<String>,
237    include_usage: bool,
238) -> Result<(ChatStream, bool), LlamaCoreError> {
239    #[cfg(feature = "logging")]
240    info!(target: "stdout", "Handle chat request with available tools by the model named {}.", graph.name());
241
242    let id = id.into();
243
244    match graph.compute() {
245        Ok(_) => {
246            // Retrieve the output.
247            let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
248            let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
249                let err_msg = format!(
250                    "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
251                );
252
253                #[cfg(feature = "logging")]
254                error!(target: "stdout", "{}", &err_msg);
255
256                LlamaCoreError::Operation(err_msg)
257            })?;
258
259            #[cfg(feature = "logging")]
260            info!(target: "stdout", "raw generation:\n{output}");
261
262            // post-process
263            let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
264                LlamaCoreError::Operation(format!("Failed to post-process the output. {e}"))
265            })?;
266
267            #[cfg(feature = "logging")]
268            info!(target: "stdout", "post-processed generation:\n{}", &message);
269
270            // retrieve the number of prompt and completion tokens
271            let token_info = get_token_info_by_graph(graph)?;
272
273            #[cfg(feature = "logging")]
274            info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
275
276            let usage = Some(Usage {
277                prompt_tokens: token_info.prompt_tokens,
278                completion_tokens: token_info.completion_tokens,
279                total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
280            });
281
282            let created = SystemTime::now()
283                .duration_since(std::time::UNIX_EPOCH)
284                .map_err(|e| {
285                    let err_msg = format!("Failed to get the current time. Reason: {e}");
286
287                    #[cfg(feature = "logging")]
288                    error!(target: "stdout", "{}", &err_msg);
289
290                    LlamaCoreError::Operation(err_msg)
291                })?;
292
293            if graph.metadata.prompt_template != PromptTemplateType::MistralTool
294                && graph.metadata.prompt_template != PromptTemplateType::ChatMLTool
295                && graph.metadata.prompt_template != PromptTemplateType::GroqLlama3Tool
296                && graph.metadata.prompt_template != PromptTemplateType::Llama3Tool
297                && graph.metadata.prompt_template != PromptTemplateType::InternLM2Tool
298                && graph.metadata.prompt_template != PromptTemplateType::NemotronTool
299                && graph.metadata.prompt_template != PromptTemplateType::FunctionaryV32
300                && graph.metadata.prompt_template != PromptTemplateType::FunctionaryV31
301                && graph.metadata.prompt_template != PromptTemplateType::MistralSmallTool
302                && graph.metadata.prompt_template != PromptTemplateType::Llama4Chat
303            {
304                let err_msg = format!("Unsupported prompt template: {}. The tool use is only supported for 'mistral-tool', 'chatml-tool', 'groq-llama3-tool', 'llama-3-tool', 'internlm-2-tool', 'nemotron-tool', 'functionary-31', 'functionary-32', 'mistral-small-tool', and 'llama-4-chat' prompt templates.", graph.metadata.prompt_template);
305
306                #[cfg(feature = "logging")]
307                error!(target: "stdout", "{}", &err_msg);
308
309                return Err(LlamaCoreError::Operation(err_msg));
310            }
311
312            let parsed_result = parse_tool_calls(&message, graph.metadata.prompt_template)?;
313
314            let content = if parsed_result.tool_calls.is_empty() {
315                Some(parsed_result.raw.clone())
316            } else {
317                parsed_result.content.clone()
318            };
319
320            let (tool_calls, include_tool_calls) = match parsed_result.tool_calls.is_empty() {
321                false => {
322                    let tool_calls: Vec<ToolCallForChunk> = parsed_result
323                        .tool_calls
324                        .into_iter()
325                        .enumerate()
326                        .map(|(index, tool_call)| ToolCallForChunk {
327                            index,
328                            id: tool_call.id,
329                            ty: tool_call.ty,
330                            function: tool_call.function,
331                        })
332                        .collect();
333                    (tool_calls, true)
334                }
335                true => (vec![], false),
336            };
337
338            // tool_calls chunk
339            let tool_call_chunk = {
340                let chat_completion_chunk = ChatCompletionChunk {
341                    id: id.clone(),
342                    object: "chat.completion.chunk".to_string(),
343                    created: created.as_secs(),
344                    model: graph.name().to_owned(),
345                    system_fingerprint: "fp_44709d6fcb".to_string(),
346                    choices: vec![ChatCompletionChunkChoice {
347                        index: 0,
348                        delta: ChatCompletionChunkChoiceDelta {
349                            role: ChatCompletionRole::Assistant,
350                            content,
351                            tool_calls,
352                        },
353                        logprobs: None,
354                        finish_reason: None,
355                    }],
356                    usage: None,
357                };
358                let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
359                    let err_msg = format!("Failed to serialize chat completion chunk. Reason: {e}");
360
361                    #[cfg(feature = "logging")]
362                    error!(target: "stdout", "{}", &err_msg);
363
364                    LlamaCoreError::Operation(err_msg)
365                })?;
366
367                format!("data: {chunk_str}\n\n")
368            };
369
370            // token uage chunk
371            let usage_chunk = {
372                let chat_completion_chunk = ChatCompletionChunk {
373                    id: id.clone(),
374                    object: "chat.completion.chunk".to_string(),
375                    created: created.as_secs(),
376                    model: graph.name().to_owned(),
377                    system_fingerprint: "fp_44709d6fcb".to_string(),
378                    choices: vec![],
379                    usage,
380                };
381                let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
382                    let err_msg = format!("Failed to serialize chat completion chunk. Reason: {e}");
383
384                    #[cfg(feature = "logging")]
385                    error!(target: "stdout", "{}", &err_msg);
386
387                    LlamaCoreError::Operation(err_msg)
388                })?;
389
390                format!("data: {chunk_str}\n\n")
391            };
392
393            // ending chunk
394            let ending_chunk = "data: [DONE]\n\n".to_string();
395
396            let chunks = vec![tool_call_chunk, usage_chunk, ending_chunk];
397
398            let stream = ChatStream::new(
399                Some(graph.name().to_owned()),
400                id,
401                include_usage,
402                Some(chunks),
403            );
404
405            Ok((stream, include_tool_calls))
406        }
407        Err(wasmedge_wasi_nn::Error::BackendError(wasmedge_wasi_nn::BackendError::ContextFull)) => {
408            // Retrieve the output.
409            let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
410            let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
411                let err_msg = format!(
412                    "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
413                );
414
415                #[cfg(feature = "logging")]
416                error!(target: "stdout", "{}", &err_msg);
417
418                LlamaCoreError::Operation(err_msg)
419            })?;
420
421            // post-process
422            let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
423                let err_msg = format!("Failed to post-process the output. {e}");
424
425                #[cfg(feature = "logging")]
426                error!(target: "stdout", "{}", &err_msg);
427
428                LlamaCoreError::Operation(err_msg)
429            })?;
430
431            // retrieve the number of prompt and completion tokens
432            let token_info = get_token_info_by_graph(graph)?;
433
434            #[cfg(feature = "logging")]
435            info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
436
437            let usage = Some(Usage {
438                prompt_tokens: token_info.prompt_tokens,
439                completion_tokens: token_info.completion_tokens,
440                total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
441            });
442
443            let created = SystemTime::now()
444                .duration_since(std::time::UNIX_EPOCH)
445                .map_err(|e| {
446                    let err_msg = format!("Failed to get the current time. Reason: {e}");
447
448                    #[cfg(feature = "logging")]
449                    error!(target: "stdout", "{}", &err_msg);
450
451                    LlamaCoreError::Operation(err_msg)
452                })?;
453
454            // context full chunk
455            let context_full_chunk = {
456                let chat_completion_chunk = ChatCompletionChunk {
457                    id: id.clone(),
458                    object: "chat.completion.chunk".to_string(),
459                    created: created.as_secs(),
460                    model: graph.name().to_owned(),
461                    system_fingerprint: "fp_44709d6fcb".to_string(),
462                    choices: vec![ChatCompletionChunkChoice {
463                        index: 0,
464                        delta: ChatCompletionChunkChoiceDelta {
465                            role: ChatCompletionRole::Assistant,
466                            content: Some(message),
467                            tool_calls: vec![],
468                        },
469                        logprobs: None,
470                        finish_reason: Some(FinishReason::length),
471                    }],
472                    usage: None,
473                };
474
475                // serialize chat completion chunk
476                let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
477                    let err_msg = format!("Failed to serialize chat completion chunk. Reason: {e}");
478
479                    #[cfg(feature = "logging")]
480                    error!(target: "stdout", "{}", &err_msg);
481
482                    LlamaCoreError::Operation(err_msg)
483                })?;
484
485                format!("data: {chunk_str}\n\n")
486            };
487
488            // usage chunk
489            let usage_chunk = {
490                let chat_completion_chunk = ChatCompletionChunk {
491                    id: id.clone(),
492                    object: "chat.completion.chunk".to_string(),
493                    created: created.as_secs(),
494                    model: graph.name().to_owned(),
495                    system_fingerprint: "fp_44709d6fcb".to_string(),
496                    choices: vec![],
497                    usage,
498                };
499
500                // serialize chat completion chunk
501                let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
502                    let err_msg = format!("Failed to serialize chat completion chunk. Reason: {e}");
503
504                    #[cfg(feature = "logging")]
505                    error!(target: "stdout", "{}", &err_msg);
506
507                    LlamaCoreError::Operation(err_msg)
508                })?;
509
510                format!("data: {chunk_str}\n\n")
511            };
512
513            // ending chunk
514            let ending_chunk = "data: [DONE]\n\n".to_string();
515
516            let chunks = vec![context_full_chunk, usage_chunk, ending_chunk];
517
518            let stream = ChatStream::new(
519                Some(graph.name().to_owned()),
520                id,
521                include_usage,
522                Some(chunks),
523            );
524
525            Ok((stream, false))
526        }
527        Err(wasmedge_wasi_nn::Error::BackendError(
528            wasmedge_wasi_nn::BackendError::PromptTooLong,
529        )) => {
530            #[cfg(feature = "logging")]
531            warn!(target: "stdout", "The prompt is too long. Please reduce the length of your input and try again.");
532
533            // Retrieve the output.
534            let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
535            let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
536                let err_msg = format!(
537                    "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
538                );
539
540                #[cfg(feature = "logging")]
541                error!(target: "stdout", "{}", &err_msg);
542
543                LlamaCoreError::Operation(err_msg)
544            })?;
545
546            // post-process
547            let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
548                let err_msg = format!("Failed to post-process the output. {e}");
549
550                #[cfg(feature = "logging")]
551                error!(target: "stdout", "{}", &err_msg);
552
553                LlamaCoreError::Operation(err_msg)
554            })?;
555
556            // retrieve the number of prompt and completion token
557            let token_info = get_token_info_by_graph(graph)?;
558
559            #[cfg(feature = "logging")]
560            info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
561
562            let usage = Some(Usage {
563                prompt_tokens: token_info.prompt_tokens,
564                completion_tokens: token_info.completion_tokens,
565                total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
566            });
567
568            let created = SystemTime::now()
569                .duration_since(std::time::UNIX_EPOCH)
570                .map_err(|e| {
571                    let err_msg = format!("Failed to get the current time. Reason: {e}");
572
573                    #[cfg(feature = "logging")]
574                    error!(target: "stdout", "{}", &err_msg);
575
576                    LlamaCoreError::Operation(err_msg)
577                })?;
578
579            // prompt too long chunk
580            let prompt_too_long_chunk = {
581                let chat_completion_chunk = ChatCompletionChunk {
582                    id: id.clone(),
583                    object: "chat.completion.chunk".to_string(),
584                    created: created.as_secs(),
585                    model: graph.name().to_owned(),
586                    system_fingerprint: "fp_44709d6fcb".to_string(),
587                    choices: vec![ChatCompletionChunkChoice {
588                        index: 0,
589                        delta: ChatCompletionChunkChoiceDelta {
590                            role: ChatCompletionRole::Assistant,
591                            content: Some(message),
592                            tool_calls: vec![],
593                        },
594                        logprobs: None,
595                        finish_reason: Some(FinishReason::length),
596                    }],
597                    usage: None,
598                };
599
600                // serialize chat completion chunk
601                let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
602                    let err_msg = format!("Failed to serialize chat completion chunk. Reason: {e}");
603
604                    #[cfg(feature = "logging")]
605                    error!(target: "stdout", "{}", &err_msg);
606
607                    LlamaCoreError::Operation(err_msg)
608                })?;
609
610                format!("data: {chunk_str}\n\n")
611            };
612
613            // usage chunk
614            let usage_chunk = {
615                let chat_completion_chunk = ChatCompletionChunk {
616                    id: id.clone(),
617                    object: "chat.completion.chunk".to_string(),
618                    created: created.as_secs(),
619                    model: graph.name().to_owned(),
620                    system_fingerprint: "fp_44709d6fcb".to_string(),
621                    choices: vec![],
622                    usage,
623                };
624
625                // serialize chat completion chunk
626                let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
627                    let err_msg = format!("Failed to serialize chat completion chunk. Reason: {e}");
628
629                    #[cfg(feature = "logging")]
630                    error!(target: "stdout", "{}", &err_msg);
631
632                    LlamaCoreError::Operation(err_msg)
633                })?;
634
635                format!("data: {chunk_str}\n\n")
636            };
637
638            // ending chunk
639            let ending_chunk = "data: [DONE]\n\n".to_string();
640
641            let chunks = vec![prompt_too_long_chunk, usage_chunk, ending_chunk];
642
643            let stream = ChatStream::new(
644                Some(graph.name().to_owned()),
645                id,
646                include_usage,
647                Some(chunks),
648            );
649
650            Ok((stream, false))
651        }
652        Err(e) => {
653            let err_msg = format!("Failed to compute the chat completion. Reason: {e}");
654
655            #[cfg(feature = "logging")]
656            error!(target: "stdout", "{}", &err_msg);
657
658            Err(LlamaCoreError::Backend(BackendError::Compute(err_msg)))
659        }
660    }
661}
662
663async fn chat_once(
664    chat_request: &mut ChatCompletionRequest,
665) -> Result<(ChatCompletionObject, bool), LlamaCoreError> {
666    #[cfg(feature = "logging")]
667    info!(target: "stdout", "Processing chat completion request in non-stream mode");
668
669    let running_mode = running_mode()?;
670    if !running_mode.contains(RunningMode::CHAT) && !running_mode.contains(RunningMode::RAG) {
671        let err_msg = "The chat completion is only supported in the chat or rag mode.";
672
673        #[cfg(feature = "logging")]
674        error!(target: "stdout", "{err_msg}");
675
676        return Err(LlamaCoreError::Operation(err_msg.to_string()));
677    }
678
679    let model_name = chat_request.model.clone();
680    let id = match &chat_request.user {
681        Some(id) => id.clone(),
682        None => gen_chat_id(),
683    };
684
685    #[cfg(feature = "logging")]
686    info!(target: "stdout", "user: {}", &id);
687
688    #[cfg(feature = "logging")]
689    info!(target: "stdout", "Check model metadata");
690
691    // update metadata
692    let mut metadata = check_model_metadata(chat_request)?;
693
694    #[cfg(feature = "logging")]
695    info!(target: "stdout", "Build the chat prompt");
696
697    // build prompt
698    let (prompt, avaible_completion_tokens, tool_use) =
699        build_prompt(model_name.as_ref(), chat_request)?;
700
701    #[cfg(feature = "logging")]
702    {
703        info!(target: "stdout", "prompt:\n{}", &prompt);
704        info!(target: "stdout", "available_completion_tokens: {avaible_completion_tokens}");
705        info!(target: "stdout", "tool_use: {tool_use}");
706    }
707
708    #[cfg(feature = "logging")]
709    info!(target: "stdout", "Update n_predict");
710
711    // update metadata n_predict
712    update_n_predict(chat_request, &mut metadata, avaible_completion_tokens)?;
713
714    #[cfg(feature = "logging")]
715    info!(target: "stdout", "Feed the prompt to the model");
716
717    // feed the prompt to the model
718    set_prompt(model_name.as_ref(), &prompt)?;
719
720    #[cfg(feature = "logging")]
721    info!(target: "stdout", "Compute chat completion.");
722
723    // compute
724    let res = compute(model_name.as_ref(), id, tool_use);
725
726    #[cfg(feature = "logging")]
727    info!(target: "stdout", "End of the chat completion");
728
729    // reset the model metadata
730    reset_model_metadata(model_name.as_ref())?;
731
732    res
733}
734
735fn compute(
736    model_name: Option<&String>,
737    id: impl Into<String>,
738    tool_use: bool,
739) -> Result<(ChatCompletionObject, bool), LlamaCoreError> {
740    let chat_graphs = match CHAT_GRAPHS.get() {
741        Some(chat_graphs) => chat_graphs,
742        None => {
743            let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
744
745            #[cfg(feature = "logging")]
746            error!(target: "stdout", "{}", &err_msg);
747
748            return Err(LlamaCoreError::Operation(err_msg.into()));
749        }
750    };
751
752    let mut chat_graphs = chat_graphs.lock().map_err(|e| {
753        let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
754
755        #[cfg(feature = "logging")]
756        error!(target: "stdout", "{}", &err_msg);
757
758        LlamaCoreError::Operation(err_msg)
759    })?;
760
761    match model_name {
762        Some(model_name) => match chat_graphs.contains_key(model_name) {
763            true => {
764                let graph = chat_graphs.get_mut(model_name).unwrap();
765                compute_by_graph(graph, id, tool_use)
766            }
767            false => match chat_graphs.iter_mut().next() {
768                Some((_, graph)) => compute_by_graph(graph, id, tool_use),
769                None => {
770                    let err_msg = "There is no model available in the chat graphs.";
771
772                    #[cfg(feature = "logging")]
773                    error!(target: "stdout", "{}", &err_msg);
774
775                    Err(LlamaCoreError::Operation(err_msg.into()))
776                }
777            },
778        },
779        None => match chat_graphs.iter_mut().next() {
780            Some((_, graph)) => compute_by_graph(graph, id, tool_use),
781            None => {
782                let err_msg = "There is no model available in the chat graphs.";
783
784                #[cfg(feature = "logging")]
785                error!(target: "stdout", "{}", &err_msg);
786
787                Err(LlamaCoreError::Operation(err_msg.into()))
788            }
789        },
790    }
791}
792
793fn compute_by_graph(
794    graph: &mut Graph<GgmlMetadata>,
795    id: impl Into<String>,
796    tool_use: bool,
797) -> Result<(ChatCompletionObject, bool), LlamaCoreError> {
798    #[cfg(feature = "logging")]
799    info!(target: "stdout", "Compute chat completion by the model named {}.", graph.name());
800
801    match graph.compute() {
802        Ok(_) => {
803            // Retrieve the output.
804            let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
805            let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
806                let err_msg = format!(
807                    "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
808                );
809
810                #[cfg(feature = "logging")]
811                error!(target: "stdout", "{}", &err_msg);
812
813                LlamaCoreError::Operation(err_msg)
814            })?;
815
816            #[cfg(feature = "logging")]
817            info!(target: "stdout", "raw generation: {output}");
818
819            // post-process
820            let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
821                LlamaCoreError::Operation(format!("Failed to post-process the output. {e}"))
822            })?;
823
824            #[cfg(feature = "logging")]
825            info!(target: "stdout", "post-processed generation:\n{}", &message);
826
827            // retrieve the number of prompt and completion tokens
828            let token_info = get_token_info_by_graph(graph)?;
829
830            #[cfg(feature = "logging")]
831            info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
832
833            let created = SystemTime::now()
834                .duration_since(std::time::UNIX_EPOCH)
835                .map_err(|e| {
836                    let err_msg = format!("Failed to get the current time. Reason: {e}");
837
838                    #[cfg(feature = "logging")]
839                    error!(target: "stdout", "{}", &err_msg);
840
841                    LlamaCoreError::Operation(err_msg)
842                })?;
843
844            match tool_use {
845                true => {
846                    if graph.metadata.prompt_template != PromptTemplateType::MistralTool
847                        && graph.metadata.prompt_template != PromptTemplateType::ChatMLTool
848                        && graph.metadata.prompt_template != PromptTemplateType::GroqLlama3Tool
849                        && graph.metadata.prompt_template != PromptTemplateType::Llama3Tool
850                        && graph.metadata.prompt_template != PromptTemplateType::InternLM2Tool
851                        && graph.metadata.prompt_template != PromptTemplateType::NemotronTool
852                        && graph.metadata.prompt_template != PromptTemplateType::FunctionaryV32
853                        && graph.metadata.prompt_template != PromptTemplateType::FunctionaryV31
854                        && graph.metadata.prompt_template != PromptTemplateType::MistralSmallTool
855                        && graph.metadata.prompt_template != PromptTemplateType::Llama4Chat
856                    {
857                        let err_msg = format!("Unsupported prompt template: {}. The tool use is only supported for 'mistral-tool', 'chatml-tool', 'groq-llama3-tool', 'llama-3-tool', 'internlm-2-tool', 'nemotron-tool', 'functionary-31', 'functionary-32', 'mistral-small-tool', and 'llama-4-chat' prompt templates.", graph.metadata.prompt_template);
858
859                        #[cfg(feature = "logging")]
860                        error!(target: "stdout", "{}", &err_msg);
861
862                        return Err(LlamaCoreError::Operation(err_msg));
863                    }
864
865                    let parsed_result = parse_tool_calls(&message, graph.metadata.prompt_template)?;
866
867                    let (finish_reason, content, include_tool_calls) =
868                        if parsed_result.tool_calls.is_empty() {
869                            (FinishReason::stop, Some(parsed_result.raw.clone()), false)
870                        } else {
871                            (
872                                FinishReason::tool_calls,
873                                parsed_result.content.clone(),
874                                true,
875                            )
876                        };
877
878                    let res = ChatCompletionObject {
879                        id: id.into(),
880                        object: String::from("chat.completion"),
881                        created: created.as_secs(),
882                        model: graph.name().to_owned(),
883                        choices: vec![ChatCompletionObjectChoice {
884                            index: 0,
885                            message: ChatCompletionObjectMessage {
886                                role: ChatCompletionRole::Assistant,
887                                content,
888                                tool_calls: parsed_result.tool_calls,
889                                function_call: None,
890                            },
891                            finish_reason,
892                            logprobs: None,
893                        }],
894                        usage: Usage {
895                            prompt_tokens: token_info.prompt_tokens,
896                            completion_tokens: token_info.completion_tokens,
897                            total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
898                        },
899                    };
900
901                    // create ChatCompletionResponse
902                    Ok((res, include_tool_calls))
903                }
904                false => {
905                    // create ChatCompletionResponse
906                    let res = ChatCompletionObject {
907                        id: id.into(),
908                        object: String::from("chat.completion"),
909                        created: created.as_secs(),
910                        model: graph.name().to_owned(),
911                        choices: vec![ChatCompletionObjectChoice {
912                            index: 0,
913                            message: ChatCompletionObjectMessage {
914                                role: ChatCompletionRole::Assistant,
915                                content: Some(message),
916                                tool_calls: vec![],
917                                function_call: None,
918                            },
919                            finish_reason: FinishReason::stop,
920                            logprobs: None,
921                        }],
922                        usage: Usage {
923                            prompt_tokens: token_info.prompt_tokens,
924                            completion_tokens: token_info.completion_tokens,
925                            total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
926                        },
927                    };
928
929                    Ok((res, false))
930                }
931            }
932        }
933        Err(wasmedge_wasi_nn::Error::BackendError(wasmedge_wasi_nn::BackendError::ContextFull)) => {
934            // Retrieve the output.
935            let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
936            let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
937                let err_msg = format!(
938                    "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
939                );
940
941                #[cfg(feature = "logging")]
942                error!(target: "stdout", "{}", &err_msg);
943
944                LlamaCoreError::Operation(err_msg)
945            })?;
946
947            // post-process
948            let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
949                let err_msg = format!("Failed to post-process the output. {e}");
950
951                #[cfg(feature = "logging")]
952                error!(target: "stdout", "{}", &err_msg);
953
954                LlamaCoreError::Operation(err_msg)
955            })?;
956
957            // retrieve the number of prompt and completion tokens
958            let token_info = get_token_info_by_graph(graph)?;
959
960            #[cfg(feature = "logging")]
961            info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
962
963            let created = SystemTime::now()
964                .duration_since(std::time::UNIX_EPOCH)
965                .map_err(|e| {
966                    let err_msg = format!("Failed to get the current time. Reason: {e}");
967
968                    #[cfg(feature = "logging")]
969                    error!(target: "stdout", "{}", &err_msg);
970
971                    LlamaCoreError::Operation(err_msg)
972                })?;
973
974            // create ChatCompletionResponse
975            let res = ChatCompletionObject {
976                id: id.into(),
977                object: String::from("chat.completion"),
978                created: created.as_secs(),
979                model: graph.name().to_owned(),
980                choices: vec![ChatCompletionObjectChoice {
981                    index: 0,
982                    message: ChatCompletionObjectMessage {
983                        role: ChatCompletionRole::Assistant,
984                        content: Some(message),
985                        tool_calls: vec![],
986                        function_call: None,
987                    },
988                    finish_reason: FinishReason::length,
989                    logprobs: None,
990                }],
991                usage: Usage {
992                    prompt_tokens: token_info.prompt_tokens,
993                    completion_tokens: token_info.completion_tokens,
994                    total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
995                },
996            };
997
998            Ok((res, false))
999        }
1000        Err(wasmedge_wasi_nn::Error::BackendError(
1001            wasmedge_wasi_nn::BackendError::PromptTooLong,
1002        )) => {
1003            #[cfg(feature = "logging")]
1004            warn!(target: "stdout", "The prompt is too long. Please reduce the length of your input and try again.");
1005
1006            // Retrieve the output.
1007            let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
1008            let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
1009                let err_msg = format!(
1010                    "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
1011                );
1012
1013                #[cfg(feature = "logging")]
1014                error!(target: "stdout", "{}", &err_msg);
1015
1016                LlamaCoreError::Operation(err_msg)
1017            })?;
1018
1019            // post-process
1020            let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
1021                let err_msg = format!("Failed to post-process the output. {e}");
1022
1023                #[cfg(feature = "logging")]
1024                error!(target: "stdout", "{}", &err_msg);
1025
1026                LlamaCoreError::Operation(err_msg)
1027            })?;
1028
1029            // retrieve the number of prompt and completion token
1030            let token_info = get_token_info_by_graph(graph)?;
1031
1032            #[cfg(feature = "logging")]
1033            info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
1034
1035            let usage = Usage {
1036                prompt_tokens: token_info.prompt_tokens,
1037                completion_tokens: token_info.completion_tokens,
1038                total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
1039            };
1040
1041            let created = SystemTime::now()
1042                .duration_since(std::time::UNIX_EPOCH)
1043                .map_err(|e| {
1044                    let err_msg = format!("Failed to get the current time. Reason: {e}");
1045
1046                    #[cfg(feature = "logging")]
1047                    error!(target: "stdout", "{}", &err_msg);
1048
1049                    LlamaCoreError::Operation(err_msg)
1050                })?;
1051
1052            // create ChatCompletionResponse
1053            let res = ChatCompletionObject {
1054                id: id.into(),
1055                object: String::from("chat.completion"),
1056                created: created.as_secs(),
1057                model: graph.name().to_owned(),
1058                choices: vec![ChatCompletionObjectChoice {
1059                    index: 0,
1060                    message: ChatCompletionObjectMessage {
1061                        role: ChatCompletionRole::Assistant,
1062                        content: Some(message),
1063                        tool_calls: vec![],
1064                        function_call: None,
1065                    },
1066                    finish_reason: FinishReason::length,
1067                    logprobs: None,
1068                }],
1069                usage,
1070            };
1071
1072            Ok((res, false))
1073        }
1074        Err(e) => {
1075            let err_msg = format!("Failed to compute the chat completion. Reason: {e}");
1076
1077            #[cfg(feature = "logging")]
1078            error!(target: "stdout", "{}", &err_msg);
1079
1080            Err(LlamaCoreError::Backend(BackendError::Compute(err_msg)))
1081        }
1082    }
1083}
1084
1085fn parse_tool_calls(
1086    input: &str,
1087    prompt_template: PromptTemplateType,
1088) -> Result<ParseResult, LlamaCoreError> {
1089    match prompt_template {
1090        PromptTemplateType::MistralTool => match regex::Regex::new(r"\[\{.*?\}\]") {
1091            Ok(re) => {
1092                let mut values: Vec<serde_json::Value> = vec![];
1093                for cap in re.captures_iter(input) {
1094                    let matched = &cap[0];
1095
1096                    #[cfg(feature = "logging")]
1097                    info!(target: "stdout", "captured: {matched}");
1098
1099                    match serde_json::from_str::<Vec<serde_json::Value>>(matched) {
1100                        Ok(group) => values.extend(group),
1101                        Err(e) => {
1102                            let err_msg =
1103                                format!("Failed to deserialize generated tool calls. Reason: {e}");
1104
1105                            #[cfg(feature = "logging")]
1106                            error!(target: "stdout", "{}", &err_msg);
1107
1108                            return Err(LlamaCoreError::Operation(err_msg));
1109                        }
1110                    }
1111                }
1112
1113                let mut tool_calls: Vec<ToolCall> = vec![];
1114                for value in values.iter() {
1115                    let name = match value.get("name") {
1116                        Some(name) => name.to_string().replace("\"", ""),
1117                        None => {
1118                            let err_msg = format!(
1119                                "Failed to get the name of the function. Tool call: {value:?}"
1120                            );
1121
1122                            #[cfg(feature = "logging")]
1123                            error!(target: "stdout", "{}", &err_msg);
1124
1125                            return Err(LlamaCoreError::Operation(err_msg));
1126                        }
1127                    };
1128
1129                    let arguments = match value.get("arguments") {
1130                        Some(arguments) => arguments.to_string(),
1131                        None => {
1132                            let err_msg = format!(
1133                                "Failed to get the arguments of the function. Tool call: {value:?}"
1134                            );
1135
1136                            #[cfg(feature = "logging")]
1137                            error!(target: "stdout", "{}", &err_msg);
1138
1139                            return Err(LlamaCoreError::Operation(err_msg));
1140                        }
1141                    };
1142
1143                    let function = Function { name, arguments };
1144
1145                    let tool_call = ToolCall {
1146                        id: "call_abc123".to_string(),
1147                        ty: "function".to_string(),
1148                        function,
1149                    };
1150
1151                    tool_calls.push(tool_call);
1152                }
1153
1154                let parsed = ParseResult {
1155                    raw: input.to_owned(),
1156                    content: None,
1157                    tool_calls,
1158                };
1159
1160                #[cfg(feature = "logging")]
1161                info!(target: "stdout", "parsed result: {parsed:?}");
1162
1163                Ok(parsed)
1164            }
1165            Err(e) => {
1166                let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1167
1168                #[cfg(feature = "logging")]
1169                error!(target: "stdout", "{}", &err_msg);
1170
1171                Err(LlamaCoreError::Operation(err_msg))
1172            }
1173        },
1174        PromptTemplateType::ChatMLTool => {
1175            match regex::Regex::new(r"<tool_call>(.*?)</tool_call>") {
1176                Ok(re) => {
1177                    let mut values: Vec<serde_json::Value> = vec![];
1178                    for cap in re.captures_iter(input) {
1179                        let matched = cap[1].replace("\\n", ""); // Remove "\\n" from the captured group
1180
1181                        #[cfg(feature = "logging")]
1182                        info!(target: "stdout", "captured: {}", &matched);
1183
1184                        match serde_json::from_str::<serde_json::Value>(&matched) {
1185                            Ok(value) => values.push(value),
1186                            Err(e) => {
1187                                let err_msg = format!(
1188                                    "Failed to deserialize generated tool calls. Reason: {e}"
1189                                );
1190
1191                                #[cfg(feature = "logging")]
1192                                error!(target: "stdout", "{}", &err_msg);
1193
1194                                return Err(LlamaCoreError::Operation(err_msg));
1195                            }
1196                        }
1197                    }
1198
1199                    let mut tool_calls: Vec<ToolCall> = vec![];
1200                    for value in values.iter() {
1201                        let name = match value.get("name") {
1202                            Some(name) => name.to_string().replace("\"", ""),
1203                            None => {
1204                                let err_msg = format!(
1205                                    "Failed to get the name of the function. Tool call: {value:?}"
1206                                );
1207
1208                                #[cfg(feature = "logging")]
1209                                error!(target: "stdout", "{}", &err_msg);
1210
1211                                return Err(LlamaCoreError::Operation(err_msg));
1212                            }
1213                        };
1214
1215                        let arguments = match value.get("arguments") {
1216                            Some(arguments) => arguments.to_string(),
1217                            None => {
1218                                let err_msg = format!(
1219                                    "Failed to get the arguments of the function. Tool call: {value:?}"
1220                                );
1221
1222                                #[cfg(feature = "logging")]
1223                                error!(target: "stdout", "{}", &err_msg);
1224
1225                                return Err(LlamaCoreError::Operation(err_msg));
1226                            }
1227                        };
1228
1229                        let function = Function { name, arguments };
1230
1231                        let tool_call = ToolCall {
1232                            id: "call_abc123".to_string(),
1233                            ty: "function".to_string(),
1234                            function,
1235                        };
1236
1237                        tool_calls.push(tool_call);
1238                    }
1239
1240                    let parsed = ParseResult {
1241                        raw: input.to_owned(),
1242                        content: None,
1243                        tool_calls,
1244                    };
1245
1246                    #[cfg(feature = "logging")]
1247                    info!(target: "stdout", "parsed result: {parsed:?}");
1248
1249                    Ok(parsed)
1250                }
1251                Err(e) => {
1252                    let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1253
1254                    #[cfg(feature = "logging")]
1255                    error!(target: "stdout", "{}", &err_msg);
1256
1257                    Err(LlamaCoreError::Operation(err_msg))
1258                }
1259            }
1260        }
1261        PromptTemplateType::GroqLlama3Tool => {
1262            #[cfg(feature = "logging")]
1263            info!(target: "stdout", "raw input: {input}");
1264
1265            match regex::Regex::new(r"(?s)<tool_call>((.|\r|\n)*?)</tool_call>") {
1266                Ok(re) => {
1267                    let mut values: Vec<serde_json::Value> = vec![];
1268                    for cap in re.captures_iter(input) {
1269                        let matched = cap[1].trim();
1270
1271                        #[cfg(feature = "logging")]
1272                        info!(target: "stdout", "captured: {matched}");
1273
1274                        match serde_json::from_str::<serde_json::Value>(matched) {
1275                            Ok(value) => values.push(value),
1276                            Err(e) => {
1277                                let err_msg = format!(
1278                                    "Failed to deserialize generated tool calls. Reason: {e}"
1279                                );
1280
1281                                #[cfg(feature = "logging")]
1282                                error!(target: "stdout", "{}", &err_msg);
1283
1284                                return Err(LlamaCoreError::Operation(err_msg));
1285                            }
1286                        }
1287                    }
1288
1289                    let mut tool_calls: Vec<ToolCall> = vec![];
1290                    for value in values.iter() {
1291                        let name = match value.get("name") {
1292                            Some(name) => name.to_string().replace("\"", ""),
1293                            None => {
1294                                let err_msg = format!(
1295                                    "Failed to get the name of the function. Tool call: {value:?}"
1296                                );
1297
1298                                #[cfg(feature = "logging")]
1299                                error!(target: "stdout", "{}", &err_msg);
1300
1301                                return Err(LlamaCoreError::Operation(err_msg));
1302                            }
1303                        };
1304
1305                        let arguments = match value.get("arguments") {
1306                            Some(arguments) => {
1307                                if arguments.is_string() {
1308                                    arguments.as_str().unwrap().to_string()
1309                                } else if arguments.is_object() {
1310                                    let map = arguments.as_object().unwrap();
1311
1312                                    #[cfg(feature = "logging")]
1313                                    info!(target: "stdout", "func arguments: {map:?}");
1314
1315                                    serde_json::to_string(map).unwrap()
1316                                } else {
1317                                    serde_json::to_string(arguments).unwrap()
1318                                }
1319                            }
1320                            None => {
1321                                let err_msg = format!(
1322                                    "Failed to get the arguments of the function. Tool call: {value:?}"
1323                                );
1324
1325                                #[cfg(feature = "logging")]
1326                                error!(target: "stdout", "{}", &err_msg);
1327
1328                                return Err(LlamaCoreError::Operation(err_msg));
1329                            }
1330                        };
1331
1332                        let function = Function { name, arguments };
1333
1334                        let tool_call = ToolCall {
1335                            id: "call_abc123".to_string(),
1336                            ty: "function".to_string(),
1337                            function,
1338                        };
1339
1340                        tool_calls.push(tool_call);
1341                    }
1342
1343                    let parsed = if tool_calls.is_empty() {
1344                        ParseResult {
1345                            raw: input.to_owned(),
1346                            content: Some(input.to_owned()),
1347                            tool_calls: vec![],
1348                        }
1349                    } else {
1350                        ParseResult {
1351                            raw: input.to_owned(),
1352                            content: None,
1353                            tool_calls,
1354                        }
1355                    };
1356
1357                    #[cfg(feature = "logging")]
1358                    info!(target: "stdout", "parsed result: {parsed:?}");
1359
1360                    Ok(parsed)
1361                }
1362                Err(e) => {
1363                    let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1364
1365                    #[cfg(feature = "logging")]
1366                    error!(target: "stdout", "{}", &err_msg);
1367
1368                    Err(LlamaCoreError::Operation(err_msg))
1369                }
1370            }
1371        }
1372        PromptTemplateType::Llama3Tool => {
1373            #[cfg(feature = "logging")]
1374            info!(target: "stdout", "raw input: {input}");
1375
1376            let re = match regex::Regex::new(r"^\{(.|\r|\n)*\}$") {
1377                Ok(re) => re,
1378                Err(e) => {
1379                    let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1380
1381                    #[cfg(feature = "logging")]
1382                    error!(target: "stdout", "{}", &err_msg);
1383
1384                    return Err(LlamaCoreError::Operation(err_msg));
1385                }
1386            };
1387
1388            if re.is_match(input) {
1389                match serde_json::from_str::<serde_json::Value>(input) {
1390                    Ok(value) => {
1391                        let values: Vec<serde_json::Value> = vec![value];
1392
1393                        let mut tool_calls: Vec<ToolCall> = vec![];
1394                        for value in values.iter() {
1395                            let name = match value.get("name") {
1396                                Some(name) => name.to_string().replace("\"", ""),
1397                                None => {
1398                                    let err_msg = format!(
1399                                        "Failed to get the name of the function. Tool call: {value:?}"
1400                                    );
1401
1402                                    #[cfg(feature = "logging")]
1403                                    error!(target: "stdout", "{}", &err_msg);
1404
1405                                    return Err(LlamaCoreError::Operation(err_msg));
1406                                }
1407                            };
1408
1409                            let arguments = match value.get("parameters") {
1410                                Some(arguments) => arguments.to_string(),
1411                                None => {
1412                                    let err_msg = format!(
1413                                        "Failed to get the arguments of the function. Tool call: {value:?}"
1414                                    );
1415
1416                                    #[cfg(feature = "logging")]
1417                                    error!(target: "stdout", "{}", &err_msg);
1418
1419                                    return Err(LlamaCoreError::Operation(err_msg));
1420                                }
1421                            };
1422
1423                            let function = Function { name, arguments };
1424
1425                            let tool_call = ToolCall {
1426                                id: "call_abc123".to_string(),
1427                                ty: "function".to_string(),
1428                                function,
1429                            };
1430
1431                            tool_calls.push(tool_call);
1432                        }
1433
1434                        let parsed = ParseResult {
1435                            raw: input.to_owned(),
1436                            content: None,
1437                            tool_calls,
1438                        };
1439
1440                        #[cfg(feature = "logging")]
1441                        info!(target: "stdout", "parsed result: {parsed:?}");
1442
1443                        Ok(parsed)
1444                    }
1445                    Err(e) => {
1446                        let err_msg =
1447                            format!("Failed to deserialize generated tool calls. Reason: {e}");
1448
1449                        #[cfg(feature = "logging")]
1450                        error!(target: "stdout", "{}", &err_msg);
1451
1452                        Err(LlamaCoreError::Operation(err_msg))
1453                    }
1454                }
1455            } else {
1456                let parsed = ParseResult {
1457                    raw: input.to_owned(),
1458                    content: None,
1459                    tool_calls: vec![],
1460                };
1461
1462                #[cfg(feature = "logging")]
1463                info!(target: "stdout", "parsed result: {parsed:?}");
1464
1465                Ok(parsed)
1466            }
1467        }
1468        PromptTemplateType::InternLM2Tool => {
1469            #[cfg(feature = "logging")]
1470            info!(target: "stdout", "raw input: {input}");
1471
1472            let blocks: Vec<&str> = input.trim().split("<|action_start|><|plugin|>").collect();
1473
1474            #[cfg(feature = "logging")]
1475            info!(target: "stdout", "blocks: {blocks:?}");
1476
1477            let mut tool_calls: Vec<ToolCall> = vec![];
1478            let mut content = String::new();
1479            for block in blocks {
1480                let block = block.trim();
1481                if !block.is_empty() {
1482                    if block.ends_with("<|action_end|>") {
1483                        let value = block.trim().trim_end_matches("<|action_end|>");
1484
1485                        #[cfg(feature = "logging")]
1486                        info!(target: "stdout", "tool call: {value}");
1487
1488                        match serde_json::from_str::<serde_json::Value>(value) {
1489                            Ok(value) => {
1490                                let name = match value.get("name") {
1491                                    Some(name) => name.to_string().replace("\"", ""),
1492                                    None => {
1493                                        let err_msg = format!(
1494                                            "Failed to get the name of the function. Tool call: {value:?}"
1495                                        );
1496
1497                                        #[cfg(feature = "logging")]
1498                                        error!(target: "stdout", "{}", &err_msg);
1499
1500                                        return Err(LlamaCoreError::Operation(err_msg));
1501                                    }
1502                                };
1503
1504                                let arguments = match value.get("parameters") {
1505                                    Some(arguments) => arguments.to_string(),
1506                                    None => {
1507                                        let err_msg = format!(
1508                                            "Failed to get the arguments of the function. Tool call: {value:?}"
1509                                        );
1510
1511                                        #[cfg(feature = "logging")]
1512                                        error!(target: "stdout", "{}", &err_msg);
1513
1514                                        return Err(LlamaCoreError::Operation(err_msg));
1515                                    }
1516                                };
1517
1518                                let function = Function { name, arguments };
1519
1520                                let tool_call = ToolCall {
1521                                    id: "call_abc123".to_string(),
1522                                    ty: "function".to_string(),
1523                                    function,
1524                                };
1525
1526                                tool_calls.push(tool_call);
1527                            }
1528                            Err(e) => {
1529                                let err_msg = format!(
1530                                    "Failed to deserialize generated tool calls. Reason: {e}"
1531                                );
1532
1533                                #[cfg(feature = "logging")]
1534                                error!(target: "stdout", "{}", &err_msg);
1535
1536                                return Err(LlamaCoreError::Operation(err_msg));
1537                            }
1538                        }
1539                    } else {
1540                        content.push_str(block);
1541                        content.push('\n');
1542                    }
1543                }
1544            }
1545
1546            let parsed = match content.is_empty() {
1547                true => ParseResult {
1548                    raw: input.to_owned(),
1549                    content: None,
1550                    tool_calls,
1551                },
1552                false => ParseResult {
1553                    raw: input.to_owned(),
1554                    content: Some(content.trim().to_owned()),
1555                    tool_calls,
1556                },
1557            };
1558
1559            #[cfg(feature = "logging")]
1560            info!(target: "stdout", "parsed result: {parsed:?}");
1561
1562            Ok(parsed)
1563        }
1564        PromptTemplateType::NemotronTool => {
1565            #[cfg(feature = "logging")]
1566            info!(target: "stdout", "raw input: {input}");
1567
1568            match regex::Regex::new(r"(?s)<toolcall>\s*(.*?)\s*</toolcall>") {
1569                Ok(re) => {
1570                    let mut values: Vec<serde_json::Value> = vec![];
1571                    for cap in re.captures_iter(input) {
1572                        #[cfg(feature = "logging")]
1573                        info!(target: "stdout", "captured: {}", &cap[0]);
1574
1575                        #[cfg(feature = "logging")]
1576                        info!(target: "stdout", "extracted: {}", &cap[1]);
1577
1578                        let matched = cap[1].trim();
1579
1580                        #[cfg(feature = "logging")]
1581                        info!(target: "stdout", "captured: {matched}");
1582
1583                        match serde_json::from_str::<serde_json::Value>(matched) {
1584                            Ok(value) => values.push(value),
1585                            Err(e) => {
1586                                let err_msg = format!(
1587                                    "Failed to deserialize generated tool calls. Reason: {e}"
1588                                );
1589
1590                                #[cfg(feature = "logging")]
1591                                error!(target: "stdout", "{}", &err_msg);
1592
1593                                return Err(LlamaCoreError::Operation(err_msg));
1594                            }
1595                        }
1596                    }
1597
1598                    let mut tool_calls: Vec<ToolCall> = vec![];
1599                    for value in values.iter() {
1600                        let name = match value.get("name") {
1601                            Some(name) => name.to_string().replace("\"", ""),
1602                            None => {
1603                                let err_msg = format!(
1604                                    "Failed to get the name of the function. Tool call: {value:?}"
1605                                );
1606
1607                                #[cfg(feature = "logging")]
1608                                error!(target: "stdout", "{}", &err_msg);
1609
1610                                return Err(LlamaCoreError::Operation(err_msg));
1611                            }
1612                        };
1613
1614                        let arguments = match value.get("arguments") {
1615                            Some(arguments) => arguments.to_string(),
1616                            None => {
1617                                let err_msg = format!(
1618                                    "Failed to get the arguments of the function. Tool call: {value:?}"
1619                                );
1620
1621                                #[cfg(feature = "logging")]
1622                                error!(target: "stdout", "{}", &err_msg);
1623
1624                                return Err(LlamaCoreError::Operation(err_msg));
1625                            }
1626                        };
1627
1628                        let function = Function { name, arguments };
1629
1630                        let tool_call = ToolCall {
1631                            id: "call_abc123".to_string(),
1632                            ty: "function".to_string(),
1633                            function,
1634                        };
1635
1636                        tool_calls.push(tool_call);
1637                    }
1638
1639                    let parsed = ParseResult {
1640                        raw: input.to_owned(),
1641                        content: None,
1642                        tool_calls,
1643                    };
1644
1645                    #[cfg(feature = "logging")]
1646                    info!(target: "stdout", "parsed result: {parsed:?}");
1647
1648                    Ok(parsed)
1649                }
1650                Err(e) => {
1651                    let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1652
1653                    #[cfg(feature = "logging")]
1654                    error!(target: "stdout", "{}", &err_msg);
1655
1656                    Err(LlamaCoreError::Operation(err_msg))
1657                }
1658            }
1659        }
1660        PromptTemplateType::FunctionaryV32 => {
1661            #[cfg(feature = "logging")]
1662            info!(target: "stdout", "raw input: {input}");
1663
1664            match regex::Regex::new(r">>>\s*(\w+)\s*\{(.*)\}<\|eot_id\|>") {
1665                Ok(re) => {
1666                    let mut tool_calls: Vec<ToolCall> = vec![];
1667                    for cap in re.captures_iter(input) {
1668                        #[cfg(feature = "logging")]
1669                        info!(target: "stdout", "func_name: {}", &cap[1]);
1670
1671                        #[cfg(feature = "logging")]
1672                        info!(target: "stdout", "arguments: {}", &cap[2]);
1673
1674                        let tool_call = ToolCall {
1675                            id: "call_abc123".to_string(),
1676                            ty: "function".to_string(),
1677                            function: Function {
1678                                name: cap[1].to_string(),
1679                                arguments: cap[2].to_string(),
1680                            },
1681                        };
1682
1683                        tool_calls.push(tool_call);
1684                    }
1685
1686                    let parsed = ParseResult {
1687                        raw: input.to_owned(),
1688                        content: None,
1689                        tool_calls,
1690                    };
1691
1692                    #[cfg(feature = "logging")]
1693                    info!(target: "stdout", "parsed result: {parsed:?}");
1694
1695                    Ok(parsed)
1696                }
1697                Err(e) => {
1698                    let warn_msg = format!("Failed to create a regex pattern. Reason: {e}");
1699
1700                    #[cfg(feature = "logging")]
1701                    warn!(target: "stdout", "{}", &warn_msg);
1702
1703                    Ok(ParseResult {
1704                        raw: input.to_owned(),
1705                        content: None,
1706                        tool_calls: vec![],
1707                    })
1708                }
1709            }
1710        }
1711        PromptTemplateType::FunctionaryV31 => {
1712            #[cfg(feature = "logging")]
1713            info!(target: "stdout", "raw input: {input}");
1714
1715            match regex::Regex::new(r"<function=(\w+)>\s*(\{.*?\})</function>") {
1716                Ok(re) => {
1717                    let mut tool_calls: Vec<ToolCall> = vec![];
1718                    for cap in re.captures_iter(input) {
1719                        #[cfg(feature = "logging")]
1720                        info!(target: "stdout", "func_name: {}", &cap[1]);
1721
1722                        #[cfg(feature = "logging")]
1723                        info!(target: "stdout", "arguments: {}", &cap[2]);
1724
1725                        let tool_call = ToolCall {
1726                            id: "call_abc123".to_string(),
1727                            ty: "function".to_string(),
1728                            function: Function {
1729                                name: cap[1].to_string(),
1730                                arguments: cap[2].to_string(),
1731                            },
1732                        };
1733
1734                        tool_calls.push(tool_call);
1735                    }
1736
1737                    let parsed = ParseResult {
1738                        raw: input.to_owned(),
1739                        content: None,
1740                        tool_calls,
1741                    };
1742
1743                    #[cfg(feature = "logging")]
1744                    info!(target: "stdout", "parsed result: {parsed:?}");
1745
1746                    Ok(parsed)
1747                }
1748                Err(e) => {
1749                    let warn_msg = format!("Failed to create a regex pattern. Reason: {e}");
1750
1751                    #[cfg(feature = "logging")]
1752                    warn!(target: "stdout", "{}", &warn_msg);
1753
1754                    Ok(ParseResult {
1755                        raw: input.to_owned(),
1756                        content: None,
1757                        tool_calls: vec![],
1758                    })
1759                }
1760            }
1761        }
1762        PromptTemplateType::MistralSmallTool => {
1763            #[cfg(feature = "logging")]
1764            info!(target: "stdout", "raw input: {input}");
1765
1766            match regex::Regex::new(r"\[TOOL_CALLS\]\s*(\[(.*?)\])") {
1767                Ok(re) => {
1768                    let mut values: Vec<serde_json::Value> = vec![];
1769                    if let Some(cap) = re.captures(input) {
1770                        let matched = cap[1].trim();
1771
1772                        #[cfg(feature = "logging")]
1773                        info!(target: "stdout", "captured: {matched}");
1774
1775                        match serde_json::from_str::<Vec<serde_json::Value>>(matched) {
1776                            Ok(vals) => values = vals,
1777                            Err(e) => {
1778                                let err_msg = format!(
1779                                    "Failed to deserialize generated tool calls. Reason: {e}"
1780                                );
1781
1782                                #[cfg(feature = "logging")]
1783                                error!(target: "stdout", "{}", &err_msg);
1784
1785                                return Err(LlamaCoreError::Operation(err_msg));
1786                            }
1787                        }
1788                    };
1789
1790                    let mut tool_calls: Vec<ToolCall> = vec![];
1791                    for value in values.iter() {
1792                        if let Some(object_map) = value.as_object() {
1793                            if object_map.contains_key("function") {
1794                                let mut function = Function {
1795                                    name: String::new(),
1796                                    arguments: String::new(),
1797                                };
1798
1799                                let value = object_map.get("function").unwrap();
1800                                let func_map = value.as_object().unwrap();
1801                                if func_map.contains_key("name") {
1802                                    let func_name = func_map.get("name").unwrap().as_str().unwrap();
1803                                    println!("Function name: {func_name:?}");
1804
1805                                    function.name = func_name.to_string();
1806                                }
1807                                if func_map.contains_key("arguments") {
1808                                    let args = func_map.get("arguments").unwrap();
1809                                    let arguments = args.to_string();
1810                                    println!("Arguments: {arguments:?}");
1811
1812                                    function.arguments = arguments;
1813                                }
1814
1815                                let tool_call = ToolCall {
1816                                    id: "call_abc123".to_string(),
1817                                    ty: "function".to_string(),
1818                                    function,
1819                                };
1820
1821                                tool_calls.push(tool_call);
1822                            } else if object_map.contains_key("name") {
1823                                let mut function = Function {
1824                                    name: String::new(),
1825                                    arguments: String::new(),
1826                                };
1827
1828                                let name = object_map.get("name").unwrap().as_str().unwrap();
1829                                println!("name: {name:?}");
1830                                function.name = name.to_string();
1831
1832                                if object_map.contains_key("arguments") {
1833                                    let args = object_map.get("arguments").unwrap();
1834                                    let arguments = args.to_string();
1835                                    println!("Arguments: {arguments:?}");
1836
1837                                    function.arguments = arguments;
1838                                }
1839
1840                                let tool_call = ToolCall {
1841                                    id: "call_abc123".to_string(),
1842                                    ty: "function".to_string(),
1843                                    function,
1844                                };
1845
1846                                tool_calls.push(tool_call);
1847                            }
1848                        }
1849                    }
1850
1851                    let parsed = ParseResult {
1852                        raw: input.to_owned(),
1853                        content: None,
1854                        tool_calls,
1855                    };
1856
1857                    #[cfg(feature = "logging")]
1858                    info!(target: "stdout", "parsed result: {parsed:?}");
1859
1860                    Ok(parsed)
1861                }
1862                Err(e) => {
1863                    let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1864
1865                    #[cfg(feature = "logging")]
1866                    error!(target: "stdout", "{}", &err_msg);
1867
1868                    Err(LlamaCoreError::Operation(err_msg))
1869                }
1870            }
1871        }
1872        PromptTemplateType::Llama4Chat => {
1873            #[cfg(feature = "logging")]
1874            info!(target: "stdout", "raw input: {input:?}");
1875
1876            let mut tool_calls: Vec<ToolCall> = vec![];
1877            if let Ok(value) = serde_json::from_str::<serde_json::Value>(input) {
1878                match value.as_object() {
1879                    Some(object_map) => {
1880                        #[cfg(feature = "logging")]
1881                        debug!(target: "stdout", "object_map: {object_map:?}");
1882
1883                        // parse function name
1884                        if object_map.contains_key("name") {
1885                            let name = object_map.get("name").unwrap().as_str().unwrap();
1886
1887                            #[cfg(feature = "logging")]
1888                            debug!(target: "stdout", "name: {name:?}");
1889
1890                            let mut function = Function {
1891                                name: name.to_string(),
1892                                arguments: String::new(),
1893                            };
1894
1895                            // parse function arguments
1896                            if object_map.contains_key("parameters") {
1897                                let args = object_map.get("parameters").unwrap();
1898                                let arguments = args.to_string();
1899
1900                                #[cfg(feature = "logging")]
1901                                debug!(target: "stdout", "arguments: {:?}", &arguments);
1902
1903                                function.arguments = arguments;
1904                            }
1905
1906                            tool_calls.push(ToolCall {
1907                                id: "call_abc123".to_string(),
1908                                ty: "function".to_string(),
1909                                function,
1910                            });
1911                        } else {
1912                            let err_msg = format!(
1913                                "Failed to get the name of the function. raw input: {input:?}"
1914                            );
1915
1916                            #[cfg(feature = "logging")]
1917                            error!(target: "stdout", "{}", &err_msg);
1918
1919                            return Err(LlamaCoreError::Operation(err_msg));
1920                        }
1921                    }
1922                    None => {
1923                        let err_msg = format!("Failed to parse the JSON string. JSON: {input}");
1924
1925                        #[cfg(feature = "logging")]
1926                        error!(target: "stdout", "{}", &err_msg);
1927
1928                        return Err(LlamaCoreError::Operation(err_msg));
1929                    }
1930                }
1931            }
1932
1933            let parsed = ParseResult {
1934                raw: input.to_owned(),
1935                content: None,
1936                tool_calls,
1937            };
1938
1939            #[cfg(feature = "logging")]
1940            info!(target: "stdout", "parsed result: {parsed:?}");
1941
1942            Ok(parsed)
1943        }
1944        _ => {
1945            let err_msg = format!(
1946                "The tool use is only supported for prompt templates: {}, {}, {}, {}, {}, {}, {}, and {}.",
1947                PromptTemplateType::MistralTool,
1948                PromptTemplateType::ChatMLTool,
1949                PromptTemplateType::GroqLlama3Tool,
1950                PromptTemplateType::Llama3Tool,
1951                PromptTemplateType::InternLM2Tool,
1952                PromptTemplateType::NemotronTool,
1953                PromptTemplateType::FunctionaryV32,
1954                PromptTemplateType::MistralSmallTool,
1955            );
1956
1957            #[cfg(feature = "logging")]
1958            error!(target: "stdout", "{}", &err_msg);
1959
1960            Err(LlamaCoreError::Operation(err_msg))
1961        }
1962    }
1963}
1964
1965fn check_model_metadata(
1966    chat_request: &ChatCompletionRequest,
1967) -> Result<GgmlMetadata, LlamaCoreError> {
1968    let mut should_update = false;
1969    let mut metadata = get_model_metadata(chat_request.model.as_ref())?;
1970
1971    // check if necessary to update `image`
1972    if metadata.prompt_template.is_image_supported() {
1973        if let Some(ChatCompletionRequestMessage::User(user_message)) = chat_request.messages.last()
1974        {
1975            if let ChatCompletionUserMessageContent::Parts(parts) = user_message.content() {
1976                for part in parts {
1977                    if let ContentPart::Image(image_part) = part {
1978                        let image = image_part.image();
1979
1980                        if image.is_url() {
1981                            #[cfg(feature = "logging")]
1982                            info!(target: "stdout", "The image is provided in URL format.");
1983
1984                            // download the image
1985                            let img_path_str = download_image(&image.url)?;
1986
1987                            #[cfg(feature = "logging")]
1988                            info!(target: "stdout", "The image is saved to {img_path_str}");
1989
1990                            // update metadata image
1991                            metadata.image = Some(img_path_str);
1992
1993                            if !should_update {
1994                                should_update = true;
1995                            }
1996
1997                            // TODO: now only support a single image
1998
1999                            break;
2000                        } else {
2001                            #[cfg(feature = "logging")]
2002                            info!(target: "stdout", "The image is provided in base64 format.");
2003
2004                            // TODO: now only support a single image
2005
2006                            break;
2007                        }
2008                    }
2009                }
2010            }
2011        }
2012    }
2013
2014    // check if necessary to update temperature
2015    if let Some(temp) = chat_request.temperature {
2016        if metadata.temperature != temp {
2017            // update temperature
2018            metadata.temperature = temp;
2019
2020            if !should_update {
2021                should_update = true;
2022            }
2023        }
2024    }
2025
2026    // check if necessary to update top_p
2027    if let Some(top_p) = chat_request.top_p {
2028        if metadata.top_p != top_p {
2029            // update top_p
2030            metadata.top_p = top_p;
2031
2032            if !should_update {
2033                should_update = true;
2034            }
2035        }
2036    }
2037
2038    // check if necessary to update frequency_penalty
2039    if let Some(frequency_penalty) = chat_request.frequency_penalty {
2040        if metadata.frequency_penalty != frequency_penalty {
2041            // update frequency_penalty
2042            metadata.frequency_penalty = frequency_penalty;
2043
2044            if !should_update {
2045                should_update = true;
2046            }
2047        }
2048    }
2049
2050    // check if necessary to update presence_penalty
2051    if let Some(presence_penalty) = chat_request.presence_penalty {
2052        if metadata.presence_penalty != presence_penalty {
2053            // update presence_penalty
2054            metadata.presence_penalty = presence_penalty;
2055
2056            if !should_update {
2057                should_update = true;
2058            }
2059        }
2060    }
2061
2062    // check if the `embedding` option is disabled
2063    if metadata.embeddings {
2064        metadata.embeddings = false;
2065
2066        if !should_update {
2067            should_update = true;
2068        }
2069    }
2070
2071    if should_update {
2072        #[cfg(feature = "logging")]
2073        info!(target: "stdout", "Update the model metadata.");
2074
2075        // update the target graph with the new metadata
2076        update_model_metadata(chat_request.model.as_ref(), &metadata)?;
2077    }
2078
2079    Ok(metadata)
2080}
2081
2082fn update_n_predict(
2083    chat_request: &ChatCompletionRequest,
2084    metadata: &mut GgmlMetadata,
2085    available_completion_tokens: u64,
2086) -> Result<(), LlamaCoreError> {
2087    let mut should_update = false;
2088
2089    #[cfg(feature = "logging")]
2090    info!(target: "stdout", "n_predict: {}", metadata.n_predict);
2091
2092    // From high to low priority
2093    // 1. chat_request.max_tokens & chat_request.max_completion_tokens
2094    // 2. available_completion_tokens
2095    // 3. n_predict
2096
2097    if let Some(max_completion_tokens) = chat_request.max_completion_tokens {
2098        if metadata.n_predict != max_completion_tokens {
2099            #[cfg(feature = "logging")]
2100            info!(target: "stdout", "Update n_predict with max_completion_tokens from {} to {}", metadata.n_predict, max_completion_tokens);
2101
2102            metadata.n_predict = max_completion_tokens;
2103
2104            if !should_update {
2105                should_update = true;
2106            }
2107        }
2108    }
2109
2110    // TODO: remove this condition after [Issue #3958 on WasmEdge](https://github.com/WasmEdge/WasmEdge/issues/3958) is fixed
2111    if metadata.n_predict == -2 {
2112        #[cfg(feature = "logging")]
2113        info!(target: "stdout", "Update n_predict with available_completion_tokens from {} to {}", metadata.n_predict, available_completion_tokens);
2114
2115        // update n_predict
2116        metadata.n_predict = available_completion_tokens as i32;
2117
2118        if !should_update {
2119            should_update = true;
2120        }
2121    }
2122
2123    if metadata.n_predict == -1
2124        || (metadata.n_predict > 0 && metadata.n_predict < available_completion_tokens as i32)
2125        || (metadata.n_predict < 0 && metadata.n_predict != -2)
2126    // TODO: remove this condition after [Issue #3958 on WasmEdge](https://github.com/WasmEdge/WasmEdge/issues/3958) is fixed
2127    {
2128        #[cfg(feature = "logging")]
2129        info!(target: "stdout", "Update n_predict with available_completion_tokens from {} to {}", metadata.n_predict, available_completion_tokens);
2130
2131        // update n_predict
2132        metadata.n_predict = available_completion_tokens as i32;
2133
2134        if !should_update {
2135            should_update = true;
2136        }
2137    }
2138
2139    if should_update {
2140        #[cfg(feature = "logging")]
2141        info!(target: "stdout", "Update the model metadata.");
2142
2143        // update the target graph with the new metadata
2144        update_model_metadata(chat_request.model.as_ref(), metadata)?;
2145    }
2146
2147    Ok(())
2148}
2149
2150/// Build post-processing for output based on template type
2151fn post_process(
2152    output: impl AsRef<str>,
2153    template_ty: &PromptTemplateType,
2154) -> Result<String, String> {
2155    let output = if *template_ty == PromptTemplateType::Baichuan2 {
2156        if output.as_ref().contains("用户:") {
2157            output.as_ref().trim_end_matches("用户:").trim().to_owned()
2158        } else {
2159            output.as_ref().trim().to_owned()
2160        }
2161    } else if *template_ty == PromptTemplateType::OpenChat {
2162        if output.as_ref().contains("<|end_of_turn|>") {
2163            output
2164                .as_ref()
2165                .trim_end_matches("<|end_of_turn|>")
2166                .trim()
2167                .to_owned()
2168        } else {
2169            output.as_ref().trim().to_owned()
2170        }
2171    } else if *template_ty == PromptTemplateType::GemmaInstruct
2172        || *template_ty == PromptTemplateType::Gemma3
2173    {
2174        let s = output.as_ref().trim();
2175        if s.ends_with("<end_of_turn>") {
2176            s.trim_end_matches("<end_of_turn>").trim().to_owned()
2177        } else {
2178            s.to_owned()
2179        }
2180    } else if *template_ty == PromptTemplateType::ChatML
2181        || *template_ty == PromptTemplateType::ChatMLTool
2182        || *template_ty == PromptTemplateType::InternLM2Tool
2183        || *template_ty == PromptTemplateType::MiniCPMV
2184    {
2185        let mut s = output.as_ref().trim();
2186        if s.ends_with("<|endoftext|>") {
2187            s = s.trim_end_matches("<|endoftext|>").trim();
2188        }
2189
2190        if s.starts_with(":") {
2191            s = s.trim_start_matches(":").trim();
2192        }
2193
2194        if s.contains("<|im_start|>") && s.contains("<|im_end|>") {
2195            let idx_start = s.find("<|im_start|>").unwrap();
2196            let idx_end = s.find("<|im_end|>").unwrap();
2197
2198            match idx_start <= idx_end {
2199                true => s.split("<|im_start|>").collect::<Vec<_>>()[0]
2200                    .trim()
2201                    .to_owned(),
2202                false => s.split("<|im_end|>").collect::<Vec<_>>()[0]
2203                    .trim()
2204                    .to_owned(),
2205            }
2206        } else if s.contains("<|im_start|>") {
2207            s.split("<|im_start|>").collect::<Vec<_>>()[0]
2208                .trim()
2209                .to_owned()
2210        } else if s.contains("<|im_end|>") {
2211            let output = s.trim_end_matches("<|im_end|>").trim();
2212            if output.starts_with(": ") {
2213                output.trim_start_matches(": ").to_owned()
2214            } else {
2215                output.to_owned()
2216            }
2217        } else {
2218            s.to_owned()
2219        }
2220    } else if *template_ty == PromptTemplateType::Zephyr
2221        || *template_ty == PromptTemplateType::MistralLite
2222        || *template_ty == PromptTemplateType::MistralTool
2223        || *template_ty == PromptTemplateType::MistralInstruct
2224        || *template_ty == PromptTemplateType::MistralSmallChat
2225        || *template_ty == PromptTemplateType::MistralSmallTool
2226        || *template_ty == PromptTemplateType::BreezeInstruct
2227    {
2228        if output.as_ref().contains("</s><") {
2229            output.as_ref().trim_end_matches("</s><").trim().to_owned()
2230        } else if output.as_ref().contains("</s>") {
2231            output
2232                .as_ref()
2233                .strip_suffix("</s>")
2234                .unwrap()
2235                .trim()
2236                .to_owned()
2237        } else {
2238            output.as_ref().trim().to_owned()
2239        }
2240    } else if *template_ty == PromptTemplateType::DeepseekChat {
2241        if output.as_ref().contains("<|end_of_sentence|>") {
2242            output
2243                .as_ref()
2244                .trim_end_matches("<|end_of_sentence|>")
2245                .trim()
2246                .replace("<|end_of_sentence|>", " ")
2247                .trim()
2248                .to_owned()
2249        } else {
2250            output.as_ref().trim().to_owned()
2251        }
2252    } else if *template_ty == PromptTemplateType::HumanAssistant {
2253        if output.as_ref().contains("Human:") {
2254            output.as_ref().trim_end_matches("Human:").trim().to_owned()
2255        } else {
2256            output.as_ref().trim().to_owned()
2257        }
2258    } else if *template_ty == PromptTemplateType::SolarInstruct {
2259        let s = output.as_ref().trim();
2260
2261        if s.starts_with("### Answer") {
2262            let s = s.trim_start_matches("###").trim();
2263
2264            if s.starts_with("Answer:\n") {
2265                s.replace("Answer:\n", "Answer: ")
2266            } else {
2267                s.to_owned()
2268            }
2269        } else {
2270            s.to_owned()
2271        }
2272    } else if *template_ty == PromptTemplateType::Llama2Chat
2273        || *template_ty == PromptTemplateType::NemotronTool
2274        || *template_ty == PromptTemplateType::NemotronChat
2275    {
2276        let s = output.as_ref().trim();
2277        if s.ends_with("</s>") {
2278            s.trim_end_matches("</s>").trim().to_owned()
2279        } else {
2280            s.to_owned()
2281        }
2282    } else if *template_ty == PromptTemplateType::Llama3Chat
2283        || *template_ty == PromptTemplateType::GroqLlama3Tool
2284        || *template_ty == PromptTemplateType::Llama3Tool
2285        || *template_ty == PromptTemplateType::FunctionaryV32
2286    {
2287        let s = output.as_ref().trim();
2288        if s.ends_with("<|eot_id|>") {
2289            s.trim_end_matches("<|eot_id|>").trim().to_owned()
2290        } else {
2291            s.to_owned()
2292        }
2293    } else if *template_ty == PromptTemplateType::Phi3Chat {
2294        let s = output.as_ref().trim();
2295        if s.ends_with("<|end|>") {
2296            s.trim_end_matches("<|end|>").trim().to_owned()
2297        } else {
2298            s.to_owned()
2299        }
2300    } else if *template_ty == PromptTemplateType::Phi4Chat {
2301        let s = output.as_ref().trim();
2302        if s.ends_with("<|im_end|>") {
2303            s.trim_end_matches("<|im_end|>").trim().to_owned()
2304        } else {
2305            s.to_owned()
2306        }
2307    } else if *template_ty == PromptTemplateType::FunctionaryV31 {
2308        let mut s = output.as_ref().trim();
2309        if s.ends_with("<|eot_id|>") {
2310            s = s.trim_end_matches("<|eot_id|>").trim();
2311        }
2312        if s.ends_with("<|eom_id|>") {
2313            s = s.trim_end_matches("<|eom_id|>").trim();
2314        }
2315        s.to_owned()
2316    } else if *template_ty == PromptTemplateType::MoxinChat {
2317        let s = output.as_ref().trim();
2318        if s.ends_with("</s>") {
2319            s.trim_end_matches("</s>").trim().to_owned()
2320        } else if s.ends_with("[INST]") {
2321            s.trim_end_matches("[INST]").trim().to_owned()
2322        } else {
2323            s.to_owned()
2324        }
2325    } else if *template_ty == PromptTemplateType::Falcon3 {
2326        let s = output.as_ref().trim();
2327        if s.ends_with("<|endoftext|>") {
2328            s.trim_end_matches("<|endoftext|>").trim().to_owned()
2329        } else {
2330            s.to_owned()
2331        }
2332    } else if *template_ty == PromptTemplateType::Megrez {
2333        let s = output.as_ref().trim();
2334        if s.ends_with("<|turn_end|>") {
2335            s.trim_end_matches("<|turn_end|>").trim().to_owned()
2336        } else {
2337            s.to_owned()
2338        }
2339    } else if *template_ty == PromptTemplateType::Qwen2vl
2340        || *template_ty == PromptTemplateType::ChatMLThink
2341    {
2342        let mut s = output.as_ref().trim();
2343
2344        if s.starts_with(":") {
2345            s = s.trim_start_matches(":").trim();
2346        }
2347
2348        if s.ends_with("<|im_end|>") {
2349            s.trim_end_matches("<|im_end|>").trim().to_owned()
2350        } else {
2351            s.to_owned()
2352        }
2353    } else if *template_ty == PromptTemplateType::VicunaLlava {
2354        let s = output.as_ref().trim();
2355        if s.ends_with("</s>") {
2356            s.trim_end_matches("</s>").trim().to_owned()
2357        } else {
2358            s.to_owned()
2359        }
2360    } else if *template_ty == PromptTemplateType::ExaoneDeepChat
2361        || *template_ty == PromptTemplateType::ExaoneChat
2362    {
2363        let mut s = output.as_ref().trim();
2364
2365        if s.ends_with("[|endofturn|]") {
2366            s = s.trim_end_matches("[|endofturn|]").trim();
2367        }
2368
2369        s.to_owned()
2370    } else if *template_ty == PromptTemplateType::Llama4Chat {
2371        let mut s = output.as_ref().trim();
2372
2373        if s.ends_with("<|eot|>") {
2374            s = s.trim_end_matches("<|eot|>").trim();
2375        }
2376
2377        s.to_owned()
2378    } else {
2379        output.as_ref().trim().to_owned()
2380    };
2381
2382    Ok(output)
2383}
2384
2385/// Build the chat prompt from the chat messages.
2386///
2387/// # Arguments
2388///
2389/// * `model_name`: The name of the model.
2390///
2391/// * `chat_request`: The chat request.
2392///
2393/// # Returns
2394///
2395/// A tuple containing the prompt, the number of available tokens for completions, and a boolean indicating whether tools are used.
2396fn build_prompt(
2397    model_name: Option<&String>,
2398    chat_request: &mut ChatCompletionRequest,
2399) -> Result<(String, u64, bool), LlamaCoreError> {
2400    let metadata = get_model_metadata(model_name)?;
2401    let ctx_size = metadata.ctx_size as u64;
2402    let chat_prompt = ChatPrompt::from(metadata.prompt_template);
2403
2404    // compute max prompt tokens, which is 80% of the context size
2405    let max_prompt_tokens = ctx_size * 4 / 5;
2406
2407    loop {
2408        // ! DO NOT REMOVE
2409        {
2410            // // build prompt
2411            // let prompt = match chat_prompt.build(&mut chat_request.messages) {
2412            //     Ok(prompt) => prompt,
2413            //     Err(e) => {
2414            //         let err_msg = format!("Fail to build chat prompts. Reason: {}", e);
2415
2416            //         #[cfg(feature = "logging")]
2417            //         error!(target: "stdout", "{}", &err_msg);
2418
2419            //         return Err(LlamaCoreError::Operation(err_msg));
2420            //     }
2421            // };
2422        }
2423
2424        if chat_request.messages.is_empty() {
2425            let err_msg = "The messages in the chat request are empty.";
2426
2427            #[cfg(feature = "logging")]
2428            error!(target: "stdout", "{err_msg}");
2429
2430            return Err(LlamaCoreError::Operation(err_msg.to_owned()));
2431        }
2432
2433        #[cfg(feature = "logging")]
2434        {
2435            let mut role_chain = String::new();
2436            for (idx, message) in chat_request.messages.iter().enumerate() {
2437                if idx == chat_request.messages.len() - 1 {
2438                    role_chain.push_str(&format!("{}", message.role()));
2439                } else {
2440                    role_chain.push_str(&format!("{} -> ", message.role()));
2441                }
2442            }
2443            info!(target: "stdout", "Role chain: {role_chain}");
2444        }
2445
2446        let (prompt, tool_use) = match chat_request.tool_choice.as_ref() {
2447            Some(tool_choice) => match tool_choice {
2448                ToolChoice::None => {
2449                    match chat_prompt.build_with_tools(&mut chat_request.messages, Some(&[])) {
2450                        Ok(prompt) => (prompt, false),
2451                        Err(e) => {
2452                            let err_msg = format!("Fail to build chat prompts. Reason: {e}");
2453
2454                            #[cfg(feature = "logging")]
2455                            error!(target: "stdout", "{}", &err_msg);
2456
2457                            return Err(LlamaCoreError::Operation(err_msg));
2458                        }
2459                    }
2460                }
2461                _ => match chat_request.tools.as_ref() {
2462                    Some(tools) => match chat_prompt
2463                        .build_with_tools(&mut chat_request.messages, Some(tools.as_slice()))
2464                    {
2465                        Ok(prompt) => (prompt, true),
2466                        Err(e) => {
2467                            let err_msg = format!("Fail to build chat prompts. Reason: {e}");
2468
2469                            #[cfg(feature = "logging")]
2470                            error!(target: "stdout", "{}", &err_msg);
2471
2472                            return Err(LlamaCoreError::Operation(err_msg));
2473                        }
2474                    },
2475                    None => {
2476                        #[cfg(feature = "logging")]
2477                        warn!(target: "stdout", "The tool choice without tools is not supported.");
2478
2479                        match chat_prompt.build_with_tools(&mut chat_request.messages, None) {
2480                            Ok(prompt) => (prompt, false),
2481                            Err(e) => {
2482                                let err_msg = format!("Fail to build chat prompts. Reason: {e}");
2483
2484                                #[cfg(feature = "logging")]
2485                                error!(target: "stdout", "{}", &err_msg);
2486
2487                                return Err(LlamaCoreError::Operation(err_msg));
2488                            }
2489                        }
2490                    }
2491                },
2492            },
2493            None => match chat_prompt.build_with_tools(&mut chat_request.messages, None) {
2494                Ok(prompt) => (prompt, false),
2495                Err(e) => {
2496                    let err_msg = format!("Fail to build chat prompts. Reason: {e}");
2497
2498                    #[cfg(feature = "logging")]
2499                    error!(target: "stdout", "{}", &err_msg);
2500
2501                    return Err(LlamaCoreError::Operation(err_msg));
2502                }
2503            },
2504        };
2505        #[cfg(feature = "logging")]
2506        info!(target: "stdout", "Try to set prompt: {prompt}");
2507
2508        // set prompt
2509        set_prompt(model_name, &prompt)?;
2510
2511        // Retrieve the number of prompt tokens.
2512        let token_info = get_token_info_by_graph_name(model_name)?;
2513
2514        match token_info.prompt_tokens > max_prompt_tokens {
2515            true => {
2516                match chat_request.messages[0].role() {
2517                    ChatCompletionRole::System => {
2518                        if chat_request.messages.len() > 2 {
2519                            #[cfg(feature = "logging")]
2520                            info!(target: "stdout", "Prune chat history: current length {}", chat_request.messages.len());
2521
2522                            // remove user_1 if it exists
2523                            // For example, `system -> user_1 -> ... -> user_2 -> ... -> user_latest` will be converted to `system -> ... -> user_2 -> ... -> user_latest`
2524                            if chat_request.messages[1].role() == ChatCompletionRole::User {
2525                                let user_message = chat_request.messages.remove(1);
2526
2527                                #[cfg(feature = "logging")]
2528                                info!(target: "stdout", "Remove a user message from the chat history: {user_message:?}");
2529                            }
2530
2531                            // remove all messages until the message is of `user`
2532                            // For example, `system -> ... -> user_2 -> ... -> user_latest` will be converted to `system -> user_2 -> ... -> user_latest`
2533                            while chat_request.messages[1].role() != ChatCompletionRole::User {
2534                                let message = chat_request.messages.remove(1);
2535
2536                                #[cfg(feature = "logging")]
2537                                info!(target: "stdout", "Remove a {} message from the chat history: {:?}", message.role(), message);
2538
2539                                if chat_request.messages.len() == 1 {
2540                                    let err_msg = format!("The last message in the chat history should be a user message, but found a {} message.", message.role());
2541
2542                                    #[cfg(feature = "logging")]
2543                                    error!(target: "stdout", "{err_msg}");
2544
2545                                    return Err(LlamaCoreError::Operation(err_msg));
2546                                }
2547                            }
2548                        } else if token_info.prompt_tokens > ctx_size {
2549                            let err_msg = format!(
2550                                    "The number of prompt tokens ({}) is greater than the context size ({}). Please increase the context size, or simplify the input message.",
2551                                    token_info.prompt_tokens, ctx_size
2552                                );
2553
2554                            #[cfg(feature = "logging")]
2555                            error!(target: "stdout", "{}", &err_msg);
2556
2557                            return Err(LlamaCoreError::Operation(err_msg));
2558                        } else {
2559                            return Ok((prompt, ctx_size - token_info.prompt_tokens, tool_use));
2560                        }
2561                    }
2562                    ChatCompletionRole::User => {
2563                        if chat_request.messages.len() > 1 {
2564                            // user_1 -> ... -> user_2 -> ... -> user_latest
2565
2566                            // remove user_1 if it exists
2567                            // For example, `user_1 -> ... -> user_2 -> ... -> user_latest` will be converted to `... -> user_2 -> ... -> user_latest`
2568                            if chat_request.messages[0].role() == ChatCompletionRole::User {
2569                                let user_message = chat_request.messages.remove(0);
2570
2571                                #[cfg(feature = "logging")]
2572                                info!(target: "stdout", "Remove a user message from the chat history: {user_message:?}");
2573                            }
2574
2575                            // remove all messages until the message is of `user`
2576                            // For example, `... -> user_2 -> ... -> user_latest` will be converted to `user_2 -> ... -> user_latest`
2577                            while chat_request.messages[0].role() != ChatCompletionRole::User {
2578                                let message = chat_request.messages.remove(0);
2579
2580                                #[cfg(feature = "logging")]
2581                                info!(target: "stdout", "Remove a {} message from the chat history: {:?}", message.role(), message);
2582
2583                                if chat_request.messages.is_empty() {
2584                                    let err_msg = format!("The last message in the chat history should be a user message, but found a {} message.", message.role());
2585
2586                                    #[cfg(feature = "logging")]
2587                                    error!(target: "stdout", "{err_msg}");
2588
2589                                    return Err(LlamaCoreError::Operation(err_msg));
2590                                }
2591                            }
2592                        } else if token_info.prompt_tokens > ctx_size {
2593                            let err_msg = format!(
2594                                    "The number of prompt tokens ({}) is greater than the context size ({}). Please increase the context size, or simplify the input message.",
2595                                    token_info.prompt_tokens, ctx_size
2596                                );
2597
2598                            #[cfg(feature = "logging")]
2599                            error!(target: "stdout", "{}", &err_msg);
2600
2601                            return Err(LlamaCoreError::Operation(err_msg));
2602                        } else {
2603                            return Ok((prompt, ctx_size - token_info.prompt_tokens, tool_use));
2604                        }
2605                    }
2606                    _ => {
2607                        #[cfg(feature = "logging")]
2608                        info!(target: "stdout", "remove a {} message from the message queue", chat_request.messages[0].role());
2609
2610                        chat_request.messages.remove(0);
2611                    }
2612                }
2613
2614                continue;
2615            }
2616            false => return Ok((prompt, ctx_size - max_prompt_tokens, tool_use)),
2617        }
2618    }
2619}
2620
2621/// Downloads an image from the given URL and returns the file name.
2622fn download_image(image_url: impl AsRef<str>) -> Result<String, LlamaCoreError> {
2623    let image_url = image_url.as_ref();
2624    let url = reqwest::Url::parse(image_url).map_err(|e| {
2625        let err_msg = format!("Fail to parse the image URL: {image_url}. Reason: {e}");
2626
2627        #[cfg(feature = "logging")]
2628        error!(target: "stdout", "{}", &err_msg);
2629
2630        LlamaCoreError::Operation(err_msg)
2631    })?;
2632
2633    let curr_rt_handle = tokio::runtime::Handle::current();
2634    let res = curr_rt_handle.block_on(async {
2635        let response = reqwest::get(url).await.map_err(|e| {
2636            let err_msg =
2637                format!("Fail to download the image from the URL: {image_url}. Reason: {e}");
2638
2639            #[cfg(feature = "logging")]
2640            error!(target: "stdout", "{}", &err_msg);
2641
2642            LlamaCoreError::Operation(err_msg)
2643        })?;
2644
2645        let fname = response
2646            .url()
2647            .path_segments()
2648            .and_then(|mut segments| segments.next_back())
2649            .and_then(|name| if name.is_empty() { None } else { Some(name) })
2650            .ok_or(LlamaCoreError::Operation(format!(
2651                "Fail to get the file name: {image_url}"
2652            )))?
2653            .to_string();
2654
2655        // create a unique file id
2656        let id = format!("file_{}", uuid::Uuid::new_v4());
2657        // save the file
2658        let path = Path::new("archives");
2659        if !path.exists() {
2660            fs::create_dir(path).unwrap();
2661        }
2662        let file_path = path.join(&id);
2663        if !file_path.exists() {
2664            fs::create_dir(&file_path).unwrap();
2665        }
2666        let img_path = file_path.join(&fname);
2667        let mut dest: File = File::create(&img_path).map_err(|e| {
2668            let err_msg = format!("Failed to create archive document {}. {}", &fname, e);
2669
2670            // log
2671            #[cfg(feature = "logging")]
2672            error!(target: "stdout", "{}", &err_msg);
2673
2674            LlamaCoreError::Operation(err_msg)
2675        })?;
2676
2677        let mut content = response.bytes_stream();
2678        while let Some(Ok(item)) = content.next().await {
2679            std::io::copy(&mut item.as_ref(), &mut dest).map_err(|e| {
2680                let err_msg = format!(
2681                    "Fail to write the image content to the file: {}. Reason: {}",
2682                    &fname, e
2683                );
2684
2685                #[cfg(feature = "logging")]
2686                error!(target: "stdout", "{}", &err_msg);
2687
2688                LlamaCoreError::Operation(err_msg)
2689            })?;
2690        }
2691
2692        Ok(img_path.as_path().to_string_lossy().to_string())
2693    });
2694
2695    res
2696
2697    // Ok(img_path.as_path().to_string_lossy().to_string())
2698}
2699
2700fn set_prompt(model_name: Option<&String>, prompt: impl AsRef<str>) -> Result<(), LlamaCoreError> {
2701    let chat_graphs = match CHAT_GRAPHS.get() {
2702        Some(chat_graphs) => chat_graphs,
2703        None => {
2704            let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
2705
2706            #[cfg(feature = "logging")]
2707            error!(target: "stdout", "{}", &err_msg);
2708
2709            return Err(LlamaCoreError::Operation(err_msg.into()));
2710        }
2711    };
2712
2713    let mut chat_graphs = chat_graphs.lock().map_err(|e| {
2714        let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
2715
2716        #[cfg(feature = "logging")]
2717        error!(target: "stdout", "{}", &err_msg);
2718
2719        LlamaCoreError::Operation(err_msg)
2720    })?;
2721
2722    match model_name {
2723        Some(model_name) => {
2724            #[cfg(feature = "logging")]
2725            info!(target: "stdout", "Set prompt to the chat model named {model_name}");
2726
2727            match chat_graphs.contains_key(model_name) {
2728                true => {
2729                    let graph = chat_graphs.get_mut(model_name).unwrap();
2730                    let tensor_data = prompt.as_ref().as_bytes().to_vec();
2731                    set_tensor_data_u8(graph, 0, &tensor_data)
2732                }
2733                false => match chat_graphs.iter_mut().next() {
2734                    Some((_, graph)) => {
2735                        let tensor_data = prompt.as_ref().as_bytes().to_vec();
2736                        set_tensor_data_u8(graph, 0, &tensor_data)
2737                    }
2738                    None => {
2739                        let err_msg = "There is no model available in the chat graphs.";
2740
2741                        #[cfg(feature = "logging")]
2742                        error!(target: "stdout", "{}", &err_msg);
2743
2744                        Err(LlamaCoreError::Operation(err_msg.into()))
2745                    }
2746                },
2747            }
2748        }
2749        None => {
2750            #[cfg(feature = "logging")]
2751            info!(target: "stdout", "Set prompt to the default chat model.");
2752
2753            match chat_graphs.iter_mut().next() {
2754                Some((_, graph)) => {
2755                    let tensor_data = prompt.as_ref().as_bytes().to_vec();
2756                    set_tensor_data_u8(graph, 0, &tensor_data)
2757                }
2758                None => {
2759                    let err_msg = "There is no model available in the chat graphs while trying to set prompt to the default model.";
2760
2761                    #[cfg(feature = "logging")]
2762                    error!(target: "stdout", "{err_msg}");
2763
2764                    Err(LlamaCoreError::Operation(err_msg.into()))
2765                }
2766            }
2767        }
2768    }
2769}
2770
2771/// Get a copy of the metadata of the model.
2772fn get_model_metadata(model_name: Option<&String>) -> Result<GgmlMetadata, LlamaCoreError> {
2773    let chat_graphs = match CHAT_GRAPHS.get() {
2774        Some(chat_graphs) => chat_graphs,
2775        None => {
2776            let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
2777
2778            #[cfg(feature = "logging")]
2779            error!(target: "stdout", "{err_msg}");
2780
2781            return Err(LlamaCoreError::Operation(err_msg.into()));
2782        }
2783    };
2784
2785    let chat_graphs = chat_graphs.lock().map_err(|e| {
2786        let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
2787
2788        #[cfg(feature = "logging")]
2789        error!(target: "stdout", "{}", &err_msg);
2790
2791        LlamaCoreError::Operation(err_msg)
2792    })?;
2793
2794    match model_name {
2795        Some(model_name) => match chat_graphs.contains_key(model_name) {
2796            true => {
2797                let graph = chat_graphs.get(model_name).unwrap();
2798                Ok(graph.metadata.clone())
2799            }
2800            false => match chat_graphs.iter().next() {
2801                Some((_, graph)) => Ok(graph.metadata.clone()),
2802                None => {
2803                    let err_msg = "There is no model available in the chat graphs.";
2804
2805                    #[cfg(feature = "logging")]
2806                    error!(target: "stdout", "{}", &err_msg);
2807
2808                    Err(LlamaCoreError::Operation(err_msg.into()))
2809                }
2810            },
2811        },
2812        None => match chat_graphs.iter().next() {
2813            Some((_, graph)) => Ok(graph.metadata.clone()),
2814            None => {
2815                let err_msg = "There is no model available in the chat graphs.";
2816
2817                #[cfg(feature = "logging")]
2818                error!(target: "stdout", "{err_msg}");
2819
2820                Err(LlamaCoreError::Operation(err_msg.into()))
2821            }
2822        },
2823    }
2824}
2825
2826fn update_model_metadata(
2827    model_name: Option<&String>,
2828    metadata: &GgmlMetadata,
2829) -> Result<(), LlamaCoreError> {
2830    let config = match serde_json::to_string(metadata) {
2831        Ok(config) => config,
2832        Err(e) => {
2833            let err_msg = format!("Fail to serialize metadata to a JSON string. {e}");
2834
2835            #[cfg(feature = "logging")]
2836            error!(target: "stdout", "{}", &err_msg);
2837
2838            return Err(LlamaCoreError::Operation(err_msg));
2839        }
2840    };
2841
2842    let chat_graphs = match CHAT_GRAPHS.get() {
2843        Some(chat_graphs) => chat_graphs,
2844        None => {
2845            let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
2846
2847            #[cfg(feature = "logging")]
2848            error!(target: "stdout", "{err_msg}");
2849
2850            return Err(LlamaCoreError::Operation(err_msg.into()));
2851        }
2852    };
2853
2854    let mut chat_graphs = chat_graphs.lock().map_err(|e| {
2855        let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. Reason: {e}");
2856
2857        #[cfg(feature = "logging")]
2858        error!(target: "stdout", "{}", &err_msg);
2859
2860        LlamaCoreError::Operation(err_msg)
2861    })?;
2862
2863    match model_name {
2864        Some(model_name) => {
2865            match chat_graphs.contains_key(model_name) {
2866                true => {
2867                    let graph = chat_graphs.get_mut(model_name).unwrap();
2868                    // update metadata
2869                    set_tensor_data_u8(graph, 1, config.as_bytes())
2870                }
2871                false => match chat_graphs.iter_mut().next() {
2872                    Some((_, graph)) => {
2873                        // update metadata
2874                        set_tensor_data_u8(graph, 1, config.as_bytes())
2875                    }
2876                    None => {
2877                        let err_msg = "There is no model available in the chat graphs.";
2878
2879                        #[cfg(feature = "logging")]
2880                        error!(target: "stdout", "{}", &err_msg);
2881
2882                        Err(LlamaCoreError::Operation(err_msg.into()))
2883                    }
2884                },
2885            }
2886        }
2887        None => {
2888            match chat_graphs.iter_mut().next() {
2889                Some((_, graph)) => {
2890                    // update metadata
2891                    set_tensor_data_u8(graph, 1, config.as_bytes())
2892                }
2893                None => {
2894                    let err_msg = "There is no model available in the chat graphs.";
2895
2896                    #[cfg(feature = "logging")]
2897                    error!(target: "stdout", "{err_msg}");
2898
2899                    Err(LlamaCoreError::Operation(err_msg.into()))
2900                }
2901            }
2902        }
2903    }
2904}
2905
2906fn reset_model_metadata(model_name: Option<&String>) -> Result<(), LlamaCoreError> {
2907    // get metadata
2908    let metadata = get_model_metadata(model_name)?;
2909
2910    // update model with the original metadata
2911    update_model_metadata(model_name, &metadata)
2912}
2913
2914#[derive(Debug, Clone, Copy, PartialEq, Eq)]
2915enum ContextFullState {
2916    Message,
2917    Usage,
2918    Done,
2919    EndOfSequence,
2920}
2921
2922#[derive(Debug, Clone, Copy, PartialEq, Eq)]
2923enum StreamState {
2924    Usage,
2925    NoUsage,
2926    Done,
2927    EndOfSequence,
2928}
2929
2930#[derive(Debug, Clone, Copy, PartialEq, Eq)]
2931enum PromptTooLongState {
2932    Message,
2933    Usage,
2934    Done,
2935    EndOfSequence,
2936}
2937
2938struct ChatStream {
2939    id: String,
2940    model: Option<String>,
2941    include_usage: bool,
2942    context_full_state: ContextFullState,
2943    prompt_too_long_state: PromptTooLongState,
2944    stream_state: StreamState,
2945    cache: Option<VecDeque<String>>,
2946    is_waiting: bool,
2947    has_lock: bool,
2948}
2949impl ChatStream {
2950    fn new(
2951        model: Option<String>,
2952        id: String,
2953        include_usage: bool,
2954        cache: Option<Vec<String>>,
2955    ) -> Self {
2956        // Try to acquire lock
2957        let has_lock = CHAT_STREAM_ACTIVE
2958            .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
2959            .is_ok();
2960
2961        #[cfg(feature = "logging")]
2962        if !has_lock {
2963            info!(target: "stdout", "Lock acquisition failed in ChatStream::new, creating with waiting status");
2964        }
2965
2966        ChatStream {
2967            id,
2968            model,
2969            include_usage,
2970            context_full_state: ContextFullState::Message,
2971            prompt_too_long_state: PromptTooLongState::Message,
2972            stream_state: if include_usage {
2973                StreamState::Usage
2974            } else {
2975                StreamState::NoUsage
2976            },
2977            cache: cache.map(VecDeque::from),
2978            is_waiting: !has_lock,
2979            has_lock,
2980        }
2981    }
2982
2983    // Try to acquire lock, returns whether successful
2984    fn try_acquire_lock(&mut self) -> bool {
2985        if self.has_lock {
2986            return true;
2987        }
2988
2989        let acquired = CHAT_STREAM_ACTIVE
2990            .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
2991            .is_ok();
2992
2993        if acquired {
2994            self.has_lock = true;
2995            self.is_waiting = false;
2996        }
2997
2998        acquired
2999    }
3000}
3001impl Drop for ChatStream {
3002    fn drop(&mut self) {
3003        // Clean up is only needed if we have the lock or if stream was actually used
3004        if self.has_lock || (self.cache.is_none() && !self.is_waiting) {
3005            #[cfg(feature = "logging")]
3006            info!(target: "stdout", "Cleaning up context for ChatStream {}", &self.id);
3007
3008            match &self.model {
3009                Some(model_name) => {
3010                    match CHAT_GRAPHS.get() {
3011                        Some(chat_graphs) => {
3012                            match chat_graphs.lock() {
3013                                Ok(mut chat_graphs) => match chat_graphs.contains_key(model_name) {
3014                                    true => {
3015                                        let graph = chat_graphs.get_mut(model_name).unwrap();
3016
3017                                        // clean up the context
3018                                        if let Err(e) = graph.finish_single() {
3019                                            let err_msg = format!(
3020                                                "Failed to clean up the context. Reason: {e}"
3021                                            );
3022
3023                                            #[cfg(feature = "logging")]
3024                                            error!(target: "stdout", "{}", &err_msg);
3025
3026                                            #[cfg(not(feature = "logging"))]
3027                                            println!(
3028                                                "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3029                                                &err_msg
3030                                            );
3031                                        }
3032                                    }
3033                                    false => match chat_graphs.iter_mut().next() {
3034                                        Some((_, graph)) => {
3035                                            // clean up the context
3036                                            if let Err(e) = graph.finish_single() {
3037                                                let err_msg = format!(
3038                                                    "Failed to clean up the context. Reason: {e}"
3039                                                );
3040
3041                                                #[cfg(feature = "logging")]
3042                                                error!(target: "stdout", "{}", &err_msg);
3043
3044                                                #[cfg(not(feature = "logging"))]
3045                                                println!(
3046                                                    "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3047                                                    &err_msg
3048                                                );
3049                                            }
3050                                        }
3051                                        None => {
3052                                            let err_msg =
3053                                                "There is no model available in the chat graphs.";
3054
3055                                            #[cfg(feature = "logging")]
3056                                            error!(target: "stdout", "{}", &err_msg);
3057
3058                                            #[cfg(not(feature = "logging"))]
3059                                            println!(
3060                                                "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3061                                                &err_msg
3062                                            );
3063                                        }
3064                                    },
3065                                },
3066                                Err(e) => {
3067                                    let err_msg =
3068                                        format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
3069
3070                                    #[cfg(feature = "logging")]
3071                                    error!(target: "stdout", "{}", &err_msg);
3072
3073                                    #[cfg(not(feature = "logging"))]
3074                                    println!(
3075                                        "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3076                                        &err_msg
3077                                    );
3078                                }
3079                            }
3080                        }
3081                        None => {
3082                            let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
3083
3084                            #[cfg(feature = "logging")]
3085                            error!(target: "stdout", "{}", &err_msg);
3086
3087                            #[cfg(not(feature = "logging"))]
3088                            println!(
3089                                "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3090                                &err_msg
3091                            );
3092                        }
3093                    };
3094                }
3095                None => {
3096                    match CHAT_GRAPHS.get() {
3097                        Some(chat_graphs) => {
3098                            match chat_graphs.lock() {
3099                                Ok(mut chat_graphs) => match chat_graphs.iter_mut().next() {
3100                                    Some((_, graph)) => {
3101                                        // clean up the context
3102                                        if let Err(e) = graph.finish_single() {
3103                                            let err_msg = format!(
3104                                                "Failed to clean up the context. Reason: {e}"
3105                                            );
3106
3107                                            #[cfg(feature = "logging")]
3108                                            error!(target: "stdout", "{}", &err_msg);
3109
3110                                            #[cfg(not(feature = "logging"))]
3111                                            println!(
3112                                                "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3113                                                &err_msg
3114                                            );
3115                                        }
3116                                    }
3117                                    None => {
3118                                        let err_msg =
3119                                            "There is no model available in the chat graphs.";
3120
3121                                        #[cfg(feature = "logging")]
3122                                        error!(target: "stdout", "{err_msg}");
3123
3124                                        #[cfg(not(feature = "logging"))]
3125                                        println!(
3126                                            "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3127                                            err_msg
3128                                        );
3129                                    }
3130                                },
3131                                Err(e) => {
3132                                    let err_msg =
3133                                        format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
3134
3135                                    #[cfg(feature = "logging")]
3136                                    error!(target: "stdout", "{}", &err_msg);
3137
3138                                    #[cfg(not(feature = "logging"))]
3139                                    println!(
3140                                        "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3141                                        &err_msg
3142                                    );
3143                                }
3144                            }
3145                        }
3146                        None => {
3147                            let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
3148
3149                            #[cfg(feature = "logging")]
3150                            error!(target: "stdout", "{}", &err_msg);
3151
3152                            #[cfg(not(feature = "logging"))]
3153                            println!(
3154                                "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3155                                &err_msg
3156                            );
3157                        }
3158                    };
3159                }
3160            }
3161
3162            #[cfg(feature = "logging")]
3163            info!(target: "stdout", "Model context cleanup done!");
3164        }
3165
3166        // reset the model metadata
3167        if let Err(e) = reset_model_metadata(self.model.as_ref()) {
3168            let err_msg = format!("Fail to reset model metadata. Reason: {e}");
3169
3170            #[cfg(feature = "logging")]
3171            error!(target: "stdout", "{}", &err_msg);
3172
3173            #[cfg(not(feature = "logging"))]
3174            println!("[ERROR][llama_core] {}", &err_msg);
3175        }
3176        #[cfg(feature = "logging")]
3177        info!(target: "stdout", "Model metadata reset done!");
3178
3179        // When dropping a ChatStream that held the lock, check if there are waiting streams
3180        if self.has_lock {
3181            // Reset the atomic flag
3182            CHAT_STREAM_ACTIVE.store(false, Ordering::SeqCst);
3183
3184            #[cfg(feature = "logging")]
3185            info!(target: "stdout", "Lock from ChatStream {} released", &self.id);
3186
3187            // Wake up waiting streams
3188            if let Ok(mut queue) = get_chat_stream_waker_queue().lock() {
3189                if let Some(waker) = queue.pop_front() {
3190                    #[cfg(feature = "logging")]
3191                    info!(target: "stdout", "Waking up a waiting ChatStream");
3192
3193                    waker.wake();
3194                }
3195            }
3196        }
3197    }
3198}
3199impl futures::Stream for ChatStream {
3200    type Item = Result<String, LlamaCoreError>;
3201
3202    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
3203        let this = self.get_mut();
3204
3205        // If this is a waiting stream, try to acquire the lock
3206        if this.is_waiting {
3207            if !this.try_acquire_lock() {
3208                // Store the waker to be notified when the lock becomes available
3209                if let Ok(mut queue) = get_chat_stream_waker_queue().lock() {
3210                    // Remove any previous instance of this waker
3211                    queue.retain(|w| !w.will_wake(cx.waker()));
3212                    // Add this waker to the queue
3213                    queue.push_back(cx.waker().clone());
3214
3215                    #[cfg(feature = "logging")]
3216                    debug!(target: "stdout", "ChatStream {} is waiting for lock, added waker to queue", &this.id);
3217                }
3218
3219                return Poll::Pending;
3220            }
3221
3222            #[cfg(feature = "logging")]
3223            info!(target: "stdout", "ChatStream {} acquired lock and is now active", &this.id);
3224            // If we got here, we successfully acquired the lock and can proceed
3225        }
3226
3227        // Ensure we still have the lock
3228        if !this.has_lock && !this.try_acquire_lock() {
3229            // Lost the lock, need to wait
3230            this.is_waiting = true;
3231
3232            // Register waker to be notified when lock is available
3233            if let Ok(mut queue) = get_chat_stream_waker_queue().lock() {
3234                queue.retain(|w| !w.will_wake(cx.waker()));
3235                queue.push_back(cx.waker().clone());
3236            }
3237
3238            return Poll::Pending;
3239        }
3240
3241        if this.cache.is_none() {
3242            let res = compute_stream(
3243                this.model.clone(),
3244                this.id.clone(),
3245                this.include_usage,
3246                &mut this.prompt_too_long_state,
3247                &mut this.context_full_state,
3248                &mut this.stream_state,
3249            );
3250
3251            match res {
3252                Ok(x) => {
3253                    #[cfg(feature = "logging")]
3254                    info!(target: "stdout", "next item for ChatStream {}: {}", &this.id, &x);
3255
3256                    if x != "[GGML] End of sequence" && !x.is_empty() {
3257                        Poll::Ready(Some(Ok(x)))
3258                    } else {
3259                        // stopped
3260                        Poll::Ready(None)
3261                    }
3262                }
3263                Err(e) => Poll::Ready(Some(Err(e))),
3264            }
3265        } else {
3266            let x = this.cache.as_mut().unwrap().pop_front();
3267
3268            #[cfg(feature = "logging")]
3269            info!(target: "stdout", "Get the next item from the cache for ChatStream {}: {:?}", &this.id, &x);
3270
3271            match x {
3272                Some(x) => Poll::Ready(Some(Ok(x))),
3273                None => Poll::Ready(None),
3274            }
3275        }
3276    }
3277}
3278
3279/// Helper function to get or initialize the waker queue for waiting ChatStreams
3280fn get_chat_stream_waker_queue() -> &'static Mutex<VecDeque<Waker>> {
3281    CHAT_STREAM_WAKER_QUEUE.get_or_init(|| {
3282        #[cfg(feature = "logging")]
3283        info!(target: "stdout", "Initializing ChatStream waker queue");
3284        Mutex::new(VecDeque::new())
3285    })
3286}
3287
3288fn compute_stream(
3289    model_name: Option<String>,
3290    id: String,
3291    include_usage: bool,
3292    prompt_too_long_state: &mut PromptTooLongState,
3293    context_full_state: &mut ContextFullState,
3294    stream_state: &mut StreamState,
3295) -> Result<String, LlamaCoreError> {
3296    #[cfg(feature = "logging")]
3297    info!(target: "stdout", "Computing stream chunk for ChatStream {}", &id);
3298
3299    #[cfg(feature = "logging")]
3300    debug!(target: "stdout", "prompt_too_long_state: {:?}", *prompt_too_long_state);
3301    #[cfg(feature = "logging")]
3302    debug!(target: "stdout", "context_full_state: {:?}", *context_full_state);
3303    #[cfg(feature = "logging")]
3304    debug!(target: "stdout", "stream_state: {:?}", *stream_state);
3305
3306    if *prompt_too_long_state == PromptTooLongState::EndOfSequence
3307        || *context_full_state == ContextFullState::EndOfSequence
3308        || *stream_state == StreamState::EndOfSequence
3309    {
3310        #[cfg(feature = "logging")]
3311        info!(target: "stdout", "Return the chat stream chunk!");
3312
3313        return Ok("[GGML] End of sequence".to_string());
3314    }
3315
3316    let chat_graphs = match CHAT_GRAPHS.get() {
3317        Some(chat_graphs) => chat_graphs,
3318        None => {
3319            let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
3320
3321            #[cfg(feature = "logging")]
3322            error!(target: "stdout", "{}", &err_msg);
3323
3324            return Err(LlamaCoreError::Operation(err_msg.into()));
3325        }
3326    };
3327
3328    // We're already holding the ChatStream lock, so we know we have exclusive access to the graph
3329    let mut chat_graphs = chat_graphs.lock().map_err(|e| {
3330        let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
3331
3332        #[cfg(feature = "logging")]
3333        error!(target: "stdout", "{}", &err_msg);
3334
3335        LlamaCoreError::Operation(err_msg)
3336    })?;
3337
3338    // Get the graph based on model name
3339    let res = match &model_name {
3340        Some(model_name) => {
3341            match chat_graphs.contains_key(model_name) {
3342                true => {
3343                    let graph = chat_graphs.get_mut(model_name).unwrap();
3344                    // compute
3345                    match graph.compute_single() {
3346                        Ok(_) => {
3347                            #[cfg(feature = "logging")]
3348                            debug!(target: "stdout", "Compute the chat stream chunk successfully.");
3349
3350                            // Process according to state
3351                            match stream_state {
3352                                StreamState::Usage | StreamState::NoUsage => {
3353                                    // Retrieve the output
3354                                    let output_buffer =
3355                                        get_output_buffer_single(graph, OUTPUT_TENSOR)?;
3356
3357                                    #[cfg(feature = "logging")]
3358                                    info!(target: "stdout", "retrieved the output buffer");
3359
3360                                    // decode the output buffer to a utf8 string
3361                                    let output = match String::from_utf8(output_buffer.clone()) {
3362                                        Ok(token) => token,
3363                                        Err(_) => {
3364                                            let mutex = CACHED_UTF8_ENCODINGS
3365                                                .get_or_init(|| Mutex::new(Vec::new()));
3366                                            let mut cached_encodings = mutex.lock().map_err(|e| {
3367                                            let err_msg = format!(
3368                                                "Fail to acquire the lock of `UTF8_ENCODINGS`. Reason: {e}"
3369                                            );
3370
3371                                            #[cfg(feature = "logging")]
3372                                            error!(target: "stdout", "{}", &err_msg);
3373
3374
3375                                            LlamaCoreError::Operation(err_msg)
3376                                        })?;
3377
3378                                            // cache the bytes for future decoding
3379                                            cached_encodings.extend_from_slice(&output_buffer[..]);
3380
3381                                            match String::from_utf8(cached_encodings.to_vec()) {
3382                                                Ok(token) => {
3383                                                    // clear CACHED_UTF8_ENCODINGS
3384                                                    cached_encodings.clear();
3385
3386                                                    token
3387                                                }
3388                                                Err(e) => {
3389                                                    // TODO This is a temp check. In case, infinite cached encodings happen.
3390                                                    if cached_encodings.len() > 4 {
3391                                                        let err_msg = format!("Fail to convert a vector of bytes to string. The length of the utf8 bytes exceeds 4. {e}");
3392
3393                                                        #[cfg(feature = "logging")]
3394                                                        error!(target: "stdout", "{}", &err_msg);
3395
3396                                                        #[cfg(feature = "logging")]
3397                                                        error!(target: "stdout", "The cached buffer: {:?}", &cached_encodings[..]);
3398
3399                                                        // let token = String::from_utf8_lossy(
3400                                                        //     &cached_encodings,
3401                                                        // )
3402                                                        // .to_string();
3403
3404                                                        // clear CACHED_UTF8_ENCODINGS
3405                                                        cached_encodings.clear();
3406
3407                                                        String::from("")
3408                                                    } else {
3409                                                        let warn_msg = format!("Fail to convert a vector of bytes to string. {e}");
3410
3411                                                        #[cfg(feature = "logging")]
3412                                                        warn!(target: "stdout", "{}", &warn_msg);
3413
3414                                                        String::from("")
3415                                                    }
3416                                                }
3417                                            }
3418                                        }
3419                                    };
3420
3421                                    #[cfg(feature = "logging")]
3422                                    info!(target: "stdout", "decoded the output buffer");
3423
3424                                    let created = SystemTime::now()
3425                                        .duration_since(std::time::UNIX_EPOCH)
3426                                        .map_err(|e| {
3427                                            let err_msg = format!(
3428                                                "Failed to get the current time. Reason: {e}"
3429                                            );
3430
3431                                            #[cfg(feature = "logging")]
3432                                            error!(target: "stdout", "{}", &err_msg);
3433
3434                                            LlamaCoreError::Operation(err_msg)
3435                                        })?;
3436
3437                                    let chat_completion_chunk = ChatCompletionChunk {
3438                                        id,
3439                                        object: "chat.completion.chunk".to_string(),
3440                                        created: created.as_secs(),
3441                                        model: graph.name().to_owned(),
3442                                        system_fingerprint: "fp_44709d6fcb".to_string(),
3443                                        choices: vec![ChatCompletionChunkChoice {
3444                                            index: 0,
3445                                            delta: ChatCompletionChunkChoiceDelta {
3446                                                role: ChatCompletionRole::Assistant,
3447                                                content: Some(output),
3448                                                tool_calls: vec![],
3449                                            },
3450                                            logprobs: None,
3451                                            finish_reason: None,
3452                                        }],
3453                                        usage: None,
3454                                    };
3455
3456                                    #[cfg(feature = "logging")]
3457                                    info!(target: "stdout", "created chat completion chunk");
3458
3459                                    // serialize chat completion chunk
3460                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
3461                                        .map_err(|e| {
3462                                        let err_msg = format!(
3463                                            "Failed to serialize chat completion chunk. Reason: {e}"
3464                                        );
3465
3466                                        #[cfg(feature = "logging")]
3467                                        error!(target: "stdout", "{}", &err_msg);
3468
3469                                        LlamaCoreError::Operation(err_msg)
3470                                    })?;
3471
3472                                    Ok(format!("data: {chunk_str}\n\n"))
3473                                }
3474                                StreamState::Done => {
3475                                    *stream_state = StreamState::EndOfSequence;
3476
3477                                    Ok("data: [DONE]\n\n".to_string())
3478                                }
3479                                StreamState::EndOfSequence => {
3480                                    Ok("[GGML] End of sequence".to_string())
3481                                }
3482                            }
3483                        }
3484                        Err(wasmedge_wasi_nn::Error::BackendError(
3485                            wasmedge_wasi_nn::BackendError::EndOfSequence,
3486                        )) => {
3487                            #[cfg(feature = "logging")]
3488                            debug!(target: "stdout", "End of sequence");
3489
3490                            match stream_state {
3491                                StreamState::Usage => {
3492                                    *stream_state = StreamState::Done;
3493
3494                                    // retrieve the number of prompt and completion tokens
3495                                    let token_info = get_token_info_by_graph(graph)?;
3496
3497                                    let usage = Some(Usage {
3498                                        prompt_tokens: token_info.prompt_tokens,
3499                                        completion_tokens: token_info.completion_tokens,
3500                                        total_tokens: token_info.prompt_tokens
3501                                            + token_info.completion_tokens,
3502                                    });
3503
3504                                    #[cfg(feature = "logging")]
3505                                    info!(target: "stdout", "token_info: {} prompt tokens, {} completion tokens", token_info.prompt_tokens, token_info.completion_tokens);
3506
3507                                    let created = SystemTime::now()
3508                                        .duration_since(std::time::UNIX_EPOCH)
3509                                        .map_err(|e| {
3510                                            let err_msg = format!(
3511                                                "Failed to get the current time. Reason: {e}"
3512                                            );
3513
3514                                            #[cfg(feature = "logging")]
3515                                            error!(target: "stdout", "{}", &err_msg);
3516
3517                                            LlamaCoreError::Operation(err_msg)
3518                                        })?;
3519
3520                                    let chat_completion_chunk = ChatCompletionChunk {
3521                                        id,
3522                                        object: "chat.completion.chunk".to_string(),
3523                                        created: created.as_secs(),
3524                                        model: graph.name().to_owned(),
3525                                        system_fingerprint: "fp_44709d6fcb".to_string(),
3526                                        choices: vec![],
3527                                        usage,
3528                                    };
3529
3530                                    // serialize chat completion chunk
3531                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
3532                                        .map_err(|e| {
3533                                        let err_msg = format!(
3534                                            "Failed to serialize chat completion chunk. Reason: {e}"
3535                                        );
3536
3537                                        #[cfg(feature = "logging")]
3538                                        error!(target: "stdout", "{}", &err_msg);
3539
3540                                        LlamaCoreError::Operation(err_msg)
3541                                    })?;
3542
3543                                    Ok(format!("data: {chunk_str}\n\n"))
3544                                }
3545                                StreamState::Done | StreamState::NoUsage => {
3546                                    *stream_state = StreamState::EndOfSequence;
3547
3548                                    Ok("data: [DONE]\n\n".to_string())
3549                                }
3550                                StreamState::EndOfSequence => {
3551                                    Ok("[GGML] End of sequence".to_string())
3552                                }
3553                            }
3554                        }
3555                        Err(wasmedge_wasi_nn::Error::BackendError(
3556                            wasmedge_wasi_nn::BackendError::ContextFull,
3557                        )) => {
3558                            #[cfg(feature = "logging")]
3559                            debug!(target: "stdout", "Context full");
3560
3561                            match context_full_state {
3562                                ContextFullState::Message => {
3563                                    match include_usage {
3564                                        true => *context_full_state = ContextFullState::Usage,
3565                                        false => *context_full_state = ContextFullState::Done,
3566                                    }
3567
3568                                    let created = SystemTime::now()
3569                                        .duration_since(std::time::UNIX_EPOCH)
3570                                        .map_err(|e| {
3571                                            let err_msg = format!(
3572                                                "Failed to get the current time. Reason: {e}"
3573                                            );
3574
3575                                            #[cfg(feature = "logging")]
3576                                            error!(target: "stdout", "{}", &err_msg);
3577
3578                                            LlamaCoreError::Operation(err_msg)
3579                                        })?;
3580
3581                                    let chat_completion_chunk = ChatCompletionChunk {
3582                                        id,
3583                                        object: "chat.completion.chunk".to_string(),
3584                                        created: created.as_secs(),
3585                                        model: graph.name().to_owned(),
3586                                        system_fingerprint: "fp_44709d6fcb".to_string(),
3587                                        choices: vec![ChatCompletionChunkChoice {
3588                                            index: 0,
3589                                            delta: ChatCompletionChunkChoiceDelta {
3590                                                role: ChatCompletionRole::Assistant,
3591                                                content: Some(
3592                                                    "<|WASMEDGE-GGML-CONTEXT-FULL|>".to_string(),
3593                                                ),
3594                                                tool_calls: vec![],
3595                                            },
3596                                            logprobs: None,
3597                                            finish_reason: Some(FinishReason::length),
3598                                        }],
3599                                        usage: None,
3600                                    };
3601
3602                                    // serialize chat completion chunk
3603                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
3604                                        .map_err(|e| {
3605                                        let err_msg = format!(
3606                                            "Failed to serialize chat completion chunk. Reason: {e}"
3607                                        );
3608
3609                                        #[cfg(feature = "logging")]
3610                                        error!(target: "stdout", "{}", &err_msg);
3611
3612                                        LlamaCoreError::Operation(err_msg)
3613                                    })?;
3614
3615                                    Ok(format!("data: {chunk_str}\n\n"))
3616                                }
3617                                ContextFullState::Usage => {
3618                                    *context_full_state = ContextFullState::Done;
3619
3620                                    // retrieve the number of prompt and completion tokens
3621                                    let token_info = get_token_info_by_graph(graph)?;
3622
3623                                    let usage = Some(Usage {
3624                                        prompt_tokens: token_info.prompt_tokens,
3625                                        completion_tokens: token_info.completion_tokens,
3626                                        total_tokens: token_info.prompt_tokens
3627                                            + token_info.completion_tokens,
3628                                    });
3629
3630                                    let created = SystemTime::now()
3631                                        .duration_since(std::time::UNIX_EPOCH)
3632                                        .map_err(|e| {
3633                                            let err_msg = format!(
3634                                                "Failed to get the current time. Reason: {e}"
3635                                            );
3636
3637                                            #[cfg(feature = "logging")]
3638                                            error!(target: "stdout", "{}", &err_msg);
3639
3640                                            LlamaCoreError::Operation(err_msg)
3641                                        })?;
3642
3643                                    let chat_completion_chunk = ChatCompletionChunk {
3644                                        id,
3645                                        object: "chat.completion.chunk".to_string(),
3646                                        created: created.as_secs(),
3647                                        model: graph.name().to_owned(),
3648                                        system_fingerprint: "fp_44709d6fcb".to_string(),
3649                                        choices: vec![],
3650                                        usage,
3651                                    };
3652
3653                                    // serialize chat completion chunk
3654                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
3655                                        .map_err(|e| {
3656                                        let err_msg = format!(
3657                                            "Failed to serialize chat completion chunk. Reason: {e}"
3658                                        );
3659
3660                                        #[cfg(feature = "logging")]
3661                                        error!(target: "stdout", "{}", &err_msg);
3662
3663                                        LlamaCoreError::Operation(err_msg)
3664                                    })?;
3665
3666                                    Ok(format!("data: {chunk_str}\n\n"))
3667                                }
3668                                ContextFullState::Done => {
3669                                    *context_full_state = ContextFullState::EndOfSequence;
3670
3671                                    Ok("data: [DONE]\n\n".to_string())
3672                                }
3673                                ContextFullState::EndOfSequence => {
3674                                    Ok("[GGML] End of sequence".to_string())
3675                                }
3676                            }
3677                        }
3678                        Err(wasmedge_wasi_nn::Error::BackendError(
3679                            wasmedge_wasi_nn::BackendError::PromptTooLong,
3680                        )) => {
3681                            #[cfg(feature = "logging")]
3682                            debug!(target: "stdout", "Prompt too long");
3683
3684                            match prompt_too_long_state {
3685                                PromptTooLongState::Message => {
3686                                    match include_usage {
3687                                        true => *prompt_too_long_state = PromptTooLongState::Usage,
3688                                        false => *prompt_too_long_state = PromptTooLongState::Done,
3689                                    }
3690
3691                                    let created = SystemTime::now()
3692                                        .duration_since(std::time::UNIX_EPOCH)
3693                                        .map_err(|e| {
3694                                            let err_msg = format!(
3695                                                "Failed to get the current time. Reason: {e}"
3696                                            );
3697
3698                                            #[cfg(feature = "logging")]
3699                                            error!(target: "stdout", "{}", &err_msg);
3700
3701                                            LlamaCoreError::Operation(err_msg)
3702                                        })?;
3703
3704                                    let chat_completion_chunk = ChatCompletionChunk {
3705                                        id,
3706                                        object: "chat.completion.chunk".to_string(),
3707                                        created: created.as_secs(),
3708                                        model: graph.name().to_owned(),
3709                                        system_fingerprint: "fp_44709d6fcb".to_string(),
3710                                        choices: vec![ChatCompletionChunkChoice {
3711                                            index: 0,
3712                                            delta: ChatCompletionChunkChoiceDelta {
3713                                                role: ChatCompletionRole::Assistant,
3714                                                content: None,
3715                                                tool_calls: vec![],
3716                                            },
3717                                            logprobs: None,
3718                                            finish_reason: Some(FinishReason::length),
3719                                        }],
3720                                        usage: None,
3721                                    };
3722
3723                                    // serialize chat completion chunk
3724                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
3725                                        .map_err(|e| {
3726                                        let err_msg = format!(
3727                                            "Failed to serialize chat completion chunk. Reason: {e}"
3728                                        );
3729
3730                                        #[cfg(feature = "logging")]
3731                                        error!(target: "stdout", "{}", &err_msg);
3732
3733                                        LlamaCoreError::Operation(err_msg)
3734                                    })?;
3735
3736                                    Ok(format!("data: {chunk_str}\n\n"))
3737                                }
3738                                PromptTooLongState::Usage => {
3739                                    *prompt_too_long_state = PromptTooLongState::Done;
3740
3741                                    // retrieve the number of prompt and completion tokens
3742                                    let token_info = get_token_info_by_graph(graph)?;
3743
3744                                    let usage = Some(Usage {
3745                                        prompt_tokens: token_info.prompt_tokens,
3746                                        completion_tokens: token_info.completion_tokens,
3747                                        total_tokens: token_info.prompt_tokens
3748                                            + token_info.completion_tokens,
3749                                    });
3750
3751                                    let created = SystemTime::now()
3752                                        .duration_since(std::time::UNIX_EPOCH)
3753                                        .map_err(|e| {
3754                                            let err_msg = format!(
3755                                                "Failed to get the current time. Reason: {e}"
3756                                            );
3757
3758                                            #[cfg(feature = "logging")]
3759                                            error!(target: "stdout", "{}", &err_msg);
3760
3761                                            LlamaCoreError::Operation(err_msg)
3762                                        })?;
3763
3764                                    let chat_completion_chunk = ChatCompletionChunk {
3765                                        id,
3766                                        object: "chat.completion.chunk".to_string(),
3767                                        created: created.as_secs(),
3768                                        model: graph.name().to_owned(),
3769                                        system_fingerprint: "fp_44709d6fcb".to_string(),
3770                                        choices: vec![],
3771                                        usage,
3772                                    };
3773
3774                                    // serialize chat completion chunk
3775                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
3776                                        .map_err(|e| {
3777                                        let err_msg = format!(
3778                                            "Failed to serialize chat completion chunk. Reason: {e}"
3779                                        );
3780
3781                                        #[cfg(feature = "logging")]
3782                                        error!(target: "stdout", "{}", &err_msg);
3783
3784                                        LlamaCoreError::Operation(err_msg)
3785                                    })?;
3786
3787                                    Ok(format!("data: {chunk_str}\n\n"))
3788                                }
3789                                PromptTooLongState::Done => {
3790                                    *prompt_too_long_state = PromptTooLongState::EndOfSequence;
3791
3792                                    Ok("data: [DONE]\n\n".to_string())
3793                                }
3794                                PromptTooLongState::EndOfSequence => {
3795                                    Ok("[GGML] End of sequence".to_string())
3796                                }
3797                            }
3798                        }
3799                        Err(e) => {
3800                            let err_msg =
3801                                format!("Failed to compute the chat completion. Reason: {e}");
3802
3803                            #[cfg(feature = "logging")]
3804                            error!(target: "stdout", "{}", &err_msg);
3805
3806                            Err(LlamaCoreError::Backend(BackendError::ComputeSingle(
3807                                err_msg,
3808                            )))
3809                        }
3810                    }
3811                }
3812                false => {
3813                    match chat_graphs.iter_mut().next() {
3814                        Some((_, graph)) => {
3815                            // compute
3816                            match graph.compute_single() {
3817                                Ok(_) => {
3818                                    #[cfg(feature = "logging")]
3819                                    debug!(target: "stdout", "Compute the chat stream chunk successfully.");
3820
3821                                    match stream_state {
3822                                        StreamState::Usage | StreamState::NoUsage => {
3823                                            // Retrieve the output
3824                                            let output_buffer =
3825                                                get_output_buffer_single(graph, OUTPUT_TENSOR)?;
3826
3827                                            #[cfg(feature = "logging")]
3828                                            info!(target: "stdout", "retrieved the output buffer");
3829
3830                                            // decode the output buffer to a utf8 string
3831                                            let output = match String::from_utf8(
3832                                                output_buffer.clone(),
3833                                            ) {
3834                                                Ok(token) => token,
3835                                                Err(_) => {
3836                                                    let mutex = CACHED_UTF8_ENCODINGS
3837                                                        .get_or_init(|| Mutex::new(Vec::new()));
3838                                                    let mut cached_encodings = mutex.lock().map_err(|e| {
3839                                            let err_msg = format!(
3840                                                "Fail to acquire the lock of `UTF8_ENCODINGS`. Reason: {e}"
3841                                            );
3842
3843                                            #[cfg(feature = "logging")]
3844                                            error!(target: "stdout", "{}", &err_msg);
3845
3846
3847                                            LlamaCoreError::Operation(err_msg)
3848                                        })?;
3849
3850                                                    // cache the bytes for future decoding
3851                                                    cached_encodings
3852                                                        .extend_from_slice(&output_buffer[..]);
3853
3854                                                    match String::from_utf8(
3855                                                        cached_encodings.to_vec(),
3856                                                    ) {
3857                                                        Ok(token) => {
3858                                                            // clear encodings
3859                                                            cached_encodings.clear();
3860
3861                                                            token
3862                                                        }
3863                                                        Err(e) => {
3864                                                            // TODO This is a temp check. In case, infinite cached encodings happen.
3865                                                            if cached_encodings.len() > 4 {
3866                                                                let err_msg = format!("Fail to convert a vector of bytes to string. The length of the utf8 bytes exceeds 4. {e}");
3867
3868                                                                #[cfg(feature = "logging")]
3869                                                                error!(target: "stdout", "{}", &err_msg);
3870
3871                                                                #[cfg(feature = "logging")]
3872                                                                error!(target: "stdout", "The cached buffer: {:?}", &cached_encodings[..]);
3873
3874                                                                // let token =
3875                                                                //     String::from_utf8_lossy(
3876                                                                //         &cached_encodings,
3877                                                                //     )
3878                                                                //     .to_string();
3879
3880                                                                // clear CACHED_UTF8_ENCODINGS
3881                                                                cached_encodings.clear();
3882
3883                                                                String::from("")
3884                                                            } else {
3885                                                                let warn_msg = format!("Fail to convert a vector of bytes to string. {e}");
3886
3887                                                                #[cfg(feature = "logging")]
3888                                                                warn!(target: "stdout", "{}", &warn_msg);
3889
3890                                                                String::from("")
3891                                                            }
3892                                                        }
3893                                                    }
3894                                                }
3895                                            };
3896
3897                                            #[cfg(feature = "logging")]
3898                                            info!(target: "stdout", "decoded the output buffer");
3899
3900                                            let created = SystemTime::now()
3901                                                .duration_since(std::time::UNIX_EPOCH)
3902                                                .map_err(|e| {
3903                                                    let err_msg = format!(
3904                                                "Failed to get the current time. Reason: {e}"
3905                                            );
3906
3907                                                    #[cfg(feature = "logging")]
3908                                                    error!(target: "stdout", "{}", &err_msg);
3909
3910                                                    LlamaCoreError::Operation(err_msg)
3911                                                })?;
3912
3913                                            let chat_completion_chunk = ChatCompletionChunk {
3914                                                id,
3915                                                object: "chat.completion.chunk".to_string(),
3916                                                created: created.as_secs(),
3917                                                model: graph.name().to_owned(),
3918                                                system_fingerprint: "fp_44709d6fcb".to_string(),
3919                                                choices: vec![ChatCompletionChunkChoice {
3920                                                    index: 0,
3921                                                    delta: ChatCompletionChunkChoiceDelta {
3922                                                        role: ChatCompletionRole::Assistant,
3923                                                        content: Some(output),
3924                                                        tool_calls: vec![],
3925                                                    },
3926                                                    logprobs: None,
3927                                                    finish_reason: None,
3928                                                }],
3929                                                usage: None,
3930                                            };
3931
3932                                            #[cfg(feature = "logging")]
3933                                            info!(target: "stdout", "created chat completion chunk");
3934
3935                                            // serialize chat completion chunk
3936                                            let chunk_str =
3937                                                serde_json::to_string(&chat_completion_chunk)
3938                                                    .map_err(|e| {
3939                                                        let err_msg = format!(
3940                                            "Failed to serialize chat completion chunk. Reason: {e}"
3941                                        );
3942
3943                                                        #[cfg(feature = "logging")]
3944                                                        error!(target: "stdout", "{}", &err_msg);
3945
3946                                                        LlamaCoreError::Operation(err_msg)
3947                                                    })?;
3948
3949                                            Ok(format!("data: {chunk_str}\n\n"))
3950                                        }
3951                                        StreamState::Done => {
3952                                            *stream_state = StreamState::EndOfSequence;
3953
3954                                            Ok("data: [DONE]\n\n".to_string())
3955                                        }
3956                                        StreamState::EndOfSequence => {
3957                                            Ok("[GGML] End of sequence".to_string())
3958                                        }
3959                                    }
3960                                }
3961                                Err(wasmedge_wasi_nn::Error::BackendError(
3962                                    wasmedge_wasi_nn::BackendError::EndOfSequence,
3963                                )) => {
3964                                    #[cfg(feature = "logging")]
3965                                    debug!(target: "stdout", "End of sequence");
3966
3967                                    match stream_state {
3968                                        StreamState::Usage => {
3969                                            *stream_state = StreamState::Done;
3970
3971                                            // retrieve the number of prompt and completion tokens
3972                                            let token_info = get_token_info_by_graph(graph)?;
3973
3974                                            let usage = Some(Usage {
3975                                                prompt_tokens: token_info.prompt_tokens,
3976                                                completion_tokens: token_info.completion_tokens,
3977                                                total_tokens: token_info.prompt_tokens
3978                                                    + token_info.completion_tokens,
3979                                            });
3980
3981                                            #[cfg(feature = "logging")]
3982                                            info!(target: "stdout", "token_info: {} prompt tokens, {} completion tokens", token_info.prompt_tokens, token_info.completion_tokens);
3983
3984                                            let created = SystemTime::now()
3985                                                .duration_since(std::time::UNIX_EPOCH)
3986                                                .map_err(|e| {
3987                                                    let err_msg = format!(
3988                                                "Failed to get the current time. Reason: {e}"
3989                                            );
3990
3991                                                    #[cfg(feature = "logging")]
3992                                                    error!(target: "stdout", "{}", &err_msg);
3993
3994                                                    LlamaCoreError::Operation(err_msg)
3995                                                })?;
3996
3997                                            let chat_completion_chunk = ChatCompletionChunk {
3998                                                id,
3999                                                object: "chat.completion.chunk".to_string(),
4000                                                created: created.as_secs(),
4001                                                model: graph.name().to_owned(),
4002                                                system_fingerprint: "fp_44709d6fcb".to_string(),
4003                                                choices: vec![],
4004                                                usage,
4005                                            };
4006
4007                                            // serialize chat completion chunk
4008                                            let chunk_str =
4009                                                serde_json::to_string(&chat_completion_chunk)
4010                                                    .map_err(|e| {
4011                                                        let err_msg = format!(
4012                                            "Failed to serialize chat completion chunk. Reason: {e}"
4013                                        );
4014
4015                                                        #[cfg(feature = "logging")]
4016                                                        error!(target: "stdout", "{}", &err_msg);
4017
4018                                                        LlamaCoreError::Operation(err_msg)
4019                                                    })?;
4020
4021                                            Ok(format!("data: {chunk_str}\n\n"))
4022                                        }
4023                                        StreamState::Done | StreamState::NoUsage => {
4024                                            *stream_state = StreamState::EndOfSequence;
4025
4026                                            Ok("data: [DONE]\n\n".to_string())
4027                                        }
4028                                        StreamState::EndOfSequence => {
4029                                            Ok("[GGML] End of sequence".to_string())
4030                                        }
4031                                    }
4032                                }
4033                                Err(wasmedge_wasi_nn::Error::BackendError(
4034                                    wasmedge_wasi_nn::BackendError::ContextFull,
4035                                )) => {
4036                                    #[cfg(feature = "logging")]
4037                                    debug!(target: "stdout", "Context full");
4038
4039                                    match context_full_state {
4040                                        ContextFullState::Message => {
4041                                            match include_usage {
4042                                                true => {
4043                                                    *context_full_state = ContextFullState::Usage
4044                                                }
4045                                                false => {
4046                                                    *context_full_state = ContextFullState::Done
4047                                                }
4048                                            }
4049
4050                                            let created = SystemTime::now()
4051                                                .duration_since(std::time::UNIX_EPOCH)
4052                                                .map_err(|e| {
4053                                                    let err_msg = format!(
4054                                                "Failed to get the current time. Reason: {e}"
4055                                            );
4056
4057                                                    #[cfg(feature = "logging")]
4058                                                    error!(target: "stdout", "{}", &err_msg);
4059
4060                                                    LlamaCoreError::Operation(err_msg)
4061                                                })?;
4062
4063                                            let chat_completion_chunk = ChatCompletionChunk {
4064                                                id,
4065                                                object: "chat.completion.chunk".to_string(),
4066                                                created: created.as_secs(),
4067                                                model: graph.name().to_owned(),
4068                                                system_fingerprint: "fp_44709d6fcb".to_string(),
4069                                                choices: vec![ChatCompletionChunkChoice {
4070                                                    index: 0,
4071                                                    delta: ChatCompletionChunkChoiceDelta {
4072                                                        role: ChatCompletionRole::Assistant,
4073                                                        content: Some(
4074                                                            "<|WASMEDGE-GGML-CONTEXT-FULL|>"
4075                                                                .to_string(),
4076                                                        ),
4077                                                        tool_calls: vec![],
4078                                                    },
4079                                                    logprobs: None,
4080                                                    finish_reason: Some(FinishReason::length),
4081                                                }],
4082                                                usage: None,
4083                                            };
4084
4085                                            // serialize chat completion chunk
4086                                            let chunk_str =
4087                                                serde_json::to_string(&chat_completion_chunk)
4088                                                    .map_err(|e| {
4089                                                        let err_msg = format!(
4090                                            "Failed to serialize chat completion chunk. Reason: {e}"
4091                                        );
4092
4093                                                        #[cfg(feature = "logging")]
4094                                                        error!(target: "stdout", "{}", &err_msg);
4095
4096                                                        LlamaCoreError::Operation(err_msg)
4097                                                    })?;
4098
4099                                            Ok(format!("data: {chunk_str}\n\n"))
4100                                        }
4101                                        ContextFullState::Usage => {
4102                                            *context_full_state = ContextFullState::Done;
4103
4104                                            // retrieve the number of prompt and completion tokens
4105                                            let token_info = get_token_info_by_graph(graph)?;
4106
4107                                            let usage = Some(Usage {
4108                                                prompt_tokens: token_info.prompt_tokens,
4109                                                completion_tokens: token_info.completion_tokens,
4110                                                total_tokens: token_info.prompt_tokens
4111                                                    + token_info.completion_tokens,
4112                                            });
4113
4114                                            let created = SystemTime::now()
4115                                                .duration_since(std::time::UNIX_EPOCH)
4116                                                .map_err(|e| {
4117                                                    let err_msg = format!(
4118                                                "Failed to get the current time. Reason: {e}"
4119                                            );
4120
4121                                                    #[cfg(feature = "logging")]
4122                                                    error!(target: "stdout", "{}", &err_msg);
4123
4124                                                    LlamaCoreError::Operation(err_msg)
4125                                                })?;
4126
4127                                            let chat_completion_chunk = ChatCompletionChunk {
4128                                                id,
4129                                                object: "chat.completion.chunk".to_string(),
4130                                                created: created.as_secs(),
4131                                                model: graph.name().to_owned(),
4132                                                system_fingerprint: "fp_44709d6fcb".to_string(),
4133                                                choices: vec![],
4134                                                usage,
4135                                            };
4136
4137                                            // serialize chat completion chunk
4138                                            let chunk_str =
4139                                                serde_json::to_string(&chat_completion_chunk)
4140                                                    .map_err(|e| {
4141                                                        let err_msg = format!(
4142                                            "Failed to serialize chat completion chunk. Reason: {e}"
4143                                        );
4144
4145                                                        #[cfg(feature = "logging")]
4146                                                        error!(target: "stdout", "{}", &err_msg);
4147
4148                                                        LlamaCoreError::Operation(err_msg)
4149                                                    })?;
4150
4151                                            Ok(format!("data: {chunk_str}\n\n"))
4152                                        }
4153                                        ContextFullState::Done => {
4154                                            *context_full_state = ContextFullState::EndOfSequence;
4155
4156                                            Ok("data: [DONE]\n\n".to_string())
4157                                        }
4158                                        ContextFullState::EndOfSequence => {
4159                                            Ok("[GGML] End of sequence".to_string())
4160                                        }
4161                                    }
4162                                }
4163                                Err(wasmedge_wasi_nn::Error::BackendError(
4164                                    wasmedge_wasi_nn::BackendError::PromptTooLong,
4165                                )) => {
4166                                    #[cfg(feature = "logging")]
4167                                    debug!(target: "stdout", "Prompt too long");
4168
4169                                    match prompt_too_long_state {
4170                                        PromptTooLongState::Message => {
4171                                            match include_usage {
4172                                                true => {
4173                                                    *prompt_too_long_state =
4174                                                        PromptTooLongState::Usage
4175                                                }
4176                                                false => {
4177                                                    *prompt_too_long_state =
4178                                                        PromptTooLongState::Done
4179                                                }
4180                                            }
4181
4182                                            let created = SystemTime::now()
4183                                                .duration_since(std::time::UNIX_EPOCH)
4184                                                .map_err(|e| {
4185                                                    let err_msg = format!(
4186                                                "Failed to get the current time. Reason: {e}"
4187                                            );
4188
4189                                                    #[cfg(feature = "logging")]
4190                                                    error!(target: "stdout", "{}", &err_msg);
4191
4192                                                    LlamaCoreError::Operation(err_msg)
4193                                                })?;
4194
4195                                            let chat_completion_chunk = ChatCompletionChunk {
4196                                                id,
4197                                                object: "chat.completion.chunk".to_string(),
4198                                                created: created.as_secs(),
4199                                                model: graph.name().to_owned(),
4200                                                system_fingerprint: "fp_44709d6fcb".to_string(),
4201                                                choices: vec![ChatCompletionChunkChoice {
4202                                                    index: 0,
4203                                                    delta: ChatCompletionChunkChoiceDelta {
4204                                                        role: ChatCompletionRole::Assistant,
4205                                                        content: None,
4206                                                        tool_calls: vec![],
4207                                                    },
4208                                                    logprobs: None,
4209                                                    finish_reason: Some(FinishReason::length),
4210                                                }],
4211                                                usage: None,
4212                                            };
4213
4214                                            // serialize chat completion chunk
4215                                            let chunk_str =
4216                                                serde_json::to_string(&chat_completion_chunk)
4217                                                    .map_err(|e| {
4218                                                        let err_msg = format!(
4219                                            "Failed to serialize chat completion chunk. Reason: {e}"
4220                                        );
4221
4222                                                        #[cfg(feature = "logging")]
4223                                                        error!(target: "stdout", "{}", &err_msg);
4224
4225                                                        LlamaCoreError::Operation(err_msg)
4226                                                    })?;
4227
4228                                            Ok(format!("data: {chunk_str}\n\n"))
4229                                        }
4230                                        PromptTooLongState::Usage => {
4231                                            *prompt_too_long_state = PromptTooLongState::Done;
4232
4233                                            // retrieve the number of prompt and completion tokens
4234                                            let token_info = get_token_info_by_graph(graph)?;
4235
4236                                            let usage = Some(Usage {
4237                                                prompt_tokens: token_info.prompt_tokens,
4238                                                completion_tokens: token_info.completion_tokens,
4239                                                total_tokens: token_info.prompt_tokens
4240                                                    + token_info.completion_tokens,
4241                                            });
4242
4243                                            let created = SystemTime::now()
4244                                                .duration_since(std::time::UNIX_EPOCH)
4245                                                .map_err(|e| {
4246                                                    let err_msg = format!(
4247                                                "Failed to get the current time. Reason: {e}"
4248                                            );
4249
4250                                                    #[cfg(feature = "logging")]
4251                                                    error!(target: "stdout", "{}", &err_msg);
4252
4253                                                    LlamaCoreError::Operation(err_msg)
4254                                                })?;
4255
4256                                            let chat_completion_chunk = ChatCompletionChunk {
4257                                                id,
4258                                                object: "chat.completion.chunk".to_string(),
4259                                                created: created.as_secs(),
4260                                                model: graph.name().to_owned(),
4261                                                system_fingerprint: "fp_44709d6fcb".to_string(),
4262                                                choices: vec![],
4263                                                usage,
4264                                            };
4265
4266                                            // serialize chat completion chunk
4267                                            let chunk_str =
4268                                                serde_json::to_string(&chat_completion_chunk)
4269                                                    .map_err(|e| {
4270                                                        let err_msg = format!(
4271                                            "Failed to serialize chat completion chunk. Reason: {e}"
4272                                        );
4273
4274                                                        #[cfg(feature = "logging")]
4275                                                        error!(target: "stdout", "{}", &err_msg);
4276
4277                                                        LlamaCoreError::Operation(err_msg)
4278                                                    })?;
4279
4280                                            Ok(format!("data: {chunk_str}\n\n"))
4281                                        }
4282                                        PromptTooLongState::Done => {
4283                                            *prompt_too_long_state =
4284                                                PromptTooLongState::EndOfSequence;
4285
4286                                            Ok("data: [DONE]\n\n".to_string())
4287                                        }
4288                                        PromptTooLongState::EndOfSequence => {
4289                                            Ok("[GGML] End of sequence".to_string())
4290                                        }
4291                                    }
4292                                }
4293                                Err(e) => {
4294                                    let err_msg = format!(
4295                                        "Failed to compute the chat completion. Reason: {e}"
4296                                    );
4297
4298                                    #[cfg(feature = "logging")]
4299                                    error!(target: "stdout", "{}", &err_msg);
4300
4301                                    Err(LlamaCoreError::Backend(BackendError::ComputeSingle(
4302                                        err_msg,
4303                                    )))
4304                                }
4305                            }
4306                        }
4307                        None => {
4308                            let err_msg = "There is no model available in the chat graphs.";
4309
4310                            #[cfg(feature = "logging")]
4311                            error!(target: "stdout", "{}", &err_msg);
4312
4313                            Err(LlamaCoreError::Operation(err_msg.into()))
4314                        }
4315                    }
4316                }
4317            }
4318        }
4319        None => {
4320            match chat_graphs.iter_mut().next() {
4321                Some((_, graph)) => {
4322                    // compute
4323                    match graph.compute_single() {
4324                        Ok(_) => {
4325                            #[cfg(feature = "logging")]
4326                            debug!(target: "stdout", "Compute the chat stream chunk successfully.");
4327
4328                            match stream_state {
4329                                StreamState::Usage | StreamState::NoUsage => {
4330                                    // Retrieve the output
4331                                    let output_buffer =
4332                                        get_output_buffer_single(graph, OUTPUT_TENSOR)?;
4333
4334                                    #[cfg(feature = "logging")]
4335                                    info!(target: "stdout", "retrieved the output buffer");
4336
4337                                    // decode the output buffer to a utf8 string
4338                                    let output = match String::from_utf8(output_buffer.clone()) {
4339                                        Ok(token) => token,
4340                                        Err(_) => {
4341                                            let mutex = CACHED_UTF8_ENCODINGS
4342                                                .get_or_init(|| Mutex::new(Vec::new()));
4343                                            let mut cached_encodings = mutex.lock().map_err(|e| {
4344                                            let err_msg = format!(
4345                                                "Fail to acquire the lock of `UTF8_ENCODINGS`. Reason: {e}"
4346                                            );
4347
4348                                            #[cfg(feature = "logging")]
4349                                            error!(target: "stdout", "{}", &err_msg);
4350
4351                                            LlamaCoreError::Operation(err_msg)
4352                                        })?;
4353
4354                                            cached_encodings.extend_from_slice(&output_buffer[..]);
4355
4356                                            match String::from_utf8(cached_encodings.to_vec()) {
4357                                                Ok(token) => {
4358                                                    // clear encodings
4359                                                    cached_encodings.clear();
4360
4361                                                    token
4362                                                }
4363                                                Err(e) => {
4364                                                    // TODO This is a temp check. In case, infinite cached encodings happen.
4365                                                    if cached_encodings.len() > 4 {
4366                                                        let err_msg = format!("Fail to convert a vector of bytes to string. The length of the utf8 bytes exceeds 4. {e}");
4367
4368                                                        #[cfg(feature = "logging")]
4369                                                        error!(target: "stdout", "{}", &err_msg);
4370
4371                                                        #[cfg(feature = "logging")]
4372                                                        error!(target: "stdout", "The cached buffer: {:?}", &cached_encodings[..]);
4373
4374                                                        // let token = String::from_utf8_lossy(
4375                                                        //     &cached_encodings,
4376                                                        // )
4377                                                        // .to_string();
4378
4379                                                        // clear CACHED_UTF8_ENCODINGS
4380                                                        cached_encodings.clear();
4381
4382                                                        String::from("")
4383                                                    } else {
4384                                                        let warn_msg = format!("Fail to convert a vector of bytes to string. {e}");
4385
4386                                                        #[cfg(feature = "logging")]
4387                                                        warn!(target: "stdout", "{}", &warn_msg);
4388
4389                                                        String::from("")
4390                                                    }
4391                                                }
4392                                            }
4393                                        }
4394                                    };
4395
4396                                    #[cfg(feature = "logging")]
4397                                    info!(target: "stdout", "decoded the output buffer");
4398
4399                                    let created = SystemTime::now()
4400                                        .duration_since(std::time::UNIX_EPOCH)
4401                                        .map_err(|e| {
4402                                            let err_msg = format!(
4403                                                "Failed to get the current time. Reason: {e}"
4404                                            );
4405
4406                                            #[cfg(feature = "logging")]
4407                                            error!(target: "stdout", "{}", &err_msg);
4408
4409                                            LlamaCoreError::Operation(err_msg)
4410                                        })?;
4411
4412                                    let chat_completion_chunk = ChatCompletionChunk {
4413                                        id,
4414                                        object: "chat.completion.chunk".to_string(),
4415                                        created: created.as_secs(),
4416                                        model: graph.name().to_owned(),
4417                                        system_fingerprint: "fp_44709d6fcb".to_string(),
4418                                        choices: vec![ChatCompletionChunkChoice {
4419                                            index: 0,
4420                                            delta: ChatCompletionChunkChoiceDelta {
4421                                                role: ChatCompletionRole::Assistant,
4422                                                content: Some(output),
4423                                                tool_calls: vec![],
4424                                            },
4425                                            logprobs: None,
4426                                            finish_reason: None,
4427                                        }],
4428                                        usage: None,
4429                                    };
4430
4431                                    #[cfg(feature = "logging")]
4432                                    info!(target: "stdout", "created chat completion chunk");
4433
4434                                    // serialize chat completion chunk
4435                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
4436                                        .map_err(|e| {
4437                                        let err_msg = format!(
4438                                            "Failed to serialize chat completion chunk. Reason: {e}"
4439                                        );
4440
4441                                        #[cfg(feature = "logging")]
4442                                        error!(target: "stdout", "{}", &err_msg);
4443
4444                                        LlamaCoreError::Operation(err_msg)
4445                                    })?;
4446
4447                                    Ok(format!("data: {chunk_str}\n\n"))
4448                                }
4449                                StreamState::Done => {
4450                                    *stream_state = StreamState::EndOfSequence;
4451
4452                                    Ok("data: [DONE]\n\n".to_string())
4453                                }
4454                                StreamState::EndOfSequence => {
4455                                    Ok("[GGML] End of sequence".to_string())
4456                                }
4457                            }
4458                        }
4459                        Err(wasmedge_wasi_nn::Error::BackendError(
4460                            wasmedge_wasi_nn::BackendError::EndOfSequence,
4461                        )) => {
4462                            #[cfg(feature = "logging")]
4463                            debug!(target: "stdout", "End of sequence");
4464
4465                            match stream_state {
4466                                StreamState::Usage => {
4467                                    *stream_state = StreamState::Done;
4468
4469                                    // retrieve the number of prompt and completion tokens
4470                                    let token_info = get_token_info_by_graph(graph)?;
4471
4472                                    let usage = Some(Usage {
4473                                        prompt_tokens: token_info.prompt_tokens,
4474                                        completion_tokens: token_info.completion_tokens,
4475                                        total_tokens: token_info.prompt_tokens
4476                                            + token_info.completion_tokens,
4477                                    });
4478
4479                                    #[cfg(feature = "logging")]
4480                                    info!(target: "stdout", "token_info: {} prompt tokens, {} completion tokens", token_info.prompt_tokens, token_info.completion_tokens);
4481
4482                                    let created = SystemTime::now()
4483                                        .duration_since(std::time::UNIX_EPOCH)
4484                                        .map_err(|e| {
4485                                            let err_msg = format!(
4486                                                "Failed to get the current time. Reason: {e}"
4487                                            );
4488
4489                                            #[cfg(feature = "logging")]
4490                                            error!(target: "stdout", "{}", &err_msg);
4491
4492                                            LlamaCoreError::Operation(err_msg)
4493                                        })?;
4494
4495                                    let chat_completion_chunk = ChatCompletionChunk {
4496                                        id,
4497                                        object: "chat.completion.chunk".to_string(),
4498                                        created: created.as_secs(),
4499                                        model: graph.name().to_owned(),
4500                                        system_fingerprint: "fp_44709d6fcb".to_string(),
4501                                        choices: vec![],
4502                                        usage,
4503                                    };
4504
4505                                    // serialize chat completion chunk
4506                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
4507                                        .map_err(|e| {
4508                                        let err_msg = format!(
4509                                            "Failed to serialize chat completion chunk. Reason: {e}"
4510                                        );
4511
4512                                        #[cfg(feature = "logging")]
4513                                        error!(target: "stdout", "{}", &err_msg);
4514
4515                                        LlamaCoreError::Operation(err_msg)
4516                                    })?;
4517
4518                                    Ok(format!("data: {chunk_str}\n\n"))
4519                                }
4520                                StreamState::Done | StreamState::NoUsage => {
4521                                    *stream_state = StreamState::EndOfSequence;
4522
4523                                    Ok("data: [DONE]\n\n".to_string())
4524                                }
4525                                StreamState::EndOfSequence => {
4526                                    Ok("[GGML] End of sequence".to_string())
4527                                }
4528                            }
4529                        }
4530                        Err(wasmedge_wasi_nn::Error::BackendError(
4531                            wasmedge_wasi_nn::BackendError::ContextFull,
4532                        )) => {
4533                            #[cfg(feature = "logging")]
4534                            debug!(target: "stdout", "Context full");
4535
4536                            match context_full_state {
4537                                ContextFullState::Message => {
4538                                    match include_usage {
4539                                        true => *context_full_state = ContextFullState::Usage,
4540                                        false => *context_full_state = ContextFullState::Done,
4541                                    }
4542
4543                                    let created = SystemTime::now()
4544                                        .duration_since(std::time::UNIX_EPOCH)
4545                                        .map_err(|e| {
4546                                            let err_msg = format!(
4547                                                "Failed to get the current time. Reason: {e}"
4548                                            );
4549
4550                                            #[cfg(feature = "logging")]
4551                                            error!(target: "stdout", "{}", &err_msg);
4552
4553                                            LlamaCoreError::Operation(err_msg)
4554                                        })?;
4555
4556                                    let chat_completion_chunk = ChatCompletionChunk {
4557                                        id,
4558                                        object: "chat.completion.chunk".to_string(),
4559                                        created: created.as_secs(),
4560                                        model: graph.name().to_owned(),
4561                                        system_fingerprint: "fp_44709d6fcb".to_string(),
4562                                        choices: vec![ChatCompletionChunkChoice {
4563                                            index: 0,
4564                                            delta: ChatCompletionChunkChoiceDelta {
4565                                                role: ChatCompletionRole::Assistant,
4566                                                content: Some(
4567                                                    "<|WASMEDGE-GGML-CONTEXT-FULL|>".to_string(),
4568                                                ),
4569                                                tool_calls: vec![],
4570                                            },
4571                                            logprobs: None,
4572                                            finish_reason: Some(FinishReason::length),
4573                                        }],
4574                                        usage: None,
4575                                    };
4576
4577                                    // serialize chat completion chunk
4578                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
4579                                        .map_err(|e| {
4580                                        let err_msg = format!(
4581                                            "Failed to serialize chat completion chunk. Reason: {e}"
4582                                        );
4583
4584                                        #[cfg(feature = "logging")]
4585                                        error!(target: "stdout", "{}", &err_msg);
4586
4587                                        LlamaCoreError::Operation(err_msg)
4588                                    })?;
4589
4590                                    Ok(format!("data: {chunk_str}\n\n"))
4591                                }
4592                                ContextFullState::Usage => {
4593                                    *context_full_state = ContextFullState::Done;
4594
4595                                    // retrieve the number of prompt and completion tokens
4596                                    let token_info = get_token_info_by_graph(graph)?;
4597
4598                                    let usage = Some(Usage {
4599                                        prompt_tokens: token_info.prompt_tokens,
4600                                        completion_tokens: token_info.completion_tokens,
4601                                        total_tokens: token_info.prompt_tokens
4602                                            + token_info.completion_tokens,
4603                                    });
4604
4605                                    let created = SystemTime::now()
4606                                        .duration_since(std::time::UNIX_EPOCH)
4607                                        .map_err(|e| {
4608                                            let err_msg = format!(
4609                                                "Failed to get the current time. Reason: {e}"
4610                                            );
4611
4612                                            #[cfg(feature = "logging")]
4613                                            error!(target: "stdout", "{}", &err_msg);
4614
4615                                            LlamaCoreError::Operation(err_msg)
4616                                        })?;
4617
4618                                    let chat_completion_chunk = ChatCompletionChunk {
4619                                        id,
4620                                        object: "chat.completion.chunk".to_string(),
4621                                        created: created.as_secs(),
4622                                        model: graph.name().to_owned(),
4623                                        system_fingerprint: "fp_44709d6fcb".to_string(),
4624                                        choices: vec![],
4625                                        usage,
4626                                    };
4627
4628                                    // serialize chat completion chunk
4629                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
4630                                        .map_err(|e| {
4631                                        let err_msg = format!(
4632                                            "Failed to serialize chat completion chunk. Reason: {e}"
4633                                        );
4634
4635                                        #[cfg(feature = "logging")]
4636                                        error!(target: "stdout", "{}", &err_msg);
4637
4638                                        LlamaCoreError::Operation(err_msg)
4639                                    })?;
4640
4641                                    Ok(format!("data: {chunk_str}\n\n"))
4642                                }
4643                                ContextFullState::Done => {
4644                                    *context_full_state = ContextFullState::EndOfSequence;
4645
4646                                    Ok("data: [DONE]\n\n".to_string())
4647                                }
4648                                ContextFullState::EndOfSequence => {
4649                                    Ok("[GGML] End of sequence".to_string())
4650                                }
4651                            }
4652                        }
4653                        Err(wasmedge_wasi_nn::Error::BackendError(
4654                            wasmedge_wasi_nn::BackendError::PromptTooLong,
4655                        )) => {
4656                            #[cfg(feature = "logging")]
4657                            debug!(target: "stdout", "Prompt too long");
4658
4659                            match prompt_too_long_state {
4660                                PromptTooLongState::Message => {
4661                                    match include_usage {
4662                                        true => *prompt_too_long_state = PromptTooLongState::Usage,
4663                                        false => *prompt_too_long_state = PromptTooLongState::Done,
4664                                    }
4665
4666                                    let created = SystemTime::now()
4667                                        .duration_since(std::time::UNIX_EPOCH)
4668                                        .map_err(|e| {
4669                                            let err_msg = format!(
4670                                                "Failed to get the current time. Reason: {e}"
4671                                            );
4672
4673                                            #[cfg(feature = "logging")]
4674                                            error!(target: "stdout", "{}", &err_msg);
4675
4676                                            LlamaCoreError::Operation(err_msg)
4677                                        })?;
4678
4679                                    let chat_completion_chunk = ChatCompletionChunk {
4680                                        id,
4681                                        object: "chat.completion.chunk".to_string(),
4682                                        created: created.as_secs(),
4683                                        model: graph.name().to_owned(),
4684                                        system_fingerprint: "fp_44709d6fcb".to_string(),
4685                                        choices: vec![ChatCompletionChunkChoice {
4686                                            index: 0,
4687                                            delta: ChatCompletionChunkChoiceDelta {
4688                                                role: ChatCompletionRole::Assistant,
4689                                                content: None,
4690                                                tool_calls: vec![],
4691                                            },
4692                                            logprobs: None,
4693                                            finish_reason: Some(FinishReason::length),
4694                                        }],
4695                                        usage: None,
4696                                    };
4697
4698                                    // serialize chat completion chunk
4699                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
4700                                        .map_err(|e| {
4701                                        let err_msg = format!(
4702                                            "Failed to serialize chat completion chunk. Reason: {e}"
4703                                        );
4704
4705                                        #[cfg(feature = "logging")]
4706                                        error!(target: "stdout", "{}", &err_msg);
4707
4708                                        LlamaCoreError::Operation(err_msg)
4709                                    })?;
4710
4711                                    Ok(format!("data: {chunk_str}\n\n"))
4712                                }
4713                                PromptTooLongState::Usage => {
4714                                    *prompt_too_long_state = PromptTooLongState::Done;
4715
4716                                    // retrieve the number of prompt and completion tokens
4717                                    let token_info = get_token_info_by_graph(graph)?;
4718
4719                                    let usage = Some(Usage {
4720                                        prompt_tokens: token_info.prompt_tokens,
4721                                        completion_tokens: token_info.completion_tokens,
4722                                        total_tokens: token_info.prompt_tokens
4723                                            + token_info.completion_tokens,
4724                                    });
4725
4726                                    let created = SystemTime::now()
4727                                        .duration_since(std::time::UNIX_EPOCH)
4728                                        .map_err(|e| {
4729                                            let err_msg = format!(
4730                                                "Failed to get the current time. Reason: {e}"
4731                                            );
4732
4733                                            #[cfg(feature = "logging")]
4734                                            error!(target: "stdout", "{}", &err_msg);
4735
4736                                            LlamaCoreError::Operation(err_msg)
4737                                        })?;
4738
4739                                    let chat_completion_chunk = ChatCompletionChunk {
4740                                        id,
4741                                        object: "chat.completion.chunk".to_string(),
4742                                        created: created.as_secs(),
4743                                        model: graph.name().to_owned(),
4744                                        system_fingerprint: "fp_44709d6fcb".to_string(),
4745                                        choices: vec![],
4746                                        usage,
4747                                    };
4748
4749                                    // serialize chat completion chunk
4750                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
4751                                        .map_err(|e| {
4752                                        let err_msg = format!(
4753                                            "Failed to serialize chat completion chunk. Reason: {e}"
4754                                        );
4755
4756                                        #[cfg(feature = "logging")]
4757                                        error!(target: "stdout", "{}", &err_msg);
4758
4759                                        LlamaCoreError::Operation(err_msg)
4760                                    })?;
4761
4762                                    Ok(format!("data: {chunk_str}\n\n"))
4763                                }
4764                                PromptTooLongState::Done => {
4765                                    *prompt_too_long_state = PromptTooLongState::EndOfSequence;
4766
4767                                    Ok("data: [DONE]\n\n".to_string())
4768                                }
4769                                PromptTooLongState::EndOfSequence => {
4770                                    Ok("[GGML] End of sequence".to_string())
4771                                }
4772                            }
4773                        }
4774                        Err(e) => {
4775                            let err_msg =
4776                                format!("Failed to compute the chat completion. Reason: {e}");
4777
4778                            #[cfg(feature = "logging")]
4779                            error!(target: "stdout", "{}", &err_msg);
4780
4781                            Err(LlamaCoreError::Backend(BackendError::ComputeSingle(
4782                                err_msg,
4783                            )))
4784                        }
4785                    }
4786                }
4787                None => {
4788                    let err_msg = "There is no model available in the chat graphs.";
4789
4790                    #[cfg(feature = "logging")]
4791                    error!(target: "stdout", "{}", &err_msg);
4792
4793                    Err(LlamaCoreError::Operation(err_msg.into()))
4794                }
4795            }
4796        }
4797    };
4798
4799    #[cfg(feature = "logging")]
4800    info!(target: "stdout", "Return the chat stream chunk!");
4801
4802    res
4803}
4804
4805#[allow(dead_code)]
4806#[derive(Debug)]
4807struct ParseResult {
4808    raw: String,
4809    content: Option<String>,
4810    tool_calls: Vec<ToolCall>,
4811}