Skip to main content

llama_core/chat/
chat_completions.rs

1//! Define APIs for chat completion.
2
3use crate::{
4    error,
5    metadata::ggml::GgmlMetadata,
6    running_mode,
7    utils::{
8        gen_chat_id, get_output_buffer, get_output_buffer_single, get_token_info_by_graph,
9        get_token_info_by_graph_name, set_tensor_data_u8,
10    },
11    Graph, RunningMode, CHAT_GRAPHS, OUTPUT_TENSOR,
12};
13use chat_prompts::{BuildChatPrompt, ChatPrompt, PromptTemplateType};
14use either::{Either, Left, Right};
15use endpoints::{
16    chat::{
17        ChatCompletionChunk, ChatCompletionChunkChoice, ChatCompletionChunkChoiceDelta,
18        ChatCompletionObject, ChatCompletionObjectChoice, ChatCompletionObjectMessage,
19        ChatCompletionRequest, ChatCompletionRequestMessage, ChatCompletionRole,
20        ChatCompletionUserMessageContent, ContentPart, Function, ToolCall, ToolCallForChunk,
21        ToolChoice,
22    },
23    common::{FinishReason, Usage},
24};
25use error::{BackendError, LlamaCoreError};
26use std::{
27    collections::{HashMap, VecDeque},
28    pin::Pin,
29    sync::{
30        atomic::{AtomicBool, Ordering},
31        Arc, Mutex, OnceLock,
32    },
33    task::{Context, Poll, Waker},
34    time::SystemTime,
35};
36
37// ============================================================================
38// Per-Model Stream Lock Infrastructure
39// ============================================================================
40
41/// Per-model stream lock state.
42/// Each model has its own lock to allow parallel inference across different models.
43pub struct ModelStreamLock {
44    /// Whether this model currently has an active stream
45    pub active: AtomicBool,
46    /// Waker queue for requests waiting on this model
47    pub waker_queue: Mutex<VecDeque<Waker>>,
48}
49
50impl ModelStreamLock {
51    /// Create a new model stream lock
52    pub fn new() -> Self {
53        Self {
54            active: AtomicBool::new(false),
55            waker_queue: Mutex::new(VecDeque::new()),
56        }
57    }
58
59    /// Try to acquire the lock. Returns true if successful.
60    pub fn try_acquire(&self) -> bool {
61        self.active
62            .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
63            .is_ok()
64    }
65
66    /// Release the lock and wake the next waiting request.
67    pub fn release(&self) {
68        self.active.store(false, Ordering::SeqCst);
69
70        if let Ok(mut queue) = self.waker_queue.lock() {
71            if let Some(waker) = queue.pop_front() {
72                waker.wake();
73            }
74        }
75    }
76
77    /// Register a waker to be notified when the lock is released.
78    pub fn register_waker(&self, waker: &Waker) {
79        if let Ok(mut queue) = self.waker_queue.lock() {
80            // Remove duplicate wakers
81            queue.retain(|w| !w.will_wake(waker));
82            queue.push_back(waker.clone());
83        }
84    }
85}
86
87impl Default for ModelStreamLock {
88    fn default() -> Self {
89        Self::new()
90    }
91}
92
93/// RAII guard for model lock.
94/// Automatically releases the lock when dropped.
95/// Used by non-stream mode to ensure lock release even on early return or panic.
96pub struct ModelLockGuard {
97    lock: Arc<ModelStreamLock>,
98}
99
100impl ModelLockGuard {
101    /// Create a new guard that holds the given lock.
102    /// The lock should already be acquired before creating the guard.
103    pub fn new(lock: Arc<ModelStreamLock>) -> Self {
104        Self { lock }
105    }
106}
107
108impl Drop for ModelLockGuard {
109    fn drop(&mut self) {
110        self.lock.release();
111
112        #[cfg(feature = "logging")]
113        info!(target: "stdout", "ModelLockGuard: lock released on drop");
114    }
115}
116
117/// Global model stream locks manager.
118/// Key: model_name, Value: Arc<ModelStreamLock>
119static MODEL_STREAM_LOCKS: OnceLock<Mutex<HashMap<String, Arc<ModelStreamLock>>>> = OnceLock::new();
120
121/// Get the model locks manager, initializing it if necessary.
122fn get_model_locks() -> &'static Mutex<HashMap<String, Arc<ModelStreamLock>>> {
123    MODEL_STREAM_LOCKS.get_or_init(|| {
124        #[cfg(feature = "logging")]
125        info!(target: "stdout", "Initializing model stream locks manager");
126        Mutex::new(HashMap::new())
127    })
128}
129
130/// Get or create a stream lock for the specified model.
131pub fn get_or_create_model_lock(model_name: &str) -> Result<Arc<ModelStreamLock>, LlamaCoreError> {
132    let locks = get_model_locks();
133    let mut locks_guard = locks.lock().map_err(|e| {
134        let err_msg = format!("Failed to acquire model locks: {e}");
135
136        #[cfg(feature = "logging")]
137        error!(target: "stdout", "{}", &err_msg);
138
139        LlamaCoreError::Operation(err_msg)
140    })?;
141
142    if !locks_guard.contains_key(model_name) {
143        locks_guard.insert(model_name.to_string(), Arc::new(ModelStreamLock::new()));
144
145        #[cfg(feature = "logging")]
146        info!(target: "stdout", "Created new stream lock for model: {}", model_name);
147    }
148
149    Ok(locks_guard.get(model_name).unwrap().clone())
150}
151
152/// Get the lock for the default model (first available model).
153/// Used when the request does not specify a model name.
154pub fn get_default_model_lock() -> Result<(String, Arc<ModelStreamLock>), LlamaCoreError> {
155    let chat_graphs = CHAT_GRAPHS.get().ok_or_else(|| {
156        let err_msg = "CHAT_GRAPHS not initialized";
157
158        #[cfg(feature = "logging")]
159        error!(target: "stdout", "{}", &err_msg);
160
161        LlamaCoreError::Operation(err_msg.into())
162    })?;
163
164    let chat_graphs = chat_graphs.lock().map_err(|e| {
165        let err_msg = format!("Failed to acquire CHAT_GRAPHS lock: {e}");
166
167        #[cfg(feature = "logging")]
168        error!(target: "stdout", "{}", &err_msg);
169
170        LlamaCoreError::Operation(err_msg)
171    })?;
172
173    let model_name = chat_graphs
174        .keys()
175        .next()
176        .ok_or_else(|| {
177            let err_msg = "No model available";
178
179            #[cfg(feature = "logging")]
180            error!(target: "stdout", "{}", &err_msg);
181
182            LlamaCoreError::Operation(err_msg.into())
183        })?
184        .clone();
185
186    drop(chat_graphs); // Release CHAT_GRAPHS lock early
187
188    let lock = get_or_create_model_lock(&model_name)?;
189    Ok((model_name, lock))
190}
191
192/// Processes a chat-completion request and returns either a stream of ChatCompletionChunk instances or a ChatCompletionObject instance.
193pub async fn chat(
194    chat_request: &mut ChatCompletionRequest,
195) -> Result<
196    (
197        Either<impl futures::TryStream<Ok = String, Error = LlamaCoreError>, ChatCompletionObject>,
198        bool,
199    ),
200    LlamaCoreError,
201> {
202    #[cfg(feature = "logging")]
203    {
204        debug!(target: "stdout", "tool choice: {:?}", chat_request.tool_choice.as_ref());
205        debug!(target: "stdout", "tools: {:?}", chat_request.tools.as_ref());
206        debug!(target: "stdout", "stream mode: {:?}", chat_request.stream);
207    }
208
209    let result = match chat_request.stream {
210        Some(true) => match chat_stream(chat_request).await {
211            Ok((stream, include_tool_calls)) => Ok((Left(stream), include_tool_calls)),
212            Err(e) => Err(e),
213        },
214        Some(false) | None => match chat_once(chat_request).await {
215            Ok((chat_completion_object, include_tool_calls)) => {
216                Ok((Right(chat_completion_object), include_tool_calls))
217            }
218            Err(e) => Err(e),
219        },
220    };
221
222    #[cfg(feature = "logging")]
223    info!(target: "stdout", "Reset the model metadata");
224
225    result
226}
227
228async fn chat_stream(
229    chat_request: &mut ChatCompletionRequest,
230) -> Result<
231    (
232        impl futures::TryStream<Ok = String, Error = LlamaCoreError>,
233        bool,
234    ),
235    LlamaCoreError,
236> {
237    #[cfg(feature = "logging")]
238    info!(target: "stdout", "Process chat completion request in the stream mode");
239
240    let running_mode = running_mode()?;
241    if !running_mode.contains(RunningMode::CHAT) && !running_mode.contains(RunningMode::RAG) {
242        let err_msg = "The chat completion is only supported in the chat or rag mode.";
243
244        #[cfg(feature = "logging")]
245        error!(target: "stdout", "{err_msg}");
246
247        return Err(LlamaCoreError::Operation(err_msg.to_string()));
248    }
249
250    let model_name = chat_request.model.clone();
251    let id = match &chat_request.user {
252        Some(id) => id.clone(),
253        None => gen_chat_id(),
254    };
255    #[cfg(feature = "logging")]
256    info!(target: "stdout", "user: {}", &id);
257
258    #[cfg(feature = "logging")]
259    info!(target: "stdout", "Check model metadata");
260
261    // update metadata
262    let mut metadata = check_model_metadata(chat_request)?;
263
264    // parse the `include_usage` option
265    let include_usage = match chat_request.stream_options {
266        Some(ref stream_options) => stream_options.include_usage.unwrap_or_default(),
267        None => metadata.include_usage,
268    };
269    #[cfg(feature = "logging")]
270    info!(target: "stdout", "include_usage: {include_usage}");
271
272    #[cfg(feature = "logging")]
273    info!(target: "stdout", "Build the chat prompt");
274
275    // build prompt
276    let (prompt, avaible_completion_tokens, tool_use) =
277        build_prompt(model_name.as_ref(), chat_request)?;
278
279    #[cfg(feature = "logging")]
280    {
281        info!(target: "stdout", "prompt:\n{}", &prompt);
282        info!(target: "stdout", "available_completion_tokens: {avaible_completion_tokens}");
283        info!(target: "stdout", "tool_use: {tool_use}");
284    }
285
286    #[cfg(feature = "logging")]
287    info!(target: "stdout", "Update the n_predict");
288
289    // update metadata n_predict
290    update_n_predict(chat_request, &mut metadata, avaible_completion_tokens)?;
291
292    #[cfg(feature = "logging")]
293    info!(target: "stdout", "Feed the prompt to the model");
294
295    // set prompt
296    set_prompt(chat_request.model.as_ref(), &prompt)?;
297
298    let stream = match tool_use {
299        false => (ChatStream::new(model_name, id, include_usage, None)?, false),
300        true => {
301            let chat_graphs = match CHAT_GRAPHS.get() {
302                Some(chat_graphs) => chat_graphs,
303                None => {
304                    let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
305
306                    #[cfg(feature = "logging")]
307                    error!(target: "stdout", "{}", &err_msg);
308
309                    return Err(LlamaCoreError::Operation(err_msg.into()));
310                }
311            };
312
313            let mut chat_graphs = chat_graphs.lock().map_err(|e| {
314                let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
315
316                #[cfg(feature = "logging")]
317                error!(target: "stdout", "{}", &err_msg);
318
319                LlamaCoreError::Operation(err_msg)
320            })?;
321
322            match model_name {
323                Some(model_name) => match chat_graphs.contains_key(&model_name) {
324                    true => {
325                        let graph = chat_graphs.get_mut(&model_name).unwrap();
326                        chat_stream_for_tool(graph, id, include_usage)?
327                    }
328                    false => match chat_graphs.iter_mut().next() {
329                        Some((_, graph)) => chat_stream_for_tool(graph, id, include_usage)?,
330                        None => {
331                            let err_msg = "There is no model available in the chat graphs.";
332
333                            #[cfg(feature = "logging")]
334                            error!(target: "stdout", "{}", &err_msg);
335
336                            return Err(LlamaCoreError::Operation(err_msg.into()));
337                        }
338                    },
339                },
340                None => match chat_graphs.iter_mut().next() {
341                    Some((_, graph)) => chat_stream_for_tool(graph, id, include_usage)?,
342                    None => {
343                        let err_msg = "There is no model available in the chat graphs.";
344
345                        #[cfg(feature = "logging")]
346                        error!(target: "stdout", "{}", &err_msg);
347
348                        return Err(LlamaCoreError::Operation(err_msg.into()));
349                    }
350                },
351            }
352        }
353    };
354
355    #[cfg(feature = "logging")]
356    info!(target: "stdout", "End of the chat completion stream.");
357
358    Ok(stream)
359}
360
361fn chat_stream_for_tool(
362    graph: &mut Graph<GgmlMetadata>,
363    id: impl Into<String>,
364    include_usage: bool,
365) -> Result<(ChatStream, bool), LlamaCoreError> {
366    #[cfg(feature = "logging")]
367    info!(target: "stdout", "Handle chat request with available tools by the model named {}.", graph.name());
368
369    let id = id.into();
370
371    match graph.compute() {
372        Ok(_) => {
373            // Retrieve the output.
374            let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
375            let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
376                let err_msg = format!(
377                    "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
378                );
379
380                #[cfg(feature = "logging")]
381                error!(target: "stdout", "{}", &err_msg);
382
383                LlamaCoreError::Operation(err_msg)
384            })?;
385
386            #[cfg(feature = "logging")]
387            info!(target: "stdout", "raw generation:\n{output}");
388
389            // post-process
390            let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
391                LlamaCoreError::Operation(format!("Failed to post-process the output. {e}"))
392            })?;
393
394            #[cfg(feature = "logging")]
395            info!(target: "stdout", "post-processed generation:\n{}", &message);
396
397            // retrieve the number of prompt and completion tokens
398            let token_info = get_token_info_by_graph(graph)?;
399
400            #[cfg(feature = "logging")]
401            info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
402
403            let usage = Some(Usage {
404                prompt_tokens: token_info.prompt_tokens,
405                completion_tokens: token_info.completion_tokens,
406                total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
407            });
408
409            let created = SystemTime::now()
410                .duration_since(std::time::UNIX_EPOCH)
411                .map_err(|e| {
412                    let err_msg = format!("Failed to get the current time. Reason: {e}");
413
414                    #[cfg(feature = "logging")]
415                    error!(target: "stdout", "{}", &err_msg);
416
417                    LlamaCoreError::Operation(err_msg)
418                })?;
419
420            if graph.metadata.prompt_template != PromptTemplateType::MistralTool
421                && graph.metadata.prompt_template != PromptTemplateType::ChatMLTool
422                && graph.metadata.prompt_template != PromptTemplateType::GroqLlama3Tool
423                && graph.metadata.prompt_template != PromptTemplateType::Llama3Tool
424                && graph.metadata.prompt_template != PromptTemplateType::InternLM2Tool
425                && graph.metadata.prompt_template != PromptTemplateType::NemotronTool
426                && graph.metadata.prompt_template != PromptTemplateType::FunctionaryV32
427                && graph.metadata.prompt_template != PromptTemplateType::FunctionaryV31
428                && graph.metadata.prompt_template != PromptTemplateType::MistralSmallTool
429                && graph.metadata.prompt_template != PromptTemplateType::Llama4Chat
430                && graph.metadata.prompt_template != PromptTemplateType::Qwen3NoThink
431                && graph.metadata.prompt_template != PromptTemplateType::Smol3NoThink
432                && graph.metadata.prompt_template != PromptTemplateType::Gemma3
433                && graph.metadata.prompt_template != PromptTemplateType::GptOss
434                && graph.metadata.prompt_template != PromptTemplateType::Qwen3Agent
435                && graph.metadata.prompt_template != PromptTemplateType::SeedOssNoThink
436                && graph.metadata.prompt_template != PromptTemplateType::SeedOssThink
437            {
438                let err_msg = format!("Unsupported prompt template: {}. The tool use is only supported for 'mistral-tool', 'chatml-tool', 'groq-llama3-tool', 'llama-3-tool', 'internlm-2-tool', 'nemotron-tool', 'functionary-31', 'functionary-32', 'mistral-small-tool', 'llama-4-chat', 'qwen3-no-think', 'smol-3-no-think', 'gemma-3', 'gpt-oss', 'qwen3-agent', 'seed-oss-no-think', and 'seed-oss-think' prompt templates.", graph.metadata.prompt_template);
439
440                #[cfg(feature = "logging")]
441                error!(target: "stdout", "{}", &err_msg);
442
443                return Err(LlamaCoreError::Operation(err_msg));
444            }
445
446            let parsed_result = parse_tool_calls(&message, graph.metadata.prompt_template)?;
447
448            let content = if parsed_result.tool_calls.is_empty() {
449                Some(parsed_result.raw.clone())
450            } else {
451                parsed_result.content.clone()
452            };
453
454            let (tool_calls, include_tool_calls) = match parsed_result.tool_calls.is_empty() {
455                false => {
456                    let tool_calls: Vec<ToolCallForChunk> = parsed_result
457                        .tool_calls
458                        .into_iter()
459                        .enumerate()
460                        .map(|(index, tool_call)| ToolCallForChunk {
461                            index,
462                            id: tool_call.id,
463                            ty: tool_call.ty,
464                            function: tool_call.function,
465                        })
466                        .collect();
467                    (tool_calls, true)
468                }
469                true => (vec![], false),
470            };
471
472            // tool_calls chunk
473            let tool_call_chunk = {
474                let chat_completion_chunk = ChatCompletionChunk {
475                    id: id.clone(),
476                    object: "chat.completion.chunk".to_string(),
477                    created: created.as_secs(),
478                    model: graph.name().to_owned(),
479                    system_fingerprint: "fp_44709d6fcb".to_string(),
480                    choices: vec![ChatCompletionChunkChoice {
481                        index: 0,
482                        delta: ChatCompletionChunkChoiceDelta {
483                            role: ChatCompletionRole::Assistant,
484                            content,
485                            tool_calls,
486                        },
487                        logprobs: None,
488                        finish_reason: None,
489                    }],
490                    usage: None,
491                };
492                let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
493                    let err_msg = format!("Failed to serialize chat completion chunk. Reason: {e}");
494
495                    #[cfg(feature = "logging")]
496                    error!(target: "stdout", "{}", &err_msg);
497
498                    LlamaCoreError::Operation(err_msg)
499                })?;
500
501                format!("data: {chunk_str}\n\n")
502            };
503
504            // token uage chunk
505            let usage_chunk = {
506                let chat_completion_chunk = ChatCompletionChunk {
507                    id: id.clone(),
508                    object: "chat.completion.chunk".to_string(),
509                    created: created.as_secs(),
510                    model: graph.name().to_owned(),
511                    system_fingerprint: "fp_44709d6fcb".to_string(),
512                    choices: vec![],
513                    usage,
514                };
515                let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
516                    let err_msg = format!("Failed to serialize chat completion chunk. Reason: {e}");
517
518                    #[cfg(feature = "logging")]
519                    error!(target: "stdout", "{}", &err_msg);
520
521                    LlamaCoreError::Operation(err_msg)
522                })?;
523
524                format!("data: {chunk_str}\n\n")
525            };
526
527            // ending chunk
528            let ending_chunk = "data: [DONE]\n\n".to_string();
529
530            let chunks = vec![tool_call_chunk, usage_chunk, ending_chunk];
531
532            let stream = ChatStream::new(
533                Some(graph.name().to_owned()),
534                id,
535                include_usage,
536                Some(chunks),
537            )?;
538
539            Ok((stream, include_tool_calls))
540        }
541        Err(wasmedge_wasi_nn::Error::BackendError(wasmedge_wasi_nn::BackendError::ContextFull)) => {
542            // Retrieve the output.
543            let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
544            let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
545                let err_msg = format!(
546                    "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
547                );
548
549                #[cfg(feature = "logging")]
550                error!(target: "stdout", "{}", &err_msg);
551
552                LlamaCoreError::Operation(err_msg)
553            })?;
554
555            // post-process
556            let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
557                let err_msg = format!("Failed to post-process the output. {e}");
558
559                #[cfg(feature = "logging")]
560                error!(target: "stdout", "{}", &err_msg);
561
562                LlamaCoreError::Operation(err_msg)
563            })?;
564
565            // retrieve the number of prompt and completion tokens
566            let token_info = get_token_info_by_graph(graph)?;
567
568            #[cfg(feature = "logging")]
569            info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
570
571            let usage = Some(Usage {
572                prompt_tokens: token_info.prompt_tokens,
573                completion_tokens: token_info.completion_tokens,
574                total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
575            });
576
577            let created = SystemTime::now()
578                .duration_since(std::time::UNIX_EPOCH)
579                .map_err(|e| {
580                    let err_msg = format!("Failed to get the current time. Reason: {e}");
581
582                    #[cfg(feature = "logging")]
583                    error!(target: "stdout", "{}", &err_msg);
584
585                    LlamaCoreError::Operation(err_msg)
586                })?;
587
588            // context full chunk
589            let context_full_chunk = {
590                let chat_completion_chunk = ChatCompletionChunk {
591                    id: id.clone(),
592                    object: "chat.completion.chunk".to_string(),
593                    created: created.as_secs(),
594                    model: graph.name().to_owned(),
595                    system_fingerprint: "fp_44709d6fcb".to_string(),
596                    choices: vec![ChatCompletionChunkChoice {
597                        index: 0,
598                        delta: ChatCompletionChunkChoiceDelta {
599                            role: ChatCompletionRole::Assistant,
600                            content: Some(message),
601                            tool_calls: vec![],
602                        },
603                        logprobs: None,
604                        finish_reason: Some(FinishReason::length),
605                    }],
606                    usage: None,
607                };
608
609                // serialize chat completion chunk
610                let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
611                    let err_msg = format!("Failed to serialize chat completion chunk. Reason: {e}");
612
613                    #[cfg(feature = "logging")]
614                    error!(target: "stdout", "{}", &err_msg);
615
616                    LlamaCoreError::Operation(err_msg)
617                })?;
618
619                format!("data: {chunk_str}\n\n")
620            };
621
622            // usage chunk
623            let usage_chunk = {
624                let chat_completion_chunk = ChatCompletionChunk {
625                    id: id.clone(),
626                    object: "chat.completion.chunk".to_string(),
627                    created: created.as_secs(),
628                    model: graph.name().to_owned(),
629                    system_fingerprint: "fp_44709d6fcb".to_string(),
630                    choices: vec![],
631                    usage,
632                };
633
634                // serialize chat completion chunk
635                let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
636                    let err_msg = format!("Failed to serialize chat completion chunk. Reason: {e}");
637
638                    #[cfg(feature = "logging")]
639                    error!(target: "stdout", "{}", &err_msg);
640
641                    LlamaCoreError::Operation(err_msg)
642                })?;
643
644                format!("data: {chunk_str}\n\n")
645            };
646
647            // ending chunk
648            let ending_chunk = "data: [DONE]\n\n".to_string();
649
650            let chunks = vec![context_full_chunk, usage_chunk, ending_chunk];
651
652            let stream = ChatStream::new(
653                Some(graph.name().to_owned()),
654                id,
655                include_usage,
656                Some(chunks),
657            )?;
658
659            Ok((stream, false))
660        }
661        Err(wasmedge_wasi_nn::Error::BackendError(
662            wasmedge_wasi_nn::BackendError::PromptTooLong,
663        )) => {
664            #[cfg(feature = "logging")]
665            warn!(target: "stdout", "The prompt is too long. Please reduce the length of your input and try again.");
666
667            // Retrieve the output.
668            let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
669            let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
670                let err_msg = format!(
671                    "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
672                );
673
674                #[cfg(feature = "logging")]
675                error!(target: "stdout", "{}", &err_msg);
676
677                LlamaCoreError::Operation(err_msg)
678            })?;
679
680            // post-process
681            let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
682                let err_msg = format!("Failed to post-process the output. {e}");
683
684                #[cfg(feature = "logging")]
685                error!(target: "stdout", "{}", &err_msg);
686
687                LlamaCoreError::Operation(err_msg)
688            })?;
689
690            // retrieve the number of prompt and completion token
691            let token_info = get_token_info_by_graph(graph)?;
692
693            #[cfg(feature = "logging")]
694            info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
695
696            let usage = Some(Usage {
697                prompt_tokens: token_info.prompt_tokens,
698                completion_tokens: token_info.completion_tokens,
699                total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
700            });
701
702            let created = SystemTime::now()
703                .duration_since(std::time::UNIX_EPOCH)
704                .map_err(|e| {
705                    let err_msg = format!("Failed to get the current time. Reason: {e}");
706
707                    #[cfg(feature = "logging")]
708                    error!(target: "stdout", "{}", &err_msg);
709
710                    LlamaCoreError::Operation(err_msg)
711                })?;
712
713            // prompt too long chunk
714            let prompt_too_long_chunk = {
715                let chat_completion_chunk = ChatCompletionChunk {
716                    id: id.clone(),
717                    object: "chat.completion.chunk".to_string(),
718                    created: created.as_secs(),
719                    model: graph.name().to_owned(),
720                    system_fingerprint: "fp_44709d6fcb".to_string(),
721                    choices: vec![ChatCompletionChunkChoice {
722                        index: 0,
723                        delta: ChatCompletionChunkChoiceDelta {
724                            role: ChatCompletionRole::Assistant,
725                            content: Some(message),
726                            tool_calls: vec![],
727                        },
728                        logprobs: None,
729                        finish_reason: Some(FinishReason::length),
730                    }],
731                    usage: None,
732                };
733
734                // serialize chat completion chunk
735                let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
736                    let err_msg = format!("Failed to serialize chat completion chunk. Reason: {e}");
737
738                    #[cfg(feature = "logging")]
739                    error!(target: "stdout", "{}", &err_msg);
740
741                    LlamaCoreError::Operation(err_msg)
742                })?;
743
744                format!("data: {chunk_str}\n\n")
745            };
746
747            // usage chunk
748            let usage_chunk = {
749                let chat_completion_chunk = ChatCompletionChunk {
750                    id: id.clone(),
751                    object: "chat.completion.chunk".to_string(),
752                    created: created.as_secs(),
753                    model: graph.name().to_owned(),
754                    system_fingerprint: "fp_44709d6fcb".to_string(),
755                    choices: vec![],
756                    usage,
757                };
758
759                // serialize chat completion chunk
760                let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
761                    let err_msg = format!("Failed to serialize chat completion chunk. Reason: {e}");
762
763                    #[cfg(feature = "logging")]
764                    error!(target: "stdout", "{}", &err_msg);
765
766                    LlamaCoreError::Operation(err_msg)
767                })?;
768
769                format!("data: {chunk_str}\n\n")
770            };
771
772            // ending chunk
773            let ending_chunk = "data: [DONE]\n\n".to_string();
774
775            let chunks = vec![prompt_too_long_chunk, usage_chunk, ending_chunk];
776
777            let stream = ChatStream::new(
778                Some(graph.name().to_owned()),
779                id,
780                include_usage,
781                Some(chunks),
782            )?;
783
784            Ok((stream, false))
785        }
786        Err(e) => {
787            let err_msg = format!("Failed to compute the chat completion. Reason: {e}");
788
789            #[cfg(feature = "logging")]
790            error!(target: "stdout", "{}", &err_msg);
791
792            Err(LlamaCoreError::Backend(BackendError::Compute(err_msg)))
793        }
794    }
795}
796
797async fn chat_once(
798    chat_request: &mut ChatCompletionRequest,
799) -> Result<(ChatCompletionObject, bool), LlamaCoreError> {
800    #[cfg(feature = "logging")]
801    info!(target: "stdout", "Processing chat completion request in non-stream mode");
802
803    let running_mode = running_mode()?;
804    if !running_mode.contains(RunningMode::CHAT) && !running_mode.contains(RunningMode::RAG) {
805        let err_msg = "The chat completion is only supported in the chat or rag mode.";
806
807        #[cfg(feature = "logging")]
808        error!(target: "stdout", "{err_msg}");
809
810        return Err(LlamaCoreError::Operation(err_msg.to_string()));
811    }
812
813    let model_name = chat_request.model.clone();
814
815    // Get or create the per-model lock
816    let (resolved_model, model_lock) = match &model_name {
817        Some(name) => (name.clone(), get_or_create_model_lock(name)?),
818        None => get_default_model_lock()?,
819    };
820
821    // Acquire the lock using spin-wait loop
822    // This ensures mutual exclusion with stream requests on the same model
823    while !model_lock.try_acquire() {
824        #[cfg(feature = "logging")]
825        debug!(target: "stdout", "Non-stream request waiting for model lock: {}", &resolved_model);
826
827        // Yield to other tasks to avoid busy-waiting
828        std::hint::spin_loop();
829    }
830
831    #[cfg(feature = "logging")]
832    info!(target: "stdout", "Non-stream request acquired lock for model: {}", &resolved_model);
833
834    // Use RAII guard to ensure lock is released even on early return or panic
835    let _lock_guard = ModelLockGuard::new(model_lock);
836
837    let id = match &chat_request.user {
838        Some(id) => id.clone(),
839        None => gen_chat_id(),
840    };
841
842    #[cfg(feature = "logging")]
843    info!(target: "stdout", "user: {}", &id);
844
845    #[cfg(feature = "logging")]
846    info!(target: "stdout", "Check model metadata");
847
848    // update metadata
849    let mut metadata = check_model_metadata(chat_request)?;
850
851    #[cfg(feature = "logging")]
852    info!(target: "stdout", "Build the chat prompt");
853
854    // build prompt
855    let (prompt, avaible_completion_tokens, tool_use) =
856        build_prompt(Some(&resolved_model), chat_request)?;
857
858    #[cfg(feature = "logging")]
859    {
860        info!(target: "stdout", "prompt:\n{}", &prompt);
861        info!(target: "stdout", "available_completion_tokens: {avaible_completion_tokens}");
862        info!(target: "stdout", "tool_use: {tool_use}");
863    }
864
865    #[cfg(feature = "logging")]
866    info!(target: "stdout", "Update n_predict");
867
868    // update metadata n_predict
869    update_n_predict(chat_request, &mut metadata, avaible_completion_tokens)?;
870
871    #[cfg(feature = "logging")]
872    info!(target: "stdout", "Feed the prompt to the model");
873
874    // feed the prompt to the model
875    set_prompt(Some(&resolved_model), &prompt)?;
876
877    #[cfg(feature = "logging")]
878    info!(target: "stdout", "Compute chat completion.");
879
880    // compute
881    let res = compute(Some(&resolved_model), id, tool_use);
882
883    #[cfg(feature = "logging")]
884    info!(target: "stdout", "End of the chat completion");
885
886    // reset the model metadata
887    reset_model_metadata(Some(&resolved_model))?;
888
889    res
890    // _lock_guard is dropped here, releasing the lock
891}
892
893fn compute(
894    model_name: Option<&String>,
895    id: impl Into<String>,
896    tool_use: bool,
897) -> Result<(ChatCompletionObject, bool), LlamaCoreError> {
898    let chat_graphs = match CHAT_GRAPHS.get() {
899        Some(chat_graphs) => chat_graphs,
900        None => {
901            let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
902
903            #[cfg(feature = "logging")]
904            error!(target: "stdout", "{}", &err_msg);
905
906            return Err(LlamaCoreError::Operation(err_msg.into()));
907        }
908    };
909
910    let mut chat_graphs = chat_graphs.lock().map_err(|e| {
911        let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
912
913        #[cfg(feature = "logging")]
914        error!(target: "stdout", "{}", &err_msg);
915
916        LlamaCoreError::Operation(err_msg)
917    })?;
918
919    match model_name {
920        Some(model_name) => match chat_graphs.contains_key(model_name) {
921            true => {
922                let graph = chat_graphs.get_mut(model_name).unwrap();
923                compute_by_graph(graph, id, tool_use)
924            }
925            false => match chat_graphs.iter_mut().next() {
926                Some((_, graph)) => compute_by_graph(graph, id, tool_use),
927                None => {
928                    let err_msg = "There is no model available in the chat graphs.";
929
930                    #[cfg(feature = "logging")]
931                    error!(target: "stdout", "{}", &err_msg);
932
933                    Err(LlamaCoreError::Operation(err_msg.into()))
934                }
935            },
936        },
937        None => match chat_graphs.iter_mut().next() {
938            Some((_, graph)) => compute_by_graph(graph, id, tool_use),
939            None => {
940                let err_msg = "There is no model available in the chat graphs.";
941
942                #[cfg(feature = "logging")]
943                error!(target: "stdout", "{}", &err_msg);
944
945                Err(LlamaCoreError::Operation(err_msg.into()))
946            }
947        },
948    }
949}
950
951fn compute_by_graph(
952    graph: &mut Graph<GgmlMetadata>,
953    id: impl Into<String>,
954    tool_use: bool,
955) -> Result<(ChatCompletionObject, bool), LlamaCoreError> {
956    #[cfg(feature = "logging")]
957    info!(target: "stdout", "Compute chat completion by the model named {}.", graph.name());
958
959    match graph.compute() {
960        Ok(_) => {
961            // Retrieve the output.
962            let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
963            let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
964                let err_msg = format!(
965                    "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
966                );
967
968                #[cfg(feature = "logging")]
969                error!(target: "stdout", "{}", &err_msg);
970
971                LlamaCoreError::Operation(err_msg)
972            })?;
973
974            #[cfg(feature = "logging")]
975            info!(target: "stdout", "raw generation: {output}");
976
977            // post-process
978            let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
979                LlamaCoreError::Operation(format!("Failed to post-process the output. {e}"))
980            })?;
981
982            #[cfg(feature = "logging")]
983            info!(target: "stdout", "post-processed generation:\n{}", &message);
984
985            // retrieve the number of prompt and completion tokens
986            let token_info = get_token_info_by_graph(graph)?;
987
988            #[cfg(feature = "logging")]
989            info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
990
991            let created = SystemTime::now()
992                .duration_since(std::time::UNIX_EPOCH)
993                .map_err(|e| {
994                    let err_msg = format!("Failed to get the current time. Reason: {e}");
995
996                    #[cfg(feature = "logging")]
997                    error!(target: "stdout", "{}", &err_msg);
998
999                    LlamaCoreError::Operation(err_msg)
1000                })?;
1001
1002            match tool_use {
1003                true => {
1004                    if graph.metadata.prompt_template != PromptTemplateType::MistralTool
1005                        && graph.metadata.prompt_template != PromptTemplateType::ChatMLTool
1006                        && graph.metadata.prompt_template != PromptTemplateType::GroqLlama3Tool
1007                        && graph.metadata.prompt_template != PromptTemplateType::Llama3Tool
1008                        && graph.metadata.prompt_template != PromptTemplateType::InternLM2Tool
1009                        && graph.metadata.prompt_template != PromptTemplateType::NemotronTool
1010                        && graph.metadata.prompt_template != PromptTemplateType::FunctionaryV32
1011                        && graph.metadata.prompt_template != PromptTemplateType::FunctionaryV31
1012                        && graph.metadata.prompt_template != PromptTemplateType::MistralSmallTool
1013                        && graph.metadata.prompt_template != PromptTemplateType::Llama4Chat
1014                        && graph.metadata.prompt_template != PromptTemplateType::Qwen3NoThink
1015                        && graph.metadata.prompt_template != PromptTemplateType::Smol3NoThink
1016                        && graph.metadata.prompt_template != PromptTemplateType::Gemma3
1017                        && graph.metadata.prompt_template != PromptTemplateType::GptOss
1018                        && graph.metadata.prompt_template != PromptTemplateType::Qwen3Agent
1019                        && graph.metadata.prompt_template != PromptTemplateType::SeedOssNoThink
1020                        && graph.metadata.prompt_template != PromptTemplateType::SeedOssThink
1021                    {
1022                        let err_msg = format!("Unsupported prompt template: {}. The tool use is only supported for 'mistral-tool', 'chatml-tool', 'groq-llama3-tool', 'llama-3-tool', 'internlm-2-tool', 'nemotron-tool', 'functionary-31', 'functionary-32', 'mistral-small-tool', 'llama-4-chat', 'qwen3-no-think', 'smol-3-no-think', 'gemma-3', 'gpt-oss', 'qwen3-agent', 'seed-oss-no-think', and 'seed-oss-think' prompt templates.", graph.metadata.prompt_template);
1023
1024                        #[cfg(feature = "logging")]
1025                        error!(target: "stdout", "{}", &err_msg);
1026
1027                        return Err(LlamaCoreError::Operation(err_msg));
1028                    }
1029
1030                    let parsed_result = parse_tool_calls(&message, graph.metadata.prompt_template)?;
1031
1032                    let (finish_reason, content, include_tool_calls) =
1033                        if parsed_result.tool_calls.is_empty() {
1034                            (FinishReason::stop, Some(parsed_result.raw.clone()), false)
1035                        } else if graph.metadata.prompt_template != PromptTemplateType::Qwen3Agent {
1036                            (
1037                                FinishReason::tool_calls,
1038                                Some(parsed_result.raw.clone()),
1039                                true,
1040                            )
1041                        } else {
1042                            (
1043                                FinishReason::tool_calls,
1044                                parsed_result.content.clone(),
1045                                true,
1046                            )
1047                        };
1048
1049                    let res = ChatCompletionObject {
1050                        id: id.into(),
1051                        object: String::from("chat.completion"),
1052                        created: created.as_secs(),
1053                        model: graph.name().to_owned(),
1054                        choices: vec![ChatCompletionObjectChoice {
1055                            index: 0,
1056                            message: ChatCompletionObjectMessage {
1057                                role: ChatCompletionRole::Assistant,
1058                                content,
1059                                tool_calls: parsed_result.tool_calls,
1060                                function_call: None,
1061                            },
1062                            finish_reason,
1063                            logprobs: None,
1064                        }],
1065                        usage: Usage {
1066                            prompt_tokens: token_info.prompt_tokens,
1067                            completion_tokens: token_info.completion_tokens,
1068                            total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
1069                        },
1070                    };
1071
1072                    // create ChatCompletionResponse
1073                    Ok((res, include_tool_calls))
1074                }
1075                false => {
1076                    // create ChatCompletionResponse
1077                    let res = ChatCompletionObject {
1078                        id: id.into(),
1079                        object: String::from("chat.completion"),
1080                        created: created.as_secs(),
1081                        model: graph.name().to_owned(),
1082                        choices: vec![ChatCompletionObjectChoice {
1083                            index: 0,
1084                            message: ChatCompletionObjectMessage {
1085                                role: ChatCompletionRole::Assistant,
1086                                content: Some(message),
1087                                tool_calls: vec![],
1088                                function_call: None,
1089                            },
1090                            finish_reason: FinishReason::stop,
1091                            logprobs: None,
1092                        }],
1093                        usage: Usage {
1094                            prompt_tokens: token_info.prompt_tokens,
1095                            completion_tokens: token_info.completion_tokens,
1096                            total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
1097                        },
1098                    };
1099
1100                    Ok((res, false))
1101                }
1102            }
1103        }
1104        Err(wasmedge_wasi_nn::Error::BackendError(wasmedge_wasi_nn::BackendError::ContextFull)) => {
1105            // Retrieve the output.
1106            let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
1107            let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
1108                let err_msg = format!(
1109                    "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
1110                );
1111
1112                #[cfg(feature = "logging")]
1113                error!(target: "stdout", "{}", &err_msg);
1114
1115                LlamaCoreError::Operation(err_msg)
1116            })?;
1117
1118            // post-process
1119            let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
1120                let err_msg = format!("Failed to post-process the output. {e}");
1121
1122                #[cfg(feature = "logging")]
1123                error!(target: "stdout", "{}", &err_msg);
1124
1125                LlamaCoreError::Operation(err_msg)
1126            })?;
1127
1128            // retrieve the number of prompt and completion tokens
1129            let token_info = get_token_info_by_graph(graph)?;
1130
1131            #[cfg(feature = "logging")]
1132            info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
1133
1134            let created = SystemTime::now()
1135                .duration_since(std::time::UNIX_EPOCH)
1136                .map_err(|e| {
1137                    let err_msg = format!("Failed to get the current time. Reason: {e}");
1138
1139                    #[cfg(feature = "logging")]
1140                    error!(target: "stdout", "{}", &err_msg);
1141
1142                    LlamaCoreError::Operation(err_msg)
1143                })?;
1144
1145            // create ChatCompletionResponse
1146            let res = ChatCompletionObject {
1147                id: id.into(),
1148                object: String::from("chat.completion"),
1149                created: created.as_secs(),
1150                model: graph.name().to_owned(),
1151                choices: vec![ChatCompletionObjectChoice {
1152                    index: 0,
1153                    message: ChatCompletionObjectMessage {
1154                        role: ChatCompletionRole::Assistant,
1155                        content: Some(message),
1156                        tool_calls: vec![],
1157                        function_call: None,
1158                    },
1159                    finish_reason: FinishReason::length,
1160                    logprobs: None,
1161                }],
1162                usage: Usage {
1163                    prompt_tokens: token_info.prompt_tokens,
1164                    completion_tokens: token_info.completion_tokens,
1165                    total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
1166                },
1167            };
1168
1169            Ok((res, false))
1170        }
1171        Err(wasmedge_wasi_nn::Error::BackendError(
1172            wasmedge_wasi_nn::BackendError::PromptTooLong,
1173        )) => {
1174            #[cfg(feature = "logging")]
1175            warn!(target: "stdout", "The prompt is too long. Please reduce the length of your input and try again.");
1176
1177            // Retrieve the output.
1178            let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
1179            let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
1180                let err_msg = format!(
1181                    "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
1182                );
1183
1184                #[cfg(feature = "logging")]
1185                error!(target: "stdout", "{}", &err_msg);
1186
1187                LlamaCoreError::Operation(err_msg)
1188            })?;
1189
1190            // post-process
1191            let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
1192                let err_msg = format!("Failed to post-process the output. {e}");
1193
1194                #[cfg(feature = "logging")]
1195                error!(target: "stdout", "{}", &err_msg);
1196
1197                LlamaCoreError::Operation(err_msg)
1198            })?;
1199
1200            // retrieve the number of prompt and completion token
1201            let token_info = get_token_info_by_graph(graph)?;
1202
1203            #[cfg(feature = "logging")]
1204            info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
1205
1206            let usage = Usage {
1207                prompt_tokens: token_info.prompt_tokens,
1208                completion_tokens: token_info.completion_tokens,
1209                total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
1210            };
1211
1212            let created = SystemTime::now()
1213                .duration_since(std::time::UNIX_EPOCH)
1214                .map_err(|e| {
1215                    let err_msg = format!("Failed to get the current time. Reason: {e}");
1216
1217                    #[cfg(feature = "logging")]
1218                    error!(target: "stdout", "{}", &err_msg);
1219
1220                    LlamaCoreError::Operation(err_msg)
1221                })?;
1222
1223            // create ChatCompletionResponse
1224            let res = ChatCompletionObject {
1225                id: id.into(),
1226                object: String::from("chat.completion"),
1227                created: created.as_secs(),
1228                model: graph.name().to_owned(),
1229                choices: vec![ChatCompletionObjectChoice {
1230                    index: 0,
1231                    message: ChatCompletionObjectMessage {
1232                        role: ChatCompletionRole::Assistant,
1233                        content: Some(message),
1234                        tool_calls: vec![],
1235                        function_call: None,
1236                    },
1237                    finish_reason: FinishReason::length,
1238                    logprobs: None,
1239                }],
1240                usage,
1241            };
1242
1243            Ok((res, false))
1244        }
1245        Err(e) => {
1246            let err_msg = format!("Failed to compute the chat completion. Reason: {e}");
1247
1248            #[cfg(feature = "logging")]
1249            error!(target: "stdout", "{}", &err_msg);
1250
1251            Err(LlamaCoreError::Backend(BackendError::Compute(err_msg)))
1252        }
1253    }
1254}
1255
1256fn parse_tool_calls(
1257    input: &str,
1258    prompt_template: PromptTemplateType,
1259) -> Result<ParseResult, LlamaCoreError> {
1260    match prompt_template {
1261        PromptTemplateType::MistralTool => match regex::Regex::new(r"\[\{.*?\}\]") {
1262            Ok(re) => {
1263                let mut values: Vec<serde_json::Value> = vec![];
1264                for cap in re.captures_iter(input) {
1265                    let matched = &cap[0];
1266
1267                    #[cfg(feature = "logging")]
1268                    info!(target: "stdout", "captured: {matched}");
1269
1270                    match serde_json::from_str::<Vec<serde_json::Value>>(matched) {
1271                        Ok(group) => values.extend(group),
1272                        Err(e) => {
1273                            let err_msg =
1274                                format!("Failed to deserialize generated tool calls. Reason: {e}");
1275
1276                            #[cfg(feature = "logging")]
1277                            error!(target: "stdout", "{}", &err_msg);
1278
1279                            return Err(LlamaCoreError::Operation(err_msg));
1280                        }
1281                    }
1282                }
1283
1284                let mut tool_calls: Vec<ToolCall> = vec![];
1285                for value in values.iter() {
1286                    let name = match value.get("name") {
1287                        Some(name) => name.to_string().replace("\"", ""),
1288                        None => {
1289                            let err_msg = format!(
1290                                "Failed to get the name of the function. Tool call: {value:?}"
1291                            );
1292
1293                            #[cfg(feature = "logging")]
1294                            error!(target: "stdout", "{}", &err_msg);
1295
1296                            return Err(LlamaCoreError::Operation(err_msg));
1297                        }
1298                    };
1299
1300                    let arguments = match value.get("arguments") {
1301                        Some(arguments) => arguments.to_string(),
1302                        None => {
1303                            let err_msg = format!(
1304                                "Failed to get the arguments of the function. Tool call: {value:?}"
1305                            );
1306
1307                            #[cfg(feature = "logging")]
1308                            error!(target: "stdout", "{}", &err_msg);
1309
1310                            return Err(LlamaCoreError::Operation(err_msg));
1311                        }
1312                    };
1313
1314                    let function = Function { name, arguments };
1315
1316                    let tool_call = ToolCall {
1317                        id: "call_abc123".to_string(),
1318                        ty: "function".to_string(),
1319                        function,
1320                    };
1321
1322                    tool_calls.push(tool_call);
1323                }
1324
1325                let parsed = ParseResult {
1326                    raw: input.to_owned(),
1327                    content: None,
1328                    tool_calls,
1329                };
1330
1331                #[cfg(feature = "logging")]
1332                info!(target: "stdout", "parsed result: {parsed:?}");
1333
1334                Ok(parsed)
1335            }
1336            Err(e) => {
1337                let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1338
1339                #[cfg(feature = "logging")]
1340                error!(target: "stdout", "{}", &err_msg);
1341
1342                Err(LlamaCoreError::Operation(err_msg))
1343            }
1344        },
1345        PromptTemplateType::ChatMLTool => {
1346            match regex::Regex::new(r"<tool_call>(.*?)</tool_call>") {
1347                Ok(re) => {
1348                    let mut values: Vec<serde_json::Value> = vec![];
1349                    for cap in re.captures_iter(input) {
1350                        let matched = cap[1].replace("\\n", ""); // Remove "\\n" from the captured group
1351
1352                        #[cfg(feature = "logging")]
1353                        info!(target: "stdout", "captured: {}", &matched);
1354
1355                        match serde_json::from_str::<serde_json::Value>(&matched) {
1356                            Ok(value) => values.push(value),
1357                            Err(e) => {
1358                                let err_msg = format!(
1359                                    "Failed to deserialize generated tool calls. Reason: {e}"
1360                                );
1361
1362                                #[cfg(feature = "logging")]
1363                                error!(target: "stdout", "{}", &err_msg);
1364
1365                                return Err(LlamaCoreError::Operation(err_msg));
1366                            }
1367                        }
1368                    }
1369
1370                    let mut tool_calls: Vec<ToolCall> = vec![];
1371                    for value in values.iter() {
1372                        let name = match value.get("name") {
1373                            Some(name) => name.to_string().replace("\"", ""),
1374                            None => {
1375                                let err_msg = format!(
1376                                    "Failed to get the name of the function. Tool call: {value:?}"
1377                                );
1378
1379                                #[cfg(feature = "logging")]
1380                                error!(target: "stdout", "{}", &err_msg);
1381
1382                                return Err(LlamaCoreError::Operation(err_msg));
1383                            }
1384                        };
1385
1386                        let arguments = match value.get("arguments") {
1387                            Some(arguments) => arguments.to_string(),
1388                            None => {
1389                                let err_msg = format!(
1390                                    "Failed to get the arguments of the function. Tool call: {value:?}"
1391                                );
1392
1393                                #[cfg(feature = "logging")]
1394                                error!(target: "stdout", "{}", &err_msg);
1395
1396                                return Err(LlamaCoreError::Operation(err_msg));
1397                            }
1398                        };
1399
1400                        let function = Function { name, arguments };
1401
1402                        let tool_call = ToolCall {
1403                            id: "call_abc123".to_string(),
1404                            ty: "function".to_string(),
1405                            function,
1406                        };
1407
1408                        tool_calls.push(tool_call);
1409                    }
1410
1411                    let parsed = ParseResult {
1412                        raw: input.to_owned(),
1413                        content: None,
1414                        tool_calls,
1415                    };
1416
1417                    #[cfg(feature = "logging")]
1418                    info!(target: "stdout", "parsed result: {parsed:?}");
1419
1420                    Ok(parsed)
1421                }
1422                Err(e) => {
1423                    let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1424
1425                    #[cfg(feature = "logging")]
1426                    error!(target: "stdout", "{}", &err_msg);
1427
1428                    Err(LlamaCoreError::Operation(err_msg))
1429                }
1430            }
1431        }
1432        PromptTemplateType::GroqLlama3Tool => {
1433            #[cfg(feature = "logging")]
1434            info!(target: "stdout", "raw input: {input}");
1435
1436            match regex::Regex::new(r"(?s)<tool_call>((.|\r|\n)*?)</tool_call>") {
1437                Ok(re) => {
1438                    let mut values: Vec<serde_json::Value> = vec![];
1439                    for cap in re.captures_iter(input) {
1440                        let matched = cap[1].trim();
1441
1442                        #[cfg(feature = "logging")]
1443                        info!(target: "stdout", "captured: {matched}");
1444
1445                        match serde_json::from_str::<serde_json::Value>(matched) {
1446                            Ok(value) => values.push(value),
1447                            Err(e) => {
1448                                let err_msg = format!(
1449                                    "Failed to deserialize generated tool calls. Reason: {e}"
1450                                );
1451
1452                                #[cfg(feature = "logging")]
1453                                error!(target: "stdout", "{}", &err_msg);
1454
1455                                return Err(LlamaCoreError::Operation(err_msg));
1456                            }
1457                        }
1458                    }
1459
1460                    let mut tool_calls: Vec<ToolCall> = vec![];
1461                    for value in values.iter() {
1462                        let name = match value.get("name") {
1463                            Some(name) => name.to_string().replace("\"", ""),
1464                            None => {
1465                                let err_msg = format!(
1466                                    "Failed to get the name of the function. Tool call: {value:?}"
1467                                );
1468
1469                                #[cfg(feature = "logging")]
1470                                error!(target: "stdout", "{}", &err_msg);
1471
1472                                return Err(LlamaCoreError::Operation(err_msg));
1473                            }
1474                        };
1475
1476                        let arguments = match value.get("arguments") {
1477                            Some(arguments) => {
1478                                if arguments.is_string() {
1479                                    arguments.as_str().unwrap().to_string()
1480                                } else if arguments.is_object() {
1481                                    let map = arguments.as_object().unwrap();
1482
1483                                    #[cfg(feature = "logging")]
1484                                    info!(target: "stdout", "func arguments: {map:?}");
1485
1486                                    serde_json::to_string(map).unwrap()
1487                                } else {
1488                                    serde_json::to_string(arguments).unwrap()
1489                                }
1490                            }
1491                            None => {
1492                                let err_msg = format!(
1493                                    "Failed to get the arguments of the function. Tool call: {value:?}"
1494                                );
1495
1496                                #[cfg(feature = "logging")]
1497                                error!(target: "stdout", "{}", &err_msg);
1498
1499                                return Err(LlamaCoreError::Operation(err_msg));
1500                            }
1501                        };
1502
1503                        let function = Function { name, arguments };
1504
1505                        let tool_call = ToolCall {
1506                            id: "call_abc123".to_string(),
1507                            ty: "function".to_string(),
1508                            function,
1509                        };
1510
1511                        tool_calls.push(tool_call);
1512                    }
1513
1514                    let parsed = if tool_calls.is_empty() {
1515                        ParseResult {
1516                            raw: input.to_owned(),
1517                            content: Some(input.to_owned()),
1518                            tool_calls: vec![],
1519                        }
1520                    } else {
1521                        ParseResult {
1522                            raw: input.to_owned(),
1523                            content: None,
1524                            tool_calls,
1525                        }
1526                    };
1527
1528                    #[cfg(feature = "logging")]
1529                    info!(target: "stdout", "parsed result: {parsed:?}");
1530
1531                    Ok(parsed)
1532                }
1533                Err(e) => {
1534                    let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1535
1536                    #[cfg(feature = "logging")]
1537                    error!(target: "stdout", "{}", &err_msg);
1538
1539                    Err(LlamaCoreError::Operation(err_msg))
1540                }
1541            }
1542        }
1543        PromptTemplateType::Llama3Tool => {
1544            #[cfg(feature = "logging")]
1545            info!(target: "stdout", "raw input: {input}");
1546
1547            let re = match regex::Regex::new(r"^\{(.|\r|\n)*\}$") {
1548                Ok(re) => re,
1549                Err(e) => {
1550                    let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1551
1552                    #[cfg(feature = "logging")]
1553                    error!(target: "stdout", "{}", &err_msg);
1554
1555                    return Err(LlamaCoreError::Operation(err_msg));
1556                }
1557            };
1558
1559            if re.is_match(input) {
1560                match serde_json::from_str::<serde_json::Value>(input) {
1561                    Ok(value) => {
1562                        let values: Vec<serde_json::Value> = vec![value];
1563
1564                        let mut tool_calls: Vec<ToolCall> = vec![];
1565                        for value in values.iter() {
1566                            let name = match value.get("name") {
1567                                Some(name) => name.to_string().replace("\"", ""),
1568                                None => {
1569                                    let err_msg = format!(
1570                                        "Failed to get the name of the function. Tool call: {value:?}"
1571                                    );
1572
1573                                    #[cfg(feature = "logging")]
1574                                    error!(target: "stdout", "{}", &err_msg);
1575
1576                                    return Err(LlamaCoreError::Operation(err_msg));
1577                                }
1578                            };
1579
1580                            let arguments = match value.get("parameters") {
1581                                Some(arguments) => arguments.to_string(),
1582                                None => {
1583                                    let err_msg = format!(
1584                                        "Failed to get the arguments of the function. Tool call: {value:?}"
1585                                    );
1586
1587                                    #[cfg(feature = "logging")]
1588                                    error!(target: "stdout", "{}", &err_msg);
1589
1590                                    return Err(LlamaCoreError::Operation(err_msg));
1591                                }
1592                            };
1593
1594                            let function = Function { name, arguments };
1595
1596                            let tool_call = ToolCall {
1597                                id: "call_abc123".to_string(),
1598                                ty: "function".to_string(),
1599                                function,
1600                            };
1601
1602                            tool_calls.push(tool_call);
1603                        }
1604
1605                        let parsed = ParseResult {
1606                            raw: input.to_owned(),
1607                            content: None,
1608                            tool_calls,
1609                        };
1610
1611                        #[cfg(feature = "logging")]
1612                        info!(target: "stdout", "parsed result: {parsed:?}");
1613
1614                        Ok(parsed)
1615                    }
1616                    Err(e) => {
1617                        let err_msg =
1618                            format!("Failed to deserialize generated tool calls. Reason: {e}");
1619
1620                        #[cfg(feature = "logging")]
1621                        error!(target: "stdout", "{}", &err_msg);
1622
1623                        Err(LlamaCoreError::Operation(err_msg))
1624                    }
1625                }
1626            } else {
1627                let parsed = ParseResult {
1628                    raw: input.to_owned(),
1629                    content: None,
1630                    tool_calls: vec![],
1631                };
1632
1633                #[cfg(feature = "logging")]
1634                info!(target: "stdout", "parsed result: {parsed:?}");
1635
1636                Ok(parsed)
1637            }
1638        }
1639        PromptTemplateType::InternLM2Tool => {
1640            #[cfg(feature = "logging")]
1641            info!(target: "stdout", "raw input: {input}");
1642
1643            let blocks: Vec<&str> = input.trim().split("<|action_start|><|plugin|>").collect();
1644
1645            #[cfg(feature = "logging")]
1646            info!(target: "stdout", "blocks: {blocks:?}");
1647
1648            let mut tool_calls: Vec<ToolCall> = vec![];
1649            let mut content = String::new();
1650            for block in blocks {
1651                let block = block.trim();
1652                if !block.is_empty() {
1653                    if block.ends_with("<|action_end|>") {
1654                        let value = block.trim().trim_end_matches("<|action_end|>");
1655
1656                        #[cfg(feature = "logging")]
1657                        info!(target: "stdout", "tool call: {value}");
1658
1659                        match serde_json::from_str::<serde_json::Value>(value) {
1660                            Ok(value) => {
1661                                let name = match value.get("name") {
1662                                    Some(name) => name.to_string().replace("\"", ""),
1663                                    None => {
1664                                        let err_msg = format!(
1665                                            "Failed to get the name of the function. Tool call: {value:?}"
1666                                        );
1667
1668                                        #[cfg(feature = "logging")]
1669                                        error!(target: "stdout", "{}", &err_msg);
1670
1671                                        return Err(LlamaCoreError::Operation(err_msg));
1672                                    }
1673                                };
1674
1675                                let arguments = match value.get("parameters") {
1676                                    Some(arguments) => arguments.to_string(),
1677                                    None => {
1678                                        let err_msg = format!(
1679                                            "Failed to get the arguments of the function. Tool call: {value:?}"
1680                                        );
1681
1682                                        #[cfg(feature = "logging")]
1683                                        error!(target: "stdout", "{}", &err_msg);
1684
1685                                        return Err(LlamaCoreError::Operation(err_msg));
1686                                    }
1687                                };
1688
1689                                let function = Function { name, arguments };
1690
1691                                let tool_call = ToolCall {
1692                                    id: "call_abc123".to_string(),
1693                                    ty: "function".to_string(),
1694                                    function,
1695                                };
1696
1697                                tool_calls.push(tool_call);
1698                            }
1699                            Err(e) => {
1700                                let err_msg = format!(
1701                                    "Failed to deserialize generated tool calls. Reason: {e}"
1702                                );
1703
1704                                #[cfg(feature = "logging")]
1705                                error!(target: "stdout", "{}", &err_msg);
1706
1707                                return Err(LlamaCoreError::Operation(err_msg));
1708                            }
1709                        }
1710                    } else {
1711                        content.push_str(block);
1712                        content.push('\n');
1713                    }
1714                }
1715            }
1716
1717            let parsed = match content.is_empty() {
1718                true => ParseResult {
1719                    raw: input.to_owned(),
1720                    content: None,
1721                    tool_calls,
1722                },
1723                false => ParseResult {
1724                    raw: input.to_owned(),
1725                    content: Some(content.trim().to_owned()),
1726                    tool_calls,
1727                },
1728            };
1729
1730            #[cfg(feature = "logging")]
1731            info!(target: "stdout", "parsed result: {parsed:?}");
1732
1733            Ok(parsed)
1734        }
1735        PromptTemplateType::NemotronTool => {
1736            #[cfg(feature = "logging")]
1737            info!(target: "stdout", "raw input: {input}");
1738
1739            match regex::Regex::new(r"(?s)<toolcall>\s*(.*?)\s*</toolcall>") {
1740                Ok(re) => {
1741                    let mut values: Vec<serde_json::Value> = vec![];
1742                    for cap in re.captures_iter(input) {
1743                        #[cfg(feature = "logging")]
1744                        info!(target: "stdout", "captured: {}", &cap[0]);
1745
1746                        #[cfg(feature = "logging")]
1747                        info!(target: "stdout", "extracted: {}", &cap[1]);
1748
1749                        let matched = cap[1].trim();
1750
1751                        #[cfg(feature = "logging")]
1752                        info!(target: "stdout", "captured: {matched}");
1753
1754                        match serde_json::from_str::<serde_json::Value>(matched) {
1755                            Ok(value) => values.push(value),
1756                            Err(e) => {
1757                                let err_msg = format!(
1758                                    "Failed to deserialize generated tool calls. Reason: {e}"
1759                                );
1760
1761                                #[cfg(feature = "logging")]
1762                                error!(target: "stdout", "{}", &err_msg);
1763
1764                                return Err(LlamaCoreError::Operation(err_msg));
1765                            }
1766                        }
1767                    }
1768
1769                    let mut tool_calls: Vec<ToolCall> = vec![];
1770                    for value in values.iter() {
1771                        let name = match value.get("name") {
1772                            Some(name) => name.to_string().replace("\"", ""),
1773                            None => {
1774                                let err_msg = format!(
1775                                    "Failed to get the name of the function. Tool call: {value:?}"
1776                                );
1777
1778                                #[cfg(feature = "logging")]
1779                                error!(target: "stdout", "{}", &err_msg);
1780
1781                                return Err(LlamaCoreError::Operation(err_msg));
1782                            }
1783                        };
1784
1785                        let arguments = match value.get("arguments") {
1786                            Some(arguments) => arguments.to_string(),
1787                            None => {
1788                                let err_msg = format!(
1789                                    "Failed to get the arguments of the function. Tool call: {value:?}"
1790                                );
1791
1792                                #[cfg(feature = "logging")]
1793                                error!(target: "stdout", "{}", &err_msg);
1794
1795                                return Err(LlamaCoreError::Operation(err_msg));
1796                            }
1797                        };
1798
1799                        let function = Function { name, arguments };
1800
1801                        let tool_call = ToolCall {
1802                            id: "call_abc123".to_string(),
1803                            ty: "function".to_string(),
1804                            function,
1805                        };
1806
1807                        tool_calls.push(tool_call);
1808                    }
1809
1810                    let parsed = ParseResult {
1811                        raw: input.to_owned(),
1812                        content: None,
1813                        tool_calls,
1814                    };
1815
1816                    #[cfg(feature = "logging")]
1817                    info!(target: "stdout", "parsed result: {parsed:?}");
1818
1819                    Ok(parsed)
1820                }
1821                Err(e) => {
1822                    let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1823
1824                    #[cfg(feature = "logging")]
1825                    error!(target: "stdout", "{}", &err_msg);
1826
1827                    Err(LlamaCoreError::Operation(err_msg))
1828                }
1829            }
1830        }
1831        PromptTemplateType::FunctionaryV32 => {
1832            #[cfg(feature = "logging")]
1833            info!(target: "stdout", "raw input: {input}");
1834
1835            match regex::Regex::new(r">>>\s*(\w+)\s*\{(.*)\}<\|eot_id\|>") {
1836                Ok(re) => {
1837                    let mut tool_calls: Vec<ToolCall> = vec![];
1838                    for cap in re.captures_iter(input) {
1839                        #[cfg(feature = "logging")]
1840                        info!(target: "stdout", "func_name: {}", &cap[1]);
1841
1842                        #[cfg(feature = "logging")]
1843                        info!(target: "stdout", "arguments: {}", &cap[2]);
1844
1845                        let tool_call = ToolCall {
1846                            id: "call_abc123".to_string(),
1847                            ty: "function".to_string(),
1848                            function: Function {
1849                                name: cap[1].to_string(),
1850                                arguments: cap[2].to_string(),
1851                            },
1852                        };
1853
1854                        tool_calls.push(tool_call);
1855                    }
1856
1857                    let parsed = ParseResult {
1858                        raw: input.to_owned(),
1859                        content: None,
1860                        tool_calls,
1861                    };
1862
1863                    #[cfg(feature = "logging")]
1864                    info!(target: "stdout", "parsed result: {parsed:?}");
1865
1866                    Ok(parsed)
1867                }
1868                Err(e) => {
1869                    let warn_msg = format!("Failed to create a regex pattern. Reason: {e}");
1870
1871                    #[cfg(feature = "logging")]
1872                    warn!(target: "stdout", "{}", &warn_msg);
1873
1874                    Ok(ParseResult {
1875                        raw: input.to_owned(),
1876                        content: None,
1877                        tool_calls: vec![],
1878                    })
1879                }
1880            }
1881        }
1882        PromptTemplateType::FunctionaryV31 => {
1883            #[cfg(feature = "logging")]
1884            info!(target: "stdout", "raw input: {input}");
1885
1886            match regex::Regex::new(r"<function=(\w+)>\s*(\{.*?\})</function>") {
1887                Ok(re) => {
1888                    let mut tool_calls: Vec<ToolCall> = vec![];
1889                    for cap in re.captures_iter(input) {
1890                        #[cfg(feature = "logging")]
1891                        info!(target: "stdout", "func_name: {}", &cap[1]);
1892
1893                        #[cfg(feature = "logging")]
1894                        info!(target: "stdout", "arguments: {}", &cap[2]);
1895
1896                        let tool_call = ToolCall {
1897                            id: "call_abc123".to_string(),
1898                            ty: "function".to_string(),
1899                            function: Function {
1900                                name: cap[1].to_string(),
1901                                arguments: cap[2].to_string(),
1902                            },
1903                        };
1904
1905                        tool_calls.push(tool_call);
1906                    }
1907
1908                    let parsed = ParseResult {
1909                        raw: input.to_owned(),
1910                        content: None,
1911                        tool_calls,
1912                    };
1913
1914                    #[cfg(feature = "logging")]
1915                    info!(target: "stdout", "parsed result: {parsed:?}");
1916
1917                    Ok(parsed)
1918                }
1919                Err(e) => {
1920                    let warn_msg = format!("Failed to create a regex pattern. Reason: {e}");
1921
1922                    #[cfg(feature = "logging")]
1923                    warn!(target: "stdout", "{}", &warn_msg);
1924
1925                    Ok(ParseResult {
1926                        raw: input.to_owned(),
1927                        content: None,
1928                        tool_calls: vec![],
1929                    })
1930                }
1931            }
1932        }
1933        PromptTemplateType::MistralSmallTool => {
1934            #[cfg(feature = "logging")]
1935            info!(target: "stdout", "raw input: {input}");
1936
1937            match regex::Regex::new(r"\[TOOL_CALLS\]\s*(\[(.*?)\])") {
1938                Ok(re) => {
1939                    let mut values: Vec<serde_json::Value> = vec![];
1940                    if let Some(cap) = re.captures(input) {
1941                        let matched = cap[1].trim();
1942
1943                        #[cfg(feature = "logging")]
1944                        info!(target: "stdout", "captured: {matched}");
1945
1946                        match serde_json::from_str::<Vec<serde_json::Value>>(matched) {
1947                            Ok(vals) => values = vals,
1948                            Err(e) => {
1949                                let err_msg = format!(
1950                                    "Failed to deserialize generated tool calls. Reason: {e}"
1951                                );
1952
1953                                #[cfg(feature = "logging")]
1954                                error!(target: "stdout", "{}", &err_msg);
1955
1956                                return Err(LlamaCoreError::Operation(err_msg));
1957                            }
1958                        }
1959                    };
1960
1961                    let mut tool_calls: Vec<ToolCall> = vec![];
1962                    for value in values.iter() {
1963                        if let Some(object_map) = value.as_object() {
1964                            if object_map.contains_key("function") {
1965                                let mut function = Function {
1966                                    name: String::new(),
1967                                    arguments: String::new(),
1968                                };
1969
1970                                let value = object_map.get("function").unwrap();
1971                                let func_map = value.as_object().unwrap();
1972                                if func_map.contains_key("name") {
1973                                    let func_name = func_map.get("name").unwrap().as_str().unwrap();
1974                                    println!("Function name: {func_name:?}");
1975
1976                                    function.name = func_name.to_string();
1977                                }
1978                                if func_map.contains_key("arguments") {
1979                                    let args = func_map.get("arguments").unwrap();
1980                                    let arguments = args.to_string();
1981                                    println!("Arguments: {arguments:?}");
1982
1983                                    function.arguments = arguments;
1984                                }
1985
1986                                let tool_call = ToolCall {
1987                                    id: "call_abc123".to_string(),
1988                                    ty: "function".to_string(),
1989                                    function,
1990                                };
1991
1992                                tool_calls.push(tool_call);
1993                            } else if object_map.contains_key("name") {
1994                                let mut function = Function {
1995                                    name: String::new(),
1996                                    arguments: String::new(),
1997                                };
1998
1999                                let name = object_map.get("name").unwrap().as_str().unwrap();
2000                                println!("name: {name:?}");
2001                                function.name = name.to_string();
2002
2003                                if object_map.contains_key("arguments") {
2004                                    let args = object_map.get("arguments").unwrap();
2005                                    let arguments = args.to_string();
2006                                    println!("Arguments: {arguments:?}");
2007
2008                                    function.arguments = arguments;
2009                                }
2010
2011                                let tool_call = ToolCall {
2012                                    id: "call_abc123".to_string(),
2013                                    ty: "function".to_string(),
2014                                    function,
2015                                };
2016
2017                                tool_calls.push(tool_call);
2018                            }
2019                        }
2020                    }
2021
2022                    let parsed = ParseResult {
2023                        raw: input.to_owned(),
2024                        content: None,
2025                        tool_calls,
2026                    };
2027
2028                    #[cfg(feature = "logging")]
2029                    info!(target: "stdout", "parsed result: {parsed:?}");
2030
2031                    Ok(parsed)
2032                }
2033                Err(e) => {
2034                    let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
2035
2036                    #[cfg(feature = "logging")]
2037                    error!(target: "stdout", "{}", &err_msg);
2038
2039                    Err(LlamaCoreError::Operation(err_msg))
2040                }
2041            }
2042        }
2043        PromptTemplateType::Llama4Chat => {
2044            #[cfg(feature = "logging")]
2045            info!(target: "stdout", "raw input: {input:?}");
2046
2047            let mut tool_calls: Vec<ToolCall> = vec![];
2048            if let Ok(value) = serde_json::from_str::<serde_json::Value>(input) {
2049                match value.as_object() {
2050                    Some(object_map) => {
2051                        #[cfg(feature = "logging")]
2052                        debug!(target: "stdout", "object_map: {object_map:?}");
2053
2054                        // parse function name
2055                        if object_map.contains_key("name") {
2056                            let name = object_map.get("name").unwrap().as_str().unwrap();
2057
2058                            #[cfg(feature = "logging")]
2059                            debug!(target: "stdout", "name: {name:?}");
2060
2061                            let mut function = Function {
2062                                name: name.to_string(),
2063                                arguments: String::new(),
2064                            };
2065
2066                            // parse function arguments
2067                            if object_map.contains_key("parameters") {
2068                                let args = object_map.get("parameters").unwrap();
2069                                let arguments = args.to_string();
2070
2071                                #[cfg(feature = "logging")]
2072                                debug!(target: "stdout", "arguments: {:?}", &arguments);
2073
2074                                function.arguments = arguments;
2075                            }
2076
2077                            tool_calls.push(ToolCall {
2078                                id: "call_abc123".to_string(),
2079                                ty: "function".to_string(),
2080                                function,
2081                            });
2082                        } else {
2083                            let err_msg = format!(
2084                                "Failed to get the name of the function. raw input: {input:?}"
2085                            );
2086
2087                            #[cfg(feature = "logging")]
2088                            error!(target: "stdout", "{}", &err_msg);
2089
2090                            return Err(LlamaCoreError::Operation(err_msg));
2091                        }
2092                    }
2093                    None => {
2094                        let err_msg = format!("Failed to parse the JSON string. JSON: {input}");
2095
2096                        #[cfg(feature = "logging")]
2097                        error!(target: "stdout", "{}", &err_msg);
2098
2099                        return Err(LlamaCoreError::Operation(err_msg));
2100                    }
2101                }
2102            }
2103
2104            let parsed = ParseResult {
2105                raw: input.to_owned(),
2106                content: None,
2107                tool_calls,
2108            };
2109
2110            #[cfg(feature = "logging")]
2111            info!(target: "stdout", "parsed result: {parsed:?}");
2112
2113            Ok(parsed)
2114        }
2115        PromptTemplateType::Qwen3NoThink | PromptTemplateType::Smol3NoThink => {
2116            #[cfg(feature = "logging")]
2117            info!(target: "stdout", "raw input: {input:?}");
2118
2119            match regex::Regex::new(r"(?s)<tool_call>((.|\r|\n)*?)</tool_call>") {
2120                Ok(re) => {
2121                    let mut values: Vec<serde_json::Value> = vec![];
2122                    for cap in re.captures_iter(input) {
2123                        let mut matched = cap[1].trim();
2124
2125                        if matched.starts_with("\\n") {
2126                            matched = matched.trim_start_matches("\\n");
2127                        }
2128
2129                        if matched.ends_with("\\n") {
2130                            matched = matched.trim_end_matches("\\n");
2131                        }
2132
2133                        #[cfg(feature = "logging")]
2134                        info!(target: "stdout", "captured: {matched:#?}");
2135
2136                        if !matched.is_empty() {
2137                            match serde_json::from_str::<serde_json::Value>(matched) {
2138                                Ok(value) => values.push(value),
2139                                Err(e) => {
2140                                    let err_msg = format!(
2141                                    "Failed to deserialize generated tool calls: {matched:#?}. Reason: {e}"
2142                                );
2143
2144                                    #[cfg(feature = "logging")]
2145                                    error!(target: "stdout", "{}", &err_msg);
2146
2147                                    return Err(LlamaCoreError::Operation(err_msg));
2148                                }
2149                            }
2150                        }
2151                    }
2152
2153                    let mut tool_calls: Vec<ToolCall> = vec![];
2154                    for value in values.iter() {
2155                        let name = match value.get("name") {
2156                            Some(name) => name.to_string().replace("\"", ""),
2157                            None => {
2158                                let err_msg = format!(
2159                                    "Failed to get the name of the function. Tool call: {value:?}"
2160                                );
2161
2162                                #[cfg(feature = "logging")]
2163                                error!(target: "stdout", "{}", &err_msg);
2164
2165                                return Err(LlamaCoreError::Operation(err_msg));
2166                            }
2167                        };
2168
2169                        let arguments = match value.get("arguments") {
2170                            Some(arguments) => {
2171                                if arguments.is_string() {
2172                                    arguments.as_str().unwrap().to_string()
2173                                } else if arguments.is_object() {
2174                                    let map = arguments.as_object().unwrap();
2175
2176                                    #[cfg(feature = "logging")]
2177                                    info!(target: "stdout", "func arguments: {map:?}");
2178
2179                                    serde_json::to_string(map).unwrap()
2180                                } else {
2181                                    serde_json::to_string(arguments).unwrap()
2182                                }
2183                            }
2184                            None => {
2185                                let err_msg = format!(
2186                                    "Failed to get the arguments of the function. Tool call: {value:?}"
2187                                );
2188
2189                                #[cfg(feature = "logging")]
2190                                error!(target: "stdout", "{}", &err_msg);
2191
2192                                return Err(LlamaCoreError::Operation(err_msg));
2193                            }
2194                        };
2195
2196                        let function = Function { name, arguments };
2197
2198                        let tool_call = ToolCall {
2199                            id: "call_abc123".to_string(),
2200                            ty: "function".to_string(),
2201                            function,
2202                        };
2203
2204                        tool_calls.push(tool_call);
2205                    }
2206
2207                    let parsed = if tool_calls.is_empty() {
2208                        ParseResult {
2209                            raw: input.to_owned(),
2210                            content: Some(input.to_owned()),
2211                            tool_calls: vec![],
2212                        }
2213                    } else {
2214                        ParseResult {
2215                            raw: input.to_owned(),
2216                            content: None,
2217                            tool_calls,
2218                        }
2219                    };
2220
2221                    #[cfg(feature = "logging")]
2222                    info!(target: "stdout", "parsed result: {parsed:?}");
2223
2224                    Ok(parsed)
2225                }
2226                Err(e) => {
2227                    let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
2228
2229                    #[cfg(feature = "logging")]
2230                    error!(target: "stdout", "{}", &err_msg);
2231
2232                    Err(LlamaCoreError::Operation(err_msg))
2233                }
2234            }
2235        }
2236        PromptTemplateType::Gemma3 => {
2237            #[cfg(feature = "logging")]
2238            info!(target: "stdout", "raw input: {input:?}");
2239
2240            match regex::Regex::new(r"(?s)```json\s*(.*?)\s*```") {
2241                Ok(re) => {
2242                    let mut values: Vec<serde_json::Value> = vec![];
2243                    for cap in re.captures_iter(input) {
2244                        let mut matched = cap[1].trim();
2245
2246                        if matched.starts_with("\\n") {
2247                            matched = matched.trim_start_matches("\\n");
2248                        }
2249
2250                        if matched.ends_with("\\n") {
2251                            matched = matched.trim_end_matches("\\n");
2252                        }
2253
2254                        #[cfg(feature = "logging")]
2255                        info!(target: "stdout", "captured: {matched:#?}");
2256
2257                        if !matched.is_empty() {
2258                            match serde_json::from_str::<serde_json::Value>(matched) {
2259                                Ok(value) => values.push(value),
2260                                Err(e) => {
2261                                    let err_msg = format!(
2262                                    "Failed to deserialize generated tool calls: {matched:#?}. Reason: {e}"
2263                                );
2264
2265                                    #[cfg(feature = "logging")]
2266                                    error!(target: "stdout", "{}", &err_msg);
2267
2268                                    return Err(LlamaCoreError::Operation(err_msg));
2269                                }
2270                            }
2271                        }
2272                    }
2273
2274                    let mut tool_calls: Vec<ToolCall> = vec![];
2275                    for value in values.iter() {
2276                        let name = match value.get("name") {
2277                            Some(name) => name.to_string().replace("\"", ""),
2278                            None => {
2279                                let err_msg = format!(
2280                                    "Failed to get the name of the function. Tool call: {value:?}"
2281                                );
2282
2283                                #[cfg(feature = "logging")]
2284                                error!(target: "stdout", "{}", &err_msg);
2285
2286                                return Err(LlamaCoreError::Operation(err_msg));
2287                            }
2288                        };
2289
2290                        let arguments = match value.get("arguments") {
2291                            Some(arguments) => {
2292                                if arguments.is_string() {
2293                                    arguments.as_str().unwrap().to_string()
2294                                } else if arguments.is_object() {
2295                                    let map = arguments.as_object().unwrap();
2296
2297                                    #[cfg(feature = "logging")]
2298                                    info!(target: "stdout", "func arguments: {map:?}");
2299
2300                                    serde_json::to_string(map).unwrap()
2301                                } else {
2302                                    serde_json::to_string(arguments).unwrap()
2303                                }
2304                            }
2305                            None => {
2306                                let err_msg = format!(
2307                                    "Failed to get the arguments of the function. Tool call: {value:?}"
2308                                );
2309
2310                                #[cfg(feature = "logging")]
2311                                error!(target: "stdout", "{}", &err_msg);
2312
2313                                return Err(LlamaCoreError::Operation(err_msg));
2314                            }
2315                        };
2316
2317                        let function = Function { name, arguments };
2318
2319                        let tool_call = ToolCall {
2320                            id: "call_abc123".to_string(),
2321                            ty: "function".to_string(),
2322                            function,
2323                        };
2324
2325                        tool_calls.push(tool_call);
2326                    }
2327
2328                    let parsed = if tool_calls.is_empty() {
2329                        ParseResult {
2330                            raw: input.to_owned(),
2331                            content: Some(input.to_owned()),
2332                            tool_calls: vec![],
2333                        }
2334                    } else {
2335                        ParseResult {
2336                            raw: input.to_owned(),
2337                            content: None,
2338                            tool_calls,
2339                        }
2340                    };
2341
2342                    #[cfg(feature = "logging")]
2343                    info!(target: "stdout", "parsed result: {parsed:?}");
2344
2345                    Ok(parsed)
2346                }
2347                Err(e) => {
2348                    let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
2349
2350                    #[cfg(feature = "logging")]
2351                    error!(target: "stdout", "{}", &err_msg);
2352
2353                    Err(LlamaCoreError::Operation(err_msg))
2354                }
2355            }
2356        }
2357        PromptTemplateType::GptOss => {
2358            #[cfg(feature = "logging")]
2359            info!(target: "stdout", "raw input: {input:?}");
2360
2361            // Match strings ending with: <|channel|>commentary to=functions.xxxxx <|constrain|>json<|message|>yyyyy<|call|>
2362            match regex::Regex::new(
2363                r"<\|channel\|>commentary to=functions\.([^<\s]+)\s*<\|constrain\|>json<\|message\|>([^<]*)<\|call\|>$",
2364            ) {
2365                Ok(re) => {
2366                    if let Some(cap) = re.captures(input) {
2367                        let function_name = cap[1].trim();
2368                        let arguments = cap[2].trim();
2369
2370                        #[cfg(feature = "logging")]
2371                        info!(target: "stdout", "extracted function_name: {function_name}, arguments: {arguments}");
2372
2373                        let function = Function {
2374                            name: function_name.to_string(),
2375                            arguments: arguments.to_string(),
2376                        };
2377
2378                        let tool_call = ToolCall {
2379                            id: "call_abc123".to_string(),
2380                            ty: "function".to_string(),
2381                            function,
2382                        };
2383
2384                        let parsed = ParseResult {
2385                            raw: input.to_owned(),
2386                            content: None,
2387                            tool_calls: vec![tool_call],
2388                        };
2389
2390                        #[cfg(feature = "logging")]
2391                        info!(target: "stdout", "parsed result: {parsed:?}");
2392
2393                        Ok(parsed)
2394                    } else {
2395                        match regex::Regex::new(r"(?s)```json\s*(.*?)\s*```") {
2396                            Ok(re) => {
2397                                let mut values: Vec<serde_json::Value> = vec![];
2398                                for cap in re.captures_iter(input) {
2399                                    let mut matched = cap[1].trim();
2400
2401                                    if matched.starts_with("\\n") {
2402                                        matched = matched.trim_start_matches("\\n");
2403                                    }
2404
2405                                    if matched.ends_with("\\n") {
2406                                        matched = matched.trim_end_matches("\\n");
2407                                    }
2408
2409                                    #[cfg(feature = "logging")]
2410                                    info!(target: "stdout", "captured: {matched:#?}");
2411
2412                                    if !matched.is_empty() {
2413                                        match serde_json::from_str::<serde_json::Value>(matched) {
2414                                            Ok(value) => values.push(value),
2415                                            Err(e) => {
2416                                                let err_msg = format!(
2417                                                "Failed to deserialize generated tool calls: {matched:#?}. Reason: {e}"
2418                                            );
2419
2420                                                #[cfg(feature = "logging")]
2421                                                error!(target: "stdout", "{}", &err_msg);
2422
2423                                                return Err(LlamaCoreError::Operation(err_msg));
2424                                            }
2425                                        }
2426                                    }
2427                                }
2428
2429                                let mut tool_calls: Vec<ToolCall> = vec![];
2430                                for value in values.iter() {
2431                                    let name = match value.get("name") {
2432                                        Some(name) => name.to_string().replace("\"", ""),
2433                                        None => {
2434                                            let err_msg = format!(
2435                                                "Failed to get the name of the function. Tool call: {value:?}"
2436                                            );
2437
2438                                            #[cfg(feature = "logging")]
2439                                            error!(target: "stdout", "{}", &err_msg);
2440
2441                                            return Err(LlamaCoreError::Operation(err_msg));
2442                                        }
2443                                    };
2444
2445                                    let arguments = match value.get("arguments") {
2446                                        Some(arguments) => {
2447                                            if arguments.is_string() {
2448                                                arguments.as_str().unwrap().to_string()
2449                                            } else if arguments.is_object() {
2450                                                let map = arguments.as_object().unwrap();
2451
2452                                                #[cfg(feature = "logging")]
2453                                                info!(target: "stdout", "func arguments: {map:?}");
2454
2455                                                serde_json::to_string(map).unwrap()
2456                                            } else {
2457                                                serde_json::to_string(arguments).unwrap()
2458                                            }
2459                                        }
2460                                        None => {
2461                                            let err_msg = format!(
2462                                                "Failed to get the arguments of the function. Tool call: {value:?}"
2463                                            );
2464
2465                                            #[cfg(feature = "logging")]
2466                                            error!(target: "stdout", "{}", &err_msg);
2467
2468                                            return Err(LlamaCoreError::Operation(err_msg));
2469                                        }
2470                                    };
2471
2472                                    let function = Function { name, arguments };
2473
2474                                    let tool_call = ToolCall {
2475                                        id: "call_abc123".to_string(),
2476                                        ty: "function".to_string(),
2477                                        function,
2478                                    };
2479
2480                                    tool_calls.push(tool_call);
2481                                }
2482
2483                                let parsed = if tool_calls.is_empty() {
2484                                    ParseResult {
2485                                        raw: input.to_owned(),
2486                                        content: Some(input.to_owned()),
2487                                        tool_calls: vec![],
2488                                    }
2489                                } else {
2490                                    ParseResult {
2491                                        raw: input.to_owned(),
2492                                        content: Some(input.to_owned()),
2493                                        tool_calls,
2494                                    }
2495                                };
2496
2497                                #[cfg(feature = "logging")]
2498                                info!(target: "stdout", "parsed result: {parsed:?}");
2499
2500                                Ok(parsed)
2501                            }
2502                            Err(e) => {
2503                                let err_msg =
2504                                    format!("Failed to create a regex pattern. Reason: {e}");
2505
2506                                #[cfg(feature = "logging")]
2507                                error!(target: "stdout", "{}", &err_msg);
2508
2509                                Err(LlamaCoreError::Operation(err_msg))
2510                            }
2511                        }
2512                    }
2513                }
2514                Err(e) => {
2515                    let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
2516
2517                    #[cfg(feature = "logging")]
2518                    error!(target: "stdout", "{}", &err_msg);
2519
2520                    Err(LlamaCoreError::Operation(err_msg))
2521                }
2522            }
2523        }
2524        PromptTemplateType::Qwen3Agent => {
2525            #[cfg(feature = "logging")]
2526            info!(target: "stdout", "Raw input to tool call parser: {input:?}");
2527
2528            // detect <action> tags
2529            match regex::Regex::new(r"<action>(.*?)</action>")
2530                .unwrap()
2531                .captures(input)
2532            {
2533                Some(captures) => {
2534                    let action = captures.get(1).unwrap().as_str();
2535
2536                    #[cfg(feature = "logging")]
2537                    info!(target: "stdout", "Action: {action}");
2538
2539                    match serde_json::from_str::<serde_json::Value>(action) {
2540                        Ok(value) => {
2541                            let name = match value.get("name") {
2542                                Some(name) => name.to_string().replace("\"", ""),
2543                                None => {
2544                                    let err_msg = format!(
2545                                        "Failed to get the name of the function. Tool call: {value:?}"
2546                                    );
2547
2548                                    #[cfg(feature = "logging")]
2549                                    error!(target: "stdout", "{}", &err_msg);
2550
2551                                    return Err(LlamaCoreError::Operation(err_msg));
2552                                }
2553                            };
2554
2555                            let arguments = match value.get("arguments") {
2556                                Some(arguments) => {
2557                                    if arguments.is_string() {
2558                                        arguments.as_str().unwrap().to_string()
2559                                    } else if arguments.is_object() {
2560                                        let map = arguments.as_object().unwrap();
2561
2562                                        #[cfg(feature = "logging")]
2563                                        info!(target: "stdout", "func arguments: {map:?}");
2564
2565                                        serde_json::to_string(map).unwrap()
2566                                    } else {
2567                                        serde_json::to_string(arguments).unwrap()
2568                                    }
2569                                }
2570                                None => {
2571                                    let err_msg = format!(
2572                                        "Failed to get the arguments of the function. Tool call: {value:?}"
2573                                    );
2574
2575                                    #[cfg(feature = "logging")]
2576                                    error!(target: "stdout", "{}", &err_msg);
2577
2578                                    return Err(LlamaCoreError::Operation(err_msg));
2579                                }
2580                            };
2581
2582                            let function = Function { name, arguments };
2583
2584                            let tool_call = ToolCall {
2585                                id: "call_abc123".to_string(),
2586                                ty: "function".to_string(),
2587                                function,
2588                            };
2589
2590                            Ok(ParseResult {
2591                                raw: input.to_owned(),
2592                                content: Some(input.to_owned()),
2593                                tool_calls: vec![tool_call],
2594                            })
2595                        }
2596                        Err(e) => {
2597                            let err_msg = format!(
2598                            "Failed to deserialize generated tool calls: {action:#?}. Reason: {e}"
2599                        );
2600
2601                            #[cfg(feature = "logging")]
2602                            error!(target: "stdout", "{}", &err_msg);
2603
2604                            Err(LlamaCoreError::Operation(err_msg))
2605                        }
2606                    }
2607                }
2608                None => match input.contains("<final_answer>") {
2609                    true => Ok(ParseResult {
2610                        raw: input.to_owned(),
2611                        content: Some(input.to_owned()),
2612                        tool_calls: vec![],
2613                    }),
2614                    false => {
2615                        let content = format!("<final_answer>{}</final_answer>", input.trim());
2616
2617                        Ok(ParseResult {
2618                            raw: input.to_owned(),
2619                            content: Some(content),
2620                            tool_calls: vec![],
2621                        })
2622                    }
2623                },
2624            }
2625        }
2626        PromptTemplateType::SeedOssNoThink | PromptTemplateType::SeedOssThink => {
2627            #[cfg(feature = "logging")]
2628            info!(target: "stdout", "Raw input to tool call parser: {input:?}");
2629
2630            match regex::Regex::new(r"```json\n([\s\S]*?)\n") {
2631                Ok(re) => {
2632                    let mut values: Vec<serde_json::Value> = vec![];
2633                    for cap in re.captures_iter(input) {
2634                        let mut matched = cap[1].trim();
2635
2636                        if matched.starts_with("\\n") {
2637                            matched = matched.trim_start_matches("\\n");
2638                        }
2639
2640                        if matched.ends_with("\\n") {
2641                            matched = matched.trim_end_matches("\\n");
2642                        }
2643
2644                        #[cfg(feature = "logging")]
2645                        info!(target: "stdout", "captured: {matched:#?}");
2646
2647                        if !matched.is_empty() {
2648                            match serde_json::from_str::<serde_json::Value>(matched) {
2649                                Ok(value) => values.push(value),
2650                                Err(e) => {
2651                                    let err_msg = format!(
2652                                    "Failed to deserialize generated tool calls: {matched:#?}. Reason: {e}"
2653                                );
2654
2655                                    #[cfg(feature = "logging")]
2656                                    error!(target: "stdout", "{}", &err_msg);
2657
2658                                    return Err(LlamaCoreError::Operation(err_msg));
2659                                }
2660                            }
2661                        }
2662                    }
2663
2664                    let mut tool_calls: Vec<ToolCall> = vec![];
2665                    for value in values.iter() {
2666                        let name = match value.get("name") {
2667                            Some(name) => name.to_string().replace("\"", ""),
2668                            None => {
2669                                let err_msg = format!(
2670                                    "Failed to get the name of the function. Tool call: {value:?}"
2671                                );
2672
2673                                #[cfg(feature = "logging")]
2674                                error!(target: "stdout", "{}", &err_msg);
2675
2676                                return Err(LlamaCoreError::Operation(err_msg));
2677                            }
2678                        };
2679
2680                        let arguments = match value.get("arguments") {
2681                            Some(arguments) => {
2682                                if arguments.is_string() {
2683                                    arguments.as_str().unwrap().to_string()
2684                                } else if arguments.is_object() {
2685                                    let map = arguments.as_object().unwrap();
2686
2687                                    #[cfg(feature = "logging")]
2688                                    info!(target: "stdout", "func arguments: {map:?}");
2689
2690                                    serde_json::to_string(map).unwrap()
2691                                } else {
2692                                    serde_json::to_string(arguments).unwrap()
2693                                }
2694                            }
2695                            None => {
2696                                let err_msg = format!(
2697                                    "Failed to get the arguments of the function. Tool call: {value:?}"
2698                                );
2699
2700                                #[cfg(feature = "logging")]
2701                                error!(target: "stdout", "{}", &err_msg);
2702
2703                                return Err(LlamaCoreError::Operation(err_msg));
2704                            }
2705                        };
2706
2707                        let function = Function { name, arguments };
2708
2709                        let tool_call = ToolCall {
2710                            id: "call_abc123".to_string(),
2711                            ty: "function".to_string(),
2712                            function,
2713                        };
2714
2715                        tool_calls.push(tool_call);
2716                    }
2717
2718                    let parsed = if tool_calls.is_empty() {
2719                        ParseResult {
2720                            raw: input.to_owned(),
2721                            content: Some(input.to_owned()),
2722                            tool_calls: vec![],
2723                        }
2724                    } else {
2725                        ParseResult {
2726                            raw: input.to_owned(),
2727                            content: None,
2728                            tool_calls,
2729                        }
2730                    };
2731
2732                    #[cfg(feature = "logging")]
2733                    info!(target: "stdout", "parsed result: {parsed:?}");
2734
2735                    Ok(parsed)
2736                }
2737                Err(e) => {
2738                    let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
2739
2740                    #[cfg(feature = "logging")]
2741                    error!(target: "stdout", "{}", &err_msg);
2742
2743                    Err(LlamaCoreError::Operation(err_msg))
2744                }
2745            }
2746        }
2747        _ => {
2748            let err_msg = format!(
2749                "The tool use is only supported for prompt templates: {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, and {}.",
2750                PromptTemplateType::MistralTool,
2751                PromptTemplateType::ChatMLTool,
2752                PromptTemplateType::GroqLlama3Tool,
2753                PromptTemplateType::Llama3Tool,
2754                PromptTemplateType::InternLM2Tool,
2755                PromptTemplateType::NemotronTool,
2756                PromptTemplateType::FunctionaryV32,
2757                PromptTemplateType::MistralSmallTool,
2758                PromptTemplateType::Llama4Chat,
2759                PromptTemplateType::Qwen3NoThink,
2760                PromptTemplateType::Smol3NoThink,
2761                PromptTemplateType::Gemma3,
2762                PromptTemplateType::GptOss,
2763                PromptTemplateType::Qwen3Agent,
2764                PromptTemplateType::SeedOssNoThink,
2765                PromptTemplateType::SeedOssThink
2766            );
2767
2768            #[cfg(feature = "logging")]
2769            error!(target: "stdout", "{}", &err_msg);
2770
2771            Err(LlamaCoreError::Operation(err_msg))
2772        }
2773    }
2774}
2775
2776fn check_model_metadata(
2777    chat_request: &ChatCompletionRequest,
2778) -> Result<GgmlMetadata, LlamaCoreError> {
2779    let mut should_update = false;
2780    let mut metadata = get_model_metadata(chat_request.model.as_ref())?;
2781
2782    // check if necessary to update `image`
2783    if metadata.prompt_template.is_image_supported() {
2784        if let Some(ChatCompletionRequestMessage::User(user_message)) = chat_request.messages.last()
2785        {
2786            if let ChatCompletionUserMessageContent::Parts(parts) = user_message.content() {
2787                for part in parts {
2788                    if let ContentPart::Image(image_part) = part {
2789                        let image = image_part.image();
2790
2791                        if image.is_url() {
2792                            let err_msg = "The image is provided in URL format. Only base64 format is supported.".to_string();
2793
2794                            #[cfg(feature = "logging")]
2795                            error!(target: "stdout", "{}", &err_msg);
2796
2797                            return Err(LlamaCoreError::Operation(err_msg));
2798                        } else {
2799                            #[cfg(feature = "logging")]
2800                            info!(target: "stdout", "The image is provided in base64 format.");
2801
2802                            // TODO: now only support a single image
2803
2804                            break;
2805                        }
2806                    }
2807                }
2808            }
2809        }
2810    }
2811
2812    // check if necessary to update temperature
2813    if let Some(temp) = chat_request.temperature {
2814        if metadata.temperature != temp {
2815            // update temperature
2816            metadata.temperature = temp;
2817
2818            if !should_update {
2819                should_update = true;
2820            }
2821        }
2822    }
2823
2824    // check if necessary to update top_p
2825    if let Some(top_p) = chat_request.top_p {
2826        if metadata.top_p != top_p {
2827            // update top_p
2828            metadata.top_p = top_p;
2829
2830            if !should_update {
2831                should_update = true;
2832            }
2833        }
2834    }
2835
2836    // check if necessary to update frequency_penalty
2837    if let Some(frequency_penalty) = chat_request.frequency_penalty {
2838        if metadata.frequency_penalty != frequency_penalty {
2839            // update frequency_penalty
2840            metadata.frequency_penalty = frequency_penalty;
2841
2842            if !should_update {
2843                should_update = true;
2844            }
2845        }
2846    }
2847
2848    // check if necessary to update presence_penalty
2849    if let Some(presence_penalty) = chat_request.presence_penalty {
2850        if metadata.presence_penalty != presence_penalty {
2851            // update presence_penalty
2852            metadata.presence_penalty = presence_penalty;
2853
2854            if !should_update {
2855                should_update = true;
2856            }
2857        }
2858    }
2859
2860    // check if the `embedding` option is disabled
2861    if metadata.embeddings {
2862        metadata.embeddings = false;
2863
2864        if !should_update {
2865            should_update = true;
2866        }
2867    }
2868
2869    if should_update {
2870        #[cfg(feature = "logging")]
2871        info!(target: "stdout", "Update the model metadata.");
2872
2873        // update the target graph with the new metadata
2874        update_model_metadata(chat_request.model.as_ref(), &metadata)?;
2875    }
2876
2877    Ok(metadata)
2878}
2879
2880fn update_n_predict(
2881    chat_request: &ChatCompletionRequest,
2882    metadata: &mut GgmlMetadata,
2883    available_completion_tokens: u64,
2884) -> Result<(), LlamaCoreError> {
2885    let mut should_update = false;
2886
2887    #[cfg(feature = "logging")]
2888    info!(target: "stdout", "n_predict: {}", metadata.n_predict);
2889
2890    // From high to low priority
2891    // 1. chat_request.max_completion_tokens
2892    // 2. available_completion_tokens
2893    // 3. n_predict
2894
2895    if let Some(max_completion_tokens) = chat_request.max_completion_tokens {
2896        if metadata.n_predict != max_completion_tokens {
2897            #[cfg(feature = "logging")]
2898            info!(target: "stdout", "Update n_predict with max_completion_tokens from {} to {}", metadata.n_predict, max_completion_tokens);
2899
2900            metadata.n_predict = max_completion_tokens;
2901
2902            if !should_update {
2903                should_update = true;
2904            }
2905        }
2906    }
2907
2908    // TODO: remove this condition after [Issue #3958 on WasmEdge](https://github.com/WasmEdge/WasmEdge/issues/3958) is fixed
2909    if metadata.n_predict == -2 {
2910        #[cfg(feature = "logging")]
2911        info!(target: "stdout", "Update n_predict with available_completion_tokens from {} to {}", metadata.n_predict, available_completion_tokens);
2912
2913        // update n_predict
2914        metadata.n_predict = available_completion_tokens as i32;
2915
2916        if !should_update {
2917            should_update = true;
2918        }
2919    }
2920
2921    if metadata.n_predict == -1
2922        || (metadata.n_predict > 0 && metadata.n_predict < available_completion_tokens as i32)
2923        || (metadata.n_predict < 0 && metadata.n_predict != -2)
2924    // TODO: remove this condition after [Issue #3958 on WasmEdge](https://github.com/WasmEdge/WasmEdge/issues/3958) is fixed
2925    {
2926        #[cfg(feature = "logging")]
2927        info!(target: "stdout", "Update n_predict with available_completion_tokens from {} to {}", metadata.n_predict, available_completion_tokens);
2928
2929        // update n_predict
2930        metadata.n_predict = available_completion_tokens as i32;
2931
2932        if !should_update {
2933            should_update = true;
2934        }
2935    }
2936
2937    if should_update {
2938        #[cfg(feature = "logging")]
2939        info!(target: "stdout", "Update the model metadata.");
2940
2941        // update the target graph with the new metadata
2942        update_model_metadata(chat_request.model.as_ref(), metadata)?;
2943    }
2944
2945    Ok(())
2946}
2947
2948/// Build post-processing for output based on template type
2949fn post_process(
2950    output: impl AsRef<str>,
2951    template_ty: &PromptTemplateType,
2952) -> Result<String, String> {
2953    let output = if *template_ty == PromptTemplateType::Baichuan2 {
2954        if output.as_ref().contains("用户:") {
2955            output.as_ref().trim_end_matches("用户:").trim().to_owned()
2956        } else {
2957            output.as_ref().trim().to_owned()
2958        }
2959    } else if *template_ty == PromptTemplateType::OpenChat {
2960        if output.as_ref().contains("<|end_of_turn|>") {
2961            output
2962                .as_ref()
2963                .trim_end_matches("<|end_of_turn|>")
2964                .trim()
2965                .to_owned()
2966        } else {
2967            output.as_ref().trim().to_owned()
2968        }
2969    } else if *template_ty == PromptTemplateType::GemmaInstruct
2970        || *template_ty == PromptTemplateType::Gemma3
2971    {
2972        let s = output.as_ref().trim();
2973        if s.ends_with("<end_of_turn>") {
2974            s.trim_end_matches("<end_of_turn>").trim().to_owned()
2975        } else {
2976            s.to_owned()
2977        }
2978    } else if *template_ty == PromptTemplateType::ChatML
2979        || *template_ty == PromptTemplateType::ChatMLTool
2980        || *template_ty == PromptTemplateType::InternLM2Tool
2981        || *template_ty == PromptTemplateType::MiniCPMV
2982    {
2983        let mut s = output.as_ref().trim();
2984        if s.ends_with("<|endoftext|>") {
2985            s = s.trim_end_matches("<|endoftext|>").trim();
2986        }
2987
2988        if s.starts_with(":") {
2989            s = s.trim_start_matches(":").trim();
2990        }
2991
2992        // handle Qwen3 empty think tags
2993        let x = {
2994            let pat = r#"<think>
2995
2996</think>
2997"#;
2998            if s.contains(pat) {
2999                let x = s.replace(pat, "");
3000                if x.starts_with("()") {
3001                    x.trim_start_matches("()").to_owned()
3002                } else {
3003                    x.to_owned()
3004                }
3005            } else {
3006                s.to_owned()
3007            }
3008        };
3009        s = x.trim();
3010
3011        if s.contains("<|im_start|>") && s.contains("<|im_end|>") {
3012            let idx_start = s.find("<|im_start|>").unwrap();
3013            let idx_end = s.find("<|im_end|>").unwrap();
3014
3015            match idx_start <= idx_end {
3016                true => s.split("<|im_start|>").collect::<Vec<_>>()[0]
3017                    .trim()
3018                    .to_owned(),
3019                false => s.split("<|im_end|>").collect::<Vec<_>>()[0]
3020                    .trim()
3021                    .to_owned(),
3022            }
3023        } else if s.contains("<|im_start|>") {
3024            s.split("<|im_start|>").collect::<Vec<_>>()[0]
3025                .trim()
3026                .to_owned()
3027        } else if s.contains("<|im_end|>") {
3028            let output = s.trim_end_matches("<|im_end|>").trim();
3029            if output.starts_with(": ") {
3030                output.trim_start_matches(": ").to_owned()
3031            } else {
3032                output.to_owned()
3033            }
3034        } else {
3035            s.to_owned()
3036        }
3037    } else if *template_ty == PromptTemplateType::Zephyr
3038        || *template_ty == PromptTemplateType::MistralLite
3039        || *template_ty == PromptTemplateType::MistralTool
3040        || *template_ty == PromptTemplateType::MistralInstruct
3041        || *template_ty == PromptTemplateType::MistralSmallChat
3042        || *template_ty == PromptTemplateType::MistralSmallTool
3043        || *template_ty == PromptTemplateType::BreezeInstruct
3044    {
3045        if output.as_ref().contains("</s><") {
3046            output.as_ref().trim_end_matches("</s><").trim().to_owned()
3047        } else if output.as_ref().contains("</s>") {
3048            output
3049                .as_ref()
3050                .strip_suffix("</s>")
3051                .unwrap()
3052                .trim()
3053                .to_owned()
3054        } else {
3055            output.as_ref().trim().to_owned()
3056        }
3057    } else if *template_ty == PromptTemplateType::DeepseekChat {
3058        if output.as_ref().contains("<|end_of_sentence|>") {
3059            output
3060                .as_ref()
3061                .trim_end_matches("<|end_of_sentence|>")
3062                .trim()
3063                .replace("<|end_of_sentence|>", " ")
3064                .trim()
3065                .to_owned()
3066        } else {
3067            output.as_ref().trim().to_owned()
3068        }
3069    } else if *template_ty == PromptTemplateType::HumanAssistant {
3070        if output.as_ref().contains("Human:") {
3071            output.as_ref().trim_end_matches("Human:").trim().to_owned()
3072        } else {
3073            output.as_ref().trim().to_owned()
3074        }
3075    } else if *template_ty == PromptTemplateType::SolarInstruct {
3076        let s = output.as_ref().trim();
3077
3078        if s.starts_with("### Answer") {
3079            let s = s.trim_start_matches("###").trim();
3080
3081            if s.starts_with("Answer:\n") {
3082                s.replace("Answer:\n", "Answer: ")
3083            } else {
3084                s.to_owned()
3085            }
3086        } else {
3087            s.to_owned()
3088        }
3089    } else if *template_ty == PromptTemplateType::Llama2Chat
3090        || *template_ty == PromptTemplateType::NemotronTool
3091        || *template_ty == PromptTemplateType::NemotronChat
3092    {
3093        let s = output.as_ref().trim();
3094        if s.ends_with("</s>") {
3095            s.trim_end_matches("</s>").trim().to_owned()
3096        } else {
3097            s.to_owned()
3098        }
3099    } else if *template_ty == PromptTemplateType::Llama3Chat
3100        || *template_ty == PromptTemplateType::GroqLlama3Tool
3101        || *template_ty == PromptTemplateType::Llama3Tool
3102        || *template_ty == PromptTemplateType::FunctionaryV32
3103    {
3104        let s = output.as_ref().trim();
3105        if s.ends_with("<|eot_id|>") {
3106            s.trim_end_matches("<|eot_id|>").trim().to_owned()
3107        } else {
3108            s.to_owned()
3109        }
3110    } else if *template_ty == PromptTemplateType::Phi3Chat {
3111        let s = output.as_ref().trim();
3112        if s.ends_with("<|end|>") {
3113            s.trim_end_matches("<|end|>").trim().to_owned()
3114        } else {
3115            s.to_owned()
3116        }
3117    } else if *template_ty == PromptTemplateType::Phi4Chat {
3118        let mut s = output.as_ref().trim();
3119
3120        if s.starts_with("think>") {
3121            s = s.trim_start_matches("think>").trim();
3122        }
3123
3124        if s.ends_with("<|im_end|>") {
3125            s.trim_end_matches("<|im_end|>").trim().to_owned()
3126        } else if s.ends_with("<|end|>") {
3127            s.trim_end_matches("<|end|>").trim().to_owned()
3128        } else {
3129            s.to_owned()
3130        }
3131    } else if *template_ty == PromptTemplateType::FunctionaryV31 {
3132        let mut s = output.as_ref().trim();
3133        if s.ends_with("<|eot_id|>") {
3134            s = s.trim_end_matches("<|eot_id|>").trim();
3135        }
3136        if s.ends_with("<|eom_id|>") {
3137            s = s.trim_end_matches("<|eom_id|>").trim();
3138        }
3139        s.to_owned()
3140    } else if *template_ty == PromptTemplateType::MoxinChat
3141        || *template_ty == PromptTemplateType::MoxinInstruct
3142    {
3143        let s = output.as_ref().trim();
3144        if s.ends_with("</s>") {
3145            s.trim_end_matches("</s>").trim().to_owned()
3146        } else if s.ends_with("[INST]") {
3147            s.trim_end_matches("[INST]").trim().to_owned()
3148        } else {
3149            s.to_owned()
3150        }
3151    } else if *template_ty == PromptTemplateType::Falcon3 {
3152        let s = output.as_ref().trim();
3153        if s.ends_with("<|endoftext|>") {
3154            s.trim_end_matches("<|endoftext|>").trim().to_owned()
3155        } else {
3156            s.to_owned()
3157        }
3158    } else if *template_ty == PromptTemplateType::Megrez {
3159        let s = output.as_ref().trim();
3160        if s.ends_with("<|turn_end|>") {
3161            s.trim_end_matches("<|turn_end|>").trim().to_owned()
3162        } else {
3163            s.to_owned()
3164        }
3165    } else if *template_ty == PromptTemplateType::Qwen2vl
3166        || *template_ty == PromptTemplateType::Qwen3NoThink
3167        || *template_ty == PromptTemplateType::ChatMLThink
3168    {
3169        let mut s = output.as_ref().trim();
3170
3171        if s.starts_with(":") {
3172            s = s.trim_start_matches(":").trim();
3173        }
3174
3175        if s.starts_with("</think>") {
3176            s = s.trim_start_matches("</think>").trim();
3177        }
3178
3179        if s.ends_with("<|im_end|>") {
3180            s.trim_end_matches("<|im_end|>").trim().to_owned()
3181        } else {
3182            s.to_owned()
3183        }
3184    } else if *template_ty == PromptTemplateType::VicunaLlava {
3185        let s = output.as_ref().trim();
3186        if s.ends_with("</s>") {
3187            s.trim_end_matches("</s>").trim().to_owned()
3188        } else {
3189            s.to_owned()
3190        }
3191    } else if *template_ty == PromptTemplateType::ExaoneDeepChat
3192        || *template_ty == PromptTemplateType::ExaoneChat
3193    {
3194        let mut s = output.as_ref().trim();
3195
3196        if s.ends_with("[|endofturn|]") {
3197            s = s.trim_end_matches("[|endofturn|]").trim();
3198        }
3199
3200        s.to_owned()
3201    } else if *template_ty == PromptTemplateType::Llama4Chat {
3202        let mut s = output.as_ref().trim();
3203
3204        if s.ends_with("<|eot|>") {
3205            s = s.trim_end_matches("<|eot|>").trim();
3206        }
3207
3208        s.to_owned()
3209    } else if *template_ty == PromptTemplateType::Smolvl {
3210        let mut s = output.as_ref().trim();
3211
3212        if s.starts_with(":") {
3213            s = s.trim_start_matches(":").trim();
3214        }
3215
3216        if s.ends_with("<end_of_utterance>") {
3217            s = s.trim_end_matches("<end_of_utterance>").trim();
3218        }
3219
3220        if s.contains("<end_of_utterance>:") {
3221            let parts = s.split("<end_of_utterance>:").collect::<Vec<_>>();
3222            parts.last().unwrap().trim().to_owned()
3223        } else {
3224            s.to_owned()
3225        }
3226    } else if *template_ty == PromptTemplateType::Smol3NoThink {
3227        let mut s = output.as_ref().trim();
3228
3229        if s.ends_with("<|im_end|>") {
3230            s = s.trim_end_matches("<|im_end|>").trim();
3231        }
3232
3233        let re = regex::Regex::new(r"(?s)^<think>.*?</think>\s*").unwrap();
3234        re.replace(s, "").to_string()
3235    } else if *template_ty == PromptTemplateType::GptOss {
3236        let s = output.as_ref().trim();
3237
3238        let re =
3239            regex::Regex::new(r"(?s).*<\|channel\|>final<\|message\|>(.*?)<\|return\|>$").unwrap();
3240
3241        if let Some(caps) = re.captures(s) {
3242            let extracted = &caps[1];
3243            extracted.to_owned()
3244        } else {
3245            s.to_owned()
3246        }
3247    } else if *template_ty == PromptTemplateType::Qwen3Agent {
3248        let mut s = output.as_ref().trim();
3249
3250        if s.starts_with(":") {
3251            s = s.trim_start_matches(":").trim();
3252        }
3253
3254        if s.starts_with("</think>") {
3255            s = s.trim_start_matches("</think>").trim();
3256        }
3257
3258        if s.ends_with("<|im_end|>") {
3259            s = s.trim_end_matches("<|im_end|>").trim();
3260        }
3261
3262        if s.contains("<final_answer>") && !s.contains("</final_answer>") {
3263            format!("{s}</final_answer>")
3264        } else {
3265            s.to_owned()
3266        }
3267    } else if *template_ty == PromptTemplateType::SeedOssNoThink {
3268        let s = output.as_ref().trim();
3269
3270        let re = regex::Regex::new(r"(?s)</seed:think>\s*(.*?)\s*<seed:eos>").unwrap();
3271
3272        if let Some(caps) = re.captures(s) {
3273            let extracted = &caps[1];
3274            extracted.to_owned()
3275        } else {
3276            s.to_owned()
3277        }
3278    } else {
3279        output.as_ref().trim().to_owned()
3280    };
3281
3282    Ok(output)
3283}
3284
3285/// Build the chat prompt from the chat messages.
3286///
3287/// # Arguments
3288///
3289/// * `model_name`: The name of the model.
3290///
3291/// * `chat_request`: The chat request.
3292///
3293/// # Returns
3294///
3295/// A tuple containing the prompt, the number of available tokens for completions, and a boolean indicating whether tools are used.
3296fn build_prompt(
3297    model_name: Option<&String>,
3298    chat_request: &mut ChatCompletionRequest,
3299) -> Result<(String, u64, bool), LlamaCoreError> {
3300    let metadata = get_model_metadata(model_name)?;
3301    let ctx_size = metadata.ctx_size as u64;
3302    let chat_prompt = ChatPrompt::from(metadata.prompt_template);
3303
3304    // compute max prompt tokens, which is 80% of the context size
3305    let max_prompt_tokens = ctx_size * 4 / 5;
3306
3307    loop {
3308        // ! DO NOT REMOVE
3309        {
3310            // // build prompt
3311            // let prompt = match chat_prompt.build(&mut chat_request.messages) {
3312            //     Ok(prompt) => prompt,
3313            //     Err(e) => {
3314            //         let err_msg = format!("Fail to build chat prompts. Reason: {}", e);
3315
3316            //         #[cfg(feature = "logging")]
3317            //         error!(target: "stdout", "{}", &err_msg);
3318
3319            //         return Err(LlamaCoreError::Operation(err_msg));
3320            //     }
3321            // };
3322        }
3323
3324        if chat_request.messages.is_empty() {
3325            let err_msg = "The messages in the chat request are empty.";
3326
3327            #[cfg(feature = "logging")]
3328            error!(target: "stdout", "{err_msg}");
3329
3330            return Err(LlamaCoreError::Operation(err_msg.to_owned()));
3331        }
3332
3333        #[cfg(feature = "logging")]
3334        {
3335            let mut role_chain = String::new();
3336            for (idx, message) in chat_request.messages.iter().enumerate() {
3337                if idx == chat_request.messages.len() - 1 {
3338                    role_chain.push_str(&format!("{}", message.role()));
3339                } else {
3340                    role_chain.push_str(&format!("{} -> ", message.role()));
3341                }
3342            }
3343            info!(target: "stdout", "Role chain: {role_chain}");
3344        }
3345
3346        let (prompt, tool_use) = match chat_request.tool_choice.as_ref() {
3347            Some(tool_choice) => match tool_choice {
3348                ToolChoice::None => {
3349                    match chat_prompt.build_with_tools(&mut chat_request.messages, Some(&[])) {
3350                        Ok(prompt) => (prompt, false),
3351                        Err(e) => {
3352                            let err_msg = format!("Fail to build chat prompts. Reason: {e}");
3353
3354                            #[cfg(feature = "logging")]
3355                            error!(target: "stdout", "{}", &err_msg);
3356
3357                            return Err(LlamaCoreError::Operation(err_msg));
3358                        }
3359                    }
3360                }
3361                _ => match chat_request.tools.as_ref() {
3362                    Some(tools) => match chat_prompt
3363                        .build_with_tools(&mut chat_request.messages, Some(tools.as_slice()))
3364                    {
3365                        Ok(prompt) => (prompt, true),
3366                        Err(e) => {
3367                            let err_msg = format!("Fail to build chat prompts. Reason: {e}");
3368
3369                            #[cfg(feature = "logging")]
3370                            error!(target: "stdout", "{}", &err_msg);
3371
3372                            return Err(LlamaCoreError::Operation(err_msg));
3373                        }
3374                    },
3375                    None => {
3376                        #[cfg(feature = "logging")]
3377                        warn!(target: "stdout", "The tool choice without tools is not supported.");
3378
3379                        match chat_prompt.build_with_tools(&mut chat_request.messages, None) {
3380                            Ok(prompt) => (prompt, false),
3381                            Err(e) => {
3382                                let err_msg = format!("Fail to build chat prompts. Reason: {e}");
3383
3384                                #[cfg(feature = "logging")]
3385                                error!(target: "stdout", "{}", &err_msg);
3386
3387                                return Err(LlamaCoreError::Operation(err_msg));
3388                            }
3389                        }
3390                    }
3391                },
3392            },
3393            None => match chat_prompt.build_with_tools(&mut chat_request.messages, None) {
3394                Ok(prompt) => (prompt, false),
3395                Err(e) => {
3396                    let err_msg = format!("Fail to build chat prompts. Reason: {e}");
3397
3398                    #[cfg(feature = "logging")]
3399                    error!(target: "stdout", "{}", &err_msg);
3400
3401                    return Err(LlamaCoreError::Operation(err_msg));
3402                }
3403            },
3404        };
3405        #[cfg(feature = "logging")]
3406        info!(target: "stdout", "Try to set prompt: {prompt}");
3407
3408        // set prompt
3409        set_prompt(model_name, &prompt)?;
3410
3411        // Retrieve the number of prompt tokens.
3412        let token_info = get_token_info_by_graph_name(model_name)?;
3413
3414        match token_info.prompt_tokens > max_prompt_tokens {
3415            true => {
3416                match chat_request.messages[0].role() {
3417                    ChatCompletionRole::System => {
3418                        // corner case: context size is too small, `system -> user -> assistant -> tool` cannot be trimmed.
3419                        if chat_request.messages.len() == 4
3420                            && chat_request.messages[1].role() == ChatCompletionRole::User
3421                            && chat_request.messages[2].role() == ChatCompletionRole::Assistant
3422                            && chat_request.messages[3].role() == ChatCompletionRole::Tool
3423                        {
3424                            let err_msg = format!(
3425                                "The number of prompt tokens ({}) is greater than the max prompt tokens ({}). Please increase the context size.",
3426                                token_info.prompt_tokens, max_prompt_tokens
3427                            );
3428
3429                            #[cfg(feature = "logging")]
3430                            error!(target: "stdout", "{}", &err_msg);
3431
3432                            return Err(LlamaCoreError::Operation(err_msg));
3433                        }
3434
3435                        if chat_request.messages.len() > 2 {
3436                            #[cfg(feature = "logging")]
3437                            info!(target: "stdout", "Prune chat history: current length {}", chat_request.messages.len());
3438
3439                            // remove user_1 if it exists
3440                            // For example, `system -> user_1 -> ... -> user_2 -> ... -> user_latest` will be converted to `system -> ... -> user_2 -> ... -> user_latest`
3441                            if chat_request.messages[1].role() == ChatCompletionRole::User {
3442                                let user_message = chat_request.messages.remove(1);
3443
3444                                #[cfg(feature = "logging")]
3445                                info!(target: "stdout", "Remove a user message from the chat history: {user_message:?}");
3446                            }
3447
3448                            // remove all messages until the message is of `user`
3449                            // For example, `system -> ... -> user_2 -> ... -> user_latest` will be converted to `system -> user_2 -> ... -> user_latest`
3450                            while chat_request.messages[1].role() != ChatCompletionRole::User {
3451                                let message = chat_request.messages.remove(1);
3452
3453                                #[cfg(feature = "logging")]
3454                                info!(target: "stdout", "Remove a {} message from the chat history: {:?}", message.role(), message);
3455
3456                                if chat_request.messages.len() == 1 {
3457                                    let err_msg = format!("The last message in the chat history should be a user message, but found a {} message.", message.role());
3458
3459                                    #[cfg(feature = "logging")]
3460                                    error!(target: "stdout", "{err_msg}");
3461
3462                                    return Err(LlamaCoreError::Operation(err_msg));
3463                                }
3464                            }
3465                        } else if token_info.prompt_tokens > ctx_size {
3466                            let err_msg = format!(
3467                                    "The number of prompt tokens ({}) is greater than the context size ({}). Please increase the context size, or simplify the input message.",
3468                                    token_info.prompt_tokens, ctx_size
3469                                );
3470
3471                            #[cfg(feature = "logging")]
3472                            error!(target: "stdout", "{}", &err_msg);
3473
3474                            return Err(LlamaCoreError::Operation(err_msg));
3475                        } else {
3476                            return Ok((prompt, ctx_size - token_info.prompt_tokens, tool_use));
3477                        }
3478                    }
3479                    ChatCompletionRole::User => {
3480                        // corner case: context size is too small, `user -> assistant -> tool` cannot be trimmed.
3481                        if chat_request.messages.len() == 3
3482                            && chat_request.messages[1].role() == ChatCompletionRole::User
3483                            && chat_request.messages[2].role() == ChatCompletionRole::Assistant
3484                            && chat_request.messages[3].role() == ChatCompletionRole::Tool
3485                        {
3486                            let err_msg = format!(
3487                            "The number of prompt tokens ({}) is greater than the max prompt tokens ({}). Please increase the context size.",
3488                            token_info.prompt_tokens, max_prompt_tokens
3489                        );
3490
3491                            #[cfg(feature = "logging")]
3492                            error!(target: "stdout", "{}", &err_msg);
3493
3494                            return Err(LlamaCoreError::Operation(err_msg));
3495                        }
3496
3497                        if chat_request.messages.len() > 1 {
3498                            // user_1 -> ... -> user_2 -> ... -> user_latest
3499
3500                            // remove user_1 if it exists
3501                            // For example, `user_1 -> ... -> user_2 -> ... -> user_latest` will be converted to `... -> user_2 -> ... -> user_latest`
3502                            if chat_request.messages[0].role() == ChatCompletionRole::User {
3503                                let user_message = chat_request.messages.remove(0);
3504
3505                                #[cfg(feature = "logging")]
3506                                info!(target: "stdout", "Remove a user message from the chat history: {user_message:?}");
3507                            }
3508
3509                            // remove all messages until the message is of `user`
3510                            // For example, `... -> user_2 -> ... -> user_latest` will be converted to `user_2 -> ... -> user_latest`
3511                            while chat_request.messages[0].role() != ChatCompletionRole::User {
3512                                let message = chat_request.messages.remove(0);
3513
3514                                #[cfg(feature = "logging")]
3515                                info!(target: "stdout", "Remove a {} message from the chat history: {:?}", message.role(), message);
3516
3517                                if chat_request.messages.is_empty() {
3518                                    let err_msg = format!("The last message in the chat history should be a user message, but found a {} message.", message.role());
3519
3520                                    #[cfg(feature = "logging")]
3521                                    error!(target: "stdout", "{err_msg}");
3522
3523                                    return Err(LlamaCoreError::Operation(err_msg));
3524                                }
3525                            }
3526                        } else if token_info.prompt_tokens > ctx_size {
3527                            let err_msg = format!(
3528                                    "The number of prompt tokens ({}) is greater than the context size ({}). Please increase the context size, or simplify the input message.",
3529                                    token_info.prompt_tokens, ctx_size
3530                                );
3531
3532                            #[cfg(feature = "logging")]
3533                            error!(target: "stdout", "{}", &err_msg);
3534
3535                            return Err(LlamaCoreError::Operation(err_msg));
3536                        } else {
3537                            return Ok((prompt, ctx_size - token_info.prompt_tokens, tool_use));
3538                        }
3539                    }
3540                    _ => {
3541                        #[cfg(feature = "logging")]
3542                        info!(target: "stdout", "remove a {} message from the message queue", chat_request.messages[0].role());
3543
3544                        chat_request.messages.remove(0);
3545                    }
3546                }
3547
3548                continue;
3549            }
3550            false => return Ok((prompt, ctx_size - max_prompt_tokens, tool_use)),
3551        }
3552    }
3553}
3554
3555fn set_prompt(model_name: Option<&String>, prompt: impl AsRef<str>) -> Result<(), LlamaCoreError> {
3556    let chat_graphs = match CHAT_GRAPHS.get() {
3557        Some(chat_graphs) => chat_graphs,
3558        None => {
3559            let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
3560
3561            #[cfg(feature = "logging")]
3562            error!(target: "stdout", "{}", &err_msg);
3563
3564            return Err(LlamaCoreError::Operation(err_msg.into()));
3565        }
3566    };
3567
3568    let mut chat_graphs = chat_graphs.lock().map_err(|e| {
3569        let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
3570
3571        #[cfg(feature = "logging")]
3572        error!(target: "stdout", "{}", &err_msg);
3573
3574        LlamaCoreError::Operation(err_msg)
3575    })?;
3576
3577    match model_name {
3578        Some(model_name) => {
3579            #[cfg(feature = "logging")]
3580            info!(target: "stdout", "Set prompt to the chat model named {model_name}");
3581
3582            match chat_graphs.contains_key(model_name) {
3583                true => {
3584                    let graph = chat_graphs.get_mut(model_name).unwrap();
3585                    let tensor_data = prompt.as_ref().as_bytes().to_vec();
3586                    set_tensor_data_u8(graph, 0, &tensor_data)
3587                }
3588                false => match chat_graphs.iter_mut().next() {
3589                    Some((_, graph)) => {
3590                        let tensor_data = prompt.as_ref().as_bytes().to_vec();
3591                        set_tensor_data_u8(graph, 0, &tensor_data)
3592                    }
3593                    None => {
3594                        let err_msg = "There is no model available in the chat graphs.";
3595
3596                        #[cfg(feature = "logging")]
3597                        error!(target: "stdout", "{}", &err_msg);
3598
3599                        Err(LlamaCoreError::Operation(err_msg.into()))
3600                    }
3601                },
3602            }
3603        }
3604        None => {
3605            #[cfg(feature = "logging")]
3606            info!(target: "stdout", "Set prompt to the default chat model.");
3607
3608            match chat_graphs.iter_mut().next() {
3609                Some((_, graph)) => {
3610                    let tensor_data = prompt.as_ref().as_bytes().to_vec();
3611                    set_tensor_data_u8(graph, 0, &tensor_data)
3612                }
3613                None => {
3614                    let err_msg = "There is no model available in the chat graphs while trying to set prompt to the default model.";
3615
3616                    #[cfg(feature = "logging")]
3617                    error!(target: "stdout", "{err_msg}");
3618
3619                    Err(LlamaCoreError::Operation(err_msg.into()))
3620                }
3621            }
3622        }
3623    }
3624}
3625
3626/// Get a copy of the metadata of the model.
3627fn get_model_metadata(model_name: Option<&String>) -> Result<GgmlMetadata, LlamaCoreError> {
3628    let chat_graphs = match CHAT_GRAPHS.get() {
3629        Some(chat_graphs) => chat_graphs,
3630        None => {
3631            let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
3632
3633            #[cfg(feature = "logging")]
3634            error!(target: "stdout", "{err_msg}");
3635
3636            return Err(LlamaCoreError::Operation(err_msg.into()));
3637        }
3638    };
3639
3640    let chat_graphs = chat_graphs.lock().map_err(|e| {
3641        let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
3642
3643        #[cfg(feature = "logging")]
3644        error!(target: "stdout", "{}", &err_msg);
3645
3646        LlamaCoreError::Operation(err_msg)
3647    })?;
3648
3649    match model_name {
3650        Some(model_name) => match chat_graphs.contains_key(model_name) {
3651            true => {
3652                let graph = chat_graphs.get(model_name).unwrap();
3653                Ok(graph.metadata.clone())
3654            }
3655            false => match chat_graphs.iter().next() {
3656                Some((_, graph)) => Ok(graph.metadata.clone()),
3657                None => {
3658                    let err_msg = "There is no model available in the chat graphs.";
3659
3660                    #[cfg(feature = "logging")]
3661                    error!(target: "stdout", "{}", &err_msg);
3662
3663                    Err(LlamaCoreError::Operation(err_msg.into()))
3664                }
3665            },
3666        },
3667        None => match chat_graphs.iter().next() {
3668            Some((_, graph)) => Ok(graph.metadata.clone()),
3669            None => {
3670                let err_msg = "There is no model available in the chat graphs.";
3671
3672                #[cfg(feature = "logging")]
3673                error!(target: "stdout", "{err_msg}");
3674
3675                Err(LlamaCoreError::Operation(err_msg.into()))
3676            }
3677        },
3678    }
3679}
3680
3681fn update_model_metadata(
3682    model_name: Option<&String>,
3683    metadata: &GgmlMetadata,
3684) -> Result<(), LlamaCoreError> {
3685    let config = match serde_json::to_string(metadata) {
3686        Ok(config) => config,
3687        Err(e) => {
3688            let err_msg = format!("Fail to serialize metadata to a JSON string. {e}");
3689
3690            #[cfg(feature = "logging")]
3691            error!(target: "stdout", "{}", &err_msg);
3692
3693            return Err(LlamaCoreError::Operation(err_msg));
3694        }
3695    };
3696
3697    let chat_graphs = match CHAT_GRAPHS.get() {
3698        Some(chat_graphs) => chat_graphs,
3699        None => {
3700            let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
3701
3702            #[cfg(feature = "logging")]
3703            error!(target: "stdout", "{err_msg}");
3704
3705            return Err(LlamaCoreError::Operation(err_msg.into()));
3706        }
3707    };
3708
3709    let mut chat_graphs = chat_graphs.lock().map_err(|e| {
3710        let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. Reason: {e}");
3711
3712        #[cfg(feature = "logging")]
3713        error!(target: "stdout", "{}", &err_msg);
3714
3715        LlamaCoreError::Operation(err_msg)
3716    })?;
3717
3718    match model_name {
3719        Some(model_name) => {
3720            match chat_graphs.contains_key(model_name) {
3721                true => {
3722                    let graph = chat_graphs.get_mut(model_name).unwrap();
3723                    // update metadata
3724                    set_tensor_data_u8(graph, 1, config.as_bytes())
3725                }
3726                false => match chat_graphs.iter_mut().next() {
3727                    Some((_, graph)) => {
3728                        // update metadata
3729                        set_tensor_data_u8(graph, 1, config.as_bytes())
3730                    }
3731                    None => {
3732                        let err_msg = "There is no model available in the chat graphs.";
3733
3734                        #[cfg(feature = "logging")]
3735                        error!(target: "stdout", "{}", &err_msg);
3736
3737                        Err(LlamaCoreError::Operation(err_msg.into()))
3738                    }
3739                },
3740            }
3741        }
3742        None => {
3743            match chat_graphs.iter_mut().next() {
3744                Some((_, graph)) => {
3745                    // update metadata
3746                    set_tensor_data_u8(graph, 1, config.as_bytes())
3747                }
3748                None => {
3749                    let err_msg = "There is no model available in the chat graphs.";
3750
3751                    #[cfg(feature = "logging")]
3752                    error!(target: "stdout", "{err_msg}");
3753
3754                    Err(LlamaCoreError::Operation(err_msg.into()))
3755                }
3756            }
3757        }
3758    }
3759}
3760
3761fn reset_model_metadata(model_name: Option<&String>) -> Result<(), LlamaCoreError> {
3762    // get metadata
3763    let metadata = get_model_metadata(model_name)?;
3764
3765    // update model with the original metadata
3766    update_model_metadata(model_name, &metadata)
3767}
3768
3769#[derive(Debug, Clone, Copy, PartialEq, Eq)]
3770enum ContextFullState {
3771    Message,
3772    Usage,
3773    Done,
3774    EndOfSequence,
3775}
3776
3777#[derive(Debug, Clone, Copy, PartialEq, Eq)]
3778enum StreamState {
3779    Usage,
3780    NoUsage,
3781    Done,
3782    EndOfSequence,
3783}
3784
3785#[derive(Debug, Clone, Copy, PartialEq, Eq)]
3786enum PromptTooLongState {
3787    Message,
3788    Usage,
3789    Done,
3790    EndOfSequence,
3791}
3792
3793struct ChatStream {
3794    id: String,
3795    model: Option<String>,
3796    include_usage: bool,
3797    context_full_state: ContextFullState,
3798    prompt_too_long_state: PromptTooLongState,
3799    stream_state: StreamState,
3800    cache: Option<VecDeque<String>>,
3801    is_waiting: bool,
3802    has_lock: bool,
3803    // Per-model lock fields
3804    model_lock: Arc<ModelStreamLock>,
3805    utf8_cache: Vec<u8>,
3806}
3807impl ChatStream {
3808    fn new(
3809        model: Option<String>,
3810        id: String,
3811        include_usage: bool,
3812        cache: Option<Vec<String>>,
3813    ) -> Result<Self, LlamaCoreError> {
3814        // Get or create the per-model lock
3815        let (resolved_model, model_lock) = match &model {
3816            Some(name) => (name.clone(), get_or_create_model_lock(name)?),
3817            None => get_default_model_lock()?,
3818        };
3819
3820        // Try to acquire the per-model lock
3821        let has_lock = model_lock.try_acquire();
3822
3823        #[cfg(feature = "logging")]
3824        if !has_lock {
3825            info!(target: "stdout", "Lock acquisition failed for model {}, creating with waiting status", &resolved_model);
3826        } else {
3827            info!(target: "stdout", "Lock acquired for model {}", &resolved_model);
3828        }
3829
3830        Ok(ChatStream {
3831            id,
3832            model: Some(resolved_model),
3833            include_usage,
3834            context_full_state: ContextFullState::Message,
3835            prompt_too_long_state: PromptTooLongState::Message,
3836            stream_state: if include_usage {
3837                StreamState::Usage
3838            } else {
3839                StreamState::NoUsage
3840            },
3841            cache: cache.map(VecDeque::from),
3842            is_waiting: !has_lock,
3843            has_lock,
3844            model_lock,
3845            utf8_cache: Vec::new(),
3846        })
3847    }
3848
3849    // Try to acquire lock, returns whether successful
3850    fn try_acquire_lock(&mut self) -> bool {
3851        if self.has_lock {
3852            return true;
3853        }
3854
3855        let acquired = self.model_lock.try_acquire();
3856
3857        if acquired {
3858            self.has_lock = true;
3859            self.is_waiting = false;
3860
3861            #[cfg(feature = "logging")]
3862            info!(target: "stdout", "ChatStream {} acquired lock for model {:?}", &self.id, &self.model);
3863        }
3864
3865        acquired
3866    }
3867}
3868impl Drop for ChatStream {
3869    fn drop(&mut self) {
3870        // Clean up is only needed if we have the lock or if stream was actually used
3871        if self.has_lock || (self.cache.is_none() && !self.is_waiting) {
3872            #[cfg(feature = "logging")]
3873            info!(target: "stdout", "Cleaning up context for ChatStream {}", &self.id);
3874
3875            match &self.model {
3876                Some(model_name) => {
3877                    match CHAT_GRAPHS.get() {
3878                        Some(chat_graphs) => {
3879                            match chat_graphs.lock() {
3880                                Ok(mut chat_graphs) => match chat_graphs.contains_key(model_name) {
3881                                    true => {
3882                                        let graph = chat_graphs.get_mut(model_name).unwrap();
3883
3884                                        // clean up the context
3885                                        if let Err(e) = graph.finish_single() {
3886                                            let err_msg = format!(
3887                                                "Failed to clean up the context. Reason: {e}"
3888                                            );
3889
3890                                            #[cfg(feature = "logging")]
3891                                            error!(target: "stdout", "{}", &err_msg);
3892
3893                                            #[cfg(not(feature = "logging"))]
3894                                            println!(
3895                                                "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3896                                                &err_msg
3897                                            );
3898                                        }
3899                                    }
3900                                    false => match chat_graphs.iter_mut().next() {
3901                                        Some((_, graph)) => {
3902                                            // clean up the context
3903                                            if let Err(e) = graph.finish_single() {
3904                                                let err_msg = format!(
3905                                                    "Failed to clean up the context. Reason: {e}"
3906                                                );
3907
3908                                                #[cfg(feature = "logging")]
3909                                                error!(target: "stdout", "{}", &err_msg);
3910
3911                                                #[cfg(not(feature = "logging"))]
3912                                                println!(
3913                                                    "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3914                                                    &err_msg
3915                                                );
3916                                            }
3917                                        }
3918                                        None => {
3919                                            let err_msg =
3920                                                "There is no model available in the chat graphs.";
3921
3922                                            #[cfg(feature = "logging")]
3923                                            error!(target: "stdout", "{}", &err_msg);
3924
3925                                            #[cfg(not(feature = "logging"))]
3926                                            println!(
3927                                                "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3928                                                &err_msg
3929                                            );
3930                                        }
3931                                    },
3932                                },
3933                                Err(e) => {
3934                                    let err_msg =
3935                                        format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
3936
3937                                    #[cfg(feature = "logging")]
3938                                    error!(target: "stdout", "{}", &err_msg);
3939
3940                                    #[cfg(not(feature = "logging"))]
3941                                    println!(
3942                                        "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3943                                        &err_msg
3944                                    );
3945                                }
3946                            }
3947                        }
3948                        None => {
3949                            let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
3950
3951                            #[cfg(feature = "logging")]
3952                            error!(target: "stdout", "{}", &err_msg);
3953
3954                            #[cfg(not(feature = "logging"))]
3955                            println!(
3956                                "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3957                                &err_msg
3958                            );
3959                        }
3960                    };
3961                }
3962                None => {
3963                    match CHAT_GRAPHS.get() {
3964                        Some(chat_graphs) => {
3965                            match chat_graphs.lock() {
3966                                Ok(mut chat_graphs) => match chat_graphs.iter_mut().next() {
3967                                    Some((_, graph)) => {
3968                                        // clean up the context
3969                                        if let Err(e) = graph.finish_single() {
3970                                            let err_msg = format!(
3971                                                "Failed to clean up the context. Reason: {e}"
3972                                            );
3973
3974                                            #[cfg(feature = "logging")]
3975                                            error!(target: "stdout", "{}", &err_msg);
3976
3977                                            #[cfg(not(feature = "logging"))]
3978                                            println!(
3979                                                "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3980                                                &err_msg
3981                                            );
3982                                        }
3983                                    }
3984                                    None => {
3985                                        let err_msg =
3986                                            "There is no model available in the chat graphs.";
3987
3988                                        #[cfg(feature = "logging")]
3989                                        error!(target: "stdout", "{err_msg}");
3990
3991                                        #[cfg(not(feature = "logging"))]
3992                                        println!(
3993                                            "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3994                                            err_msg
3995                                        );
3996                                    }
3997                                },
3998                                Err(e) => {
3999                                    let err_msg =
4000                                        format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
4001
4002                                    #[cfg(feature = "logging")]
4003                                    error!(target: "stdout", "{}", &err_msg);
4004
4005                                    #[cfg(not(feature = "logging"))]
4006                                    println!(
4007                                        "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
4008                                        &err_msg
4009                                    );
4010                                }
4011                            }
4012                        }
4013                        None => {
4014                            let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
4015
4016                            #[cfg(feature = "logging")]
4017                            error!(target: "stdout", "{}", &err_msg);
4018
4019                            #[cfg(not(feature = "logging"))]
4020                            println!(
4021                                "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
4022                                &err_msg
4023                            );
4024                        }
4025                    };
4026                }
4027            }
4028
4029            #[cfg(feature = "logging")]
4030            info!(target: "stdout", "Model context cleanup done!");
4031        }
4032
4033        // reset the model metadata
4034        if let Err(e) = reset_model_metadata(self.model.as_ref()) {
4035            let err_msg = format!("Fail to reset model metadata. Reason: {e}");
4036
4037            #[cfg(feature = "logging")]
4038            error!(target: "stdout", "{}", &err_msg);
4039
4040            #[cfg(not(feature = "logging"))]
4041            println!("[ERROR][llama_core] {}", &err_msg);
4042        }
4043        #[cfg(feature = "logging")]
4044        info!(target: "stdout", "Model metadata reset done!");
4045
4046        // Release the per-model lock and wake up waiting streams
4047        if self.has_lock {
4048            self.model_lock.release();
4049
4050            #[cfg(feature = "logging")]
4051            info!(target: "stdout", "Lock released for ChatStream {} (model {:?})", &self.id, &self.model);
4052        }
4053    }
4054}
4055impl futures::Stream for ChatStream {
4056    type Item = Result<String, LlamaCoreError>;
4057
4058    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
4059        let this = self.get_mut();
4060
4061        // If this is a waiting stream, try to acquire the per-model lock
4062        if this.is_waiting {
4063            if !this.try_acquire_lock() {
4064                // Register waker to be notified when the per-model lock becomes available
4065                this.model_lock.register_waker(cx.waker());
4066
4067                #[cfg(feature = "logging")]
4068                debug!(target: "stdout", "ChatStream {} waiting for model {:?}", &this.id, &this.model);
4069
4070                return Poll::Pending;
4071            }
4072
4073            #[cfg(feature = "logging")]
4074            info!(target: "stdout", "ChatStream {} acquired lock and is now active", &this.id);
4075            // If we got here, we successfully acquired the lock and can proceed
4076        }
4077
4078        // Ensure we still have the lock
4079        if !this.has_lock && !this.try_acquire_lock() {
4080            // Lost the lock, need to wait
4081            this.is_waiting = true;
4082
4083            // Register waker to be notified when the per-model lock is available
4084            this.model_lock.register_waker(cx.waker());
4085
4086            return Poll::Pending;
4087        }
4088
4089        if let Some(cache) = &mut this.cache {
4090            let x = cache.pop_front();
4091
4092            #[cfg(feature = "logging")]
4093            info!(target: "stdout", "Get the next item from the cache for ChatStream {}: {:?}", &this.id, &x);
4094
4095            match x {
4096                Some(x) => Poll::Ready(Some(Ok(x))),
4097                None => Poll::Ready(None),
4098            }
4099        } else {
4100            let res = compute_stream(
4101                this.model.clone(),
4102                this.id.clone(),
4103                this.include_usage,
4104                &mut this.prompt_too_long_state,
4105                &mut this.context_full_state,
4106                &mut this.stream_state,
4107                &mut this.utf8_cache,
4108            );
4109
4110            match res {
4111                Ok(x) => {
4112                    #[cfg(feature = "logging")]
4113                    info!(target: "stdout", "next item for ChatStream {}: {}", &this.id, &x);
4114
4115                    if x != "[GGML] End of sequence" && !x.is_empty() {
4116                        Poll::Ready(Some(Ok(x)))
4117                    } else {
4118                        // stopped
4119                        Poll::Ready(None)
4120                    }
4121                }
4122                Err(e) => Poll::Ready(Some(Err(e))),
4123            }
4124        }
4125    }
4126}
4127
4128fn compute_stream(
4129    model_name: Option<String>,
4130    id: String,
4131    include_usage: bool,
4132    prompt_too_long_state: &mut PromptTooLongState,
4133    context_full_state: &mut ContextFullState,
4134    stream_state: &mut StreamState,
4135    utf8_cache: &mut Vec<u8>,
4136) -> Result<String, LlamaCoreError> {
4137    #[cfg(feature = "logging")]
4138    info!(target: "stdout", "Computing stream chunk for ChatStream {}", &id);
4139
4140    #[cfg(feature = "logging")]
4141    debug!(target: "stdout", "prompt_too_long_state: {:?}", *prompt_too_long_state);
4142    #[cfg(feature = "logging")]
4143    debug!(target: "stdout", "context_full_state: {:?}", *context_full_state);
4144    #[cfg(feature = "logging")]
4145    debug!(target: "stdout", "stream_state: {:?}", *stream_state);
4146
4147    if *prompt_too_long_state == PromptTooLongState::EndOfSequence
4148        || *context_full_state == ContextFullState::EndOfSequence
4149        || *stream_state == StreamState::EndOfSequence
4150    {
4151        #[cfg(feature = "logging")]
4152        info!(target: "stdout", "Return the chat stream chunk!");
4153
4154        return Ok("[GGML] End of sequence".to_string());
4155    }
4156
4157    let chat_graphs = match CHAT_GRAPHS.get() {
4158        Some(chat_graphs) => chat_graphs,
4159        None => {
4160            let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
4161
4162            #[cfg(feature = "logging")]
4163            error!(target: "stdout", "{}", &err_msg);
4164
4165            return Err(LlamaCoreError::Operation(err_msg.into()));
4166        }
4167    };
4168
4169    // We're already holding the ChatStream lock, so we know we have exclusive access to the graph
4170    let mut chat_graphs = chat_graphs.lock().map_err(|e| {
4171        let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
4172
4173        #[cfg(feature = "logging")]
4174        error!(target: "stdout", "{}", &err_msg);
4175
4176        LlamaCoreError::Operation(err_msg)
4177    })?;
4178
4179    // Get the graph based on model name
4180    let res = match &model_name {
4181        Some(model_name) => {
4182            match chat_graphs.contains_key(model_name) {
4183                true => {
4184                    let graph = chat_graphs.get_mut(model_name).unwrap();
4185                    // compute
4186                    match graph.compute_single() {
4187                        Ok(_) => {
4188                            #[cfg(feature = "logging")]
4189                            debug!(target: "stdout", "Compute the chat stream chunk successfully.");
4190
4191                            // Process according to state
4192                            match stream_state {
4193                                StreamState::Usage | StreamState::NoUsage => {
4194                                    // Retrieve the output
4195                                    let output_buffer =
4196                                        get_output_buffer_single(graph, OUTPUT_TENSOR)?;
4197
4198                                    #[cfg(feature = "logging")]
4199                                    info!(target: "stdout", "retrieved the output buffer");
4200
4201                                    // decode the output buffer to a utf8 string
4202                                    let output = match String::from_utf8(output_buffer.clone()) {
4203                                        Ok(token) => token,
4204                                        Err(_) => {
4205                                            // Use the per-stream utf8_cache instead of global CACHED_UTF8_ENCODINGS
4206                                            utf8_cache.extend_from_slice(&output_buffer[..]);
4207
4208                                            match String::from_utf8(utf8_cache.clone()) {
4209                                                Ok(token) => {
4210                                                    utf8_cache.clear();
4211                                                    token
4212                                                }
4213                                                Err(e) => {
4214                                                    if utf8_cache.len() > 4 {
4215                                                        #[cfg(feature = "logging")]
4216                                                        error!(target: "stdout", "UTF-8 decode failed, cache too long: {e}");
4217                                                        #[cfg(feature = "logging")]
4218                                                        error!(target: "stdout", "The cached buffer: {:?}", &utf8_cache[..]);
4219                                                        utf8_cache.clear();
4220                                                    } else {
4221                                                        #[cfg(feature = "logging")]
4222                                                        warn!(target: "stdout", "UTF-8 decode incomplete: {e}");
4223                                                    }
4224                                                    String::new()
4225                                                }
4226                                            }
4227                                        }
4228                                    };
4229
4230                                    #[cfg(feature = "logging")]
4231                                    info!(target: "stdout", "decoded the output buffer");
4232
4233                                    let created = SystemTime::now()
4234                                        .duration_since(std::time::UNIX_EPOCH)
4235                                        .map_err(|e| {
4236                                            let err_msg = format!(
4237                                                "Failed to get the current time. Reason: {e}"
4238                                            );
4239
4240                                            #[cfg(feature = "logging")]
4241                                            error!(target: "stdout", "{}", &err_msg);
4242
4243                                            LlamaCoreError::Operation(err_msg)
4244                                        })?;
4245
4246                                    let chat_completion_chunk = ChatCompletionChunk {
4247                                        id,
4248                                        object: "chat.completion.chunk".to_string(),
4249                                        created: created.as_secs(),
4250                                        model: graph.name().to_owned(),
4251                                        system_fingerprint: "fp_44709d6fcb".to_string(),
4252                                        choices: vec![ChatCompletionChunkChoice {
4253                                            index: 0,
4254                                            delta: ChatCompletionChunkChoiceDelta {
4255                                                role: ChatCompletionRole::Assistant,
4256                                                content: Some(output),
4257                                                tool_calls: vec![],
4258                                            },
4259                                            logprobs: None,
4260                                            finish_reason: None,
4261                                        }],
4262                                        usage: None,
4263                                    };
4264
4265                                    #[cfg(feature = "logging")]
4266                                    info!(target: "stdout", "created chat completion chunk");
4267
4268                                    // serialize chat completion chunk
4269                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
4270                                        .map_err(|e| {
4271                                        let err_msg = format!(
4272                                            "Failed to serialize chat completion chunk. Reason: {e}"
4273                                        );
4274
4275                                        #[cfg(feature = "logging")]
4276                                        error!(target: "stdout", "{}", &err_msg);
4277
4278                                        LlamaCoreError::Operation(err_msg)
4279                                    })?;
4280
4281                                    Ok(format!("data: {chunk_str}\n\n"))
4282                                }
4283                                StreamState::Done => {
4284                                    *stream_state = StreamState::EndOfSequence;
4285
4286                                    Ok("data: [DONE]\n\n".to_string())
4287                                }
4288                                StreamState::EndOfSequence => {
4289                                    Ok("[GGML] End of sequence".to_string())
4290                                }
4291                            }
4292                        }
4293                        Err(wasmedge_wasi_nn::Error::BackendError(
4294                            wasmedge_wasi_nn::BackendError::EndOfSequence,
4295                        )) => {
4296                            #[cfg(feature = "logging")]
4297                            debug!(target: "stdout", "End of sequence");
4298
4299                            match stream_state {
4300                                StreamState::Usage => {
4301                                    *stream_state = StreamState::Done;
4302
4303                                    // retrieve the number of prompt and completion tokens
4304                                    let token_info = get_token_info_by_graph(graph)?;
4305
4306                                    let usage = Some(Usage {
4307                                        prompt_tokens: token_info.prompt_tokens,
4308                                        completion_tokens: token_info.completion_tokens,
4309                                        total_tokens: token_info.prompt_tokens
4310                                            + token_info.completion_tokens,
4311                                    });
4312
4313                                    #[cfg(feature = "logging")]
4314                                    info!(target: "stdout", "token_info: {} prompt tokens, {} completion tokens", token_info.prompt_tokens, token_info.completion_tokens);
4315
4316                                    let created = SystemTime::now()
4317                                        .duration_since(std::time::UNIX_EPOCH)
4318                                        .map_err(|e| {
4319                                            let err_msg = format!(
4320                                                "Failed to get the current time. Reason: {e}"
4321                                            );
4322
4323                                            #[cfg(feature = "logging")]
4324                                            error!(target: "stdout", "{}", &err_msg);
4325
4326                                            LlamaCoreError::Operation(err_msg)
4327                                        })?;
4328
4329                                    let chat_completion_chunk = ChatCompletionChunk {
4330                                        id,
4331                                        object: "chat.completion.chunk".to_string(),
4332                                        created: created.as_secs(),
4333                                        model: graph.name().to_owned(),
4334                                        system_fingerprint: "fp_44709d6fcb".to_string(),
4335                                        choices: vec![],
4336                                        usage,
4337                                    };
4338
4339                                    // serialize chat completion chunk
4340                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
4341                                        .map_err(|e| {
4342                                        let err_msg = format!(
4343                                            "Failed to serialize chat completion chunk. Reason: {e}"
4344                                        );
4345
4346                                        #[cfg(feature = "logging")]
4347                                        error!(target: "stdout", "{}", &err_msg);
4348
4349                                        LlamaCoreError::Operation(err_msg)
4350                                    })?;
4351
4352                                    Ok(format!("data: {chunk_str}\n\n"))
4353                                }
4354                                StreamState::Done | StreamState::NoUsage => {
4355                                    *stream_state = StreamState::EndOfSequence;
4356
4357                                    Ok("data: [DONE]\n\n".to_string())
4358                                }
4359                                StreamState::EndOfSequence => {
4360                                    Ok("[GGML] End of sequence".to_string())
4361                                }
4362                            }
4363                        }
4364                        Err(wasmedge_wasi_nn::Error::BackendError(
4365                            wasmedge_wasi_nn::BackendError::ContextFull,
4366                        )) => {
4367                            #[cfg(feature = "logging")]
4368                            debug!(target: "stdout", "Context full");
4369
4370                            match context_full_state {
4371                                ContextFullState::Message => {
4372                                    match include_usage {
4373                                        true => *context_full_state = ContextFullState::Usage,
4374                                        false => *context_full_state = ContextFullState::Done,
4375                                    }
4376
4377                                    let created = SystemTime::now()
4378                                        .duration_since(std::time::UNIX_EPOCH)
4379                                        .map_err(|e| {
4380                                            let err_msg = format!(
4381                                                "Failed to get the current time. Reason: {e}"
4382                                            );
4383
4384                                            #[cfg(feature = "logging")]
4385                                            error!(target: "stdout", "{}", &err_msg);
4386
4387                                            LlamaCoreError::Operation(err_msg)
4388                                        })?;
4389
4390                                    let chat_completion_chunk = ChatCompletionChunk {
4391                                        id,
4392                                        object: "chat.completion.chunk".to_string(),
4393                                        created: created.as_secs(),
4394                                        model: graph.name().to_owned(),
4395                                        system_fingerprint: "fp_44709d6fcb".to_string(),
4396                                        choices: vec![ChatCompletionChunkChoice {
4397                                            index: 0,
4398                                            delta: ChatCompletionChunkChoiceDelta {
4399                                                role: ChatCompletionRole::Assistant,
4400                                                content: Some(
4401                                                    "<|WASMEDGE-GGML-CONTEXT-FULL|>".to_string(),
4402                                                ),
4403                                                tool_calls: vec![],
4404                                            },
4405                                            logprobs: None,
4406                                            finish_reason: Some(FinishReason::length),
4407                                        }],
4408                                        usage: None,
4409                                    };
4410
4411                                    // serialize chat completion chunk
4412                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
4413                                        .map_err(|e| {
4414                                        let err_msg = format!(
4415                                            "Failed to serialize chat completion chunk. Reason: {e}"
4416                                        );
4417
4418                                        #[cfg(feature = "logging")]
4419                                        error!(target: "stdout", "{}", &err_msg);
4420
4421                                        LlamaCoreError::Operation(err_msg)
4422                                    })?;
4423
4424                                    Ok(format!("data: {chunk_str}\n\n"))
4425                                }
4426                                ContextFullState::Usage => {
4427                                    *context_full_state = ContextFullState::Done;
4428
4429                                    // retrieve the number of prompt and completion tokens
4430                                    let token_info = get_token_info_by_graph(graph)?;
4431
4432                                    let usage = Some(Usage {
4433                                        prompt_tokens: token_info.prompt_tokens,
4434                                        completion_tokens: token_info.completion_tokens,
4435                                        total_tokens: token_info.prompt_tokens
4436                                            + token_info.completion_tokens,
4437                                    });
4438
4439                                    let created = SystemTime::now()
4440                                        .duration_since(std::time::UNIX_EPOCH)
4441                                        .map_err(|e| {
4442                                            let err_msg = format!(
4443                                                "Failed to get the current time. Reason: {e}"
4444                                            );
4445
4446                                            #[cfg(feature = "logging")]
4447                                            error!(target: "stdout", "{}", &err_msg);
4448
4449                                            LlamaCoreError::Operation(err_msg)
4450                                        })?;
4451
4452                                    let chat_completion_chunk = ChatCompletionChunk {
4453                                        id,
4454                                        object: "chat.completion.chunk".to_string(),
4455                                        created: created.as_secs(),
4456                                        model: graph.name().to_owned(),
4457                                        system_fingerprint: "fp_44709d6fcb".to_string(),
4458                                        choices: vec![],
4459                                        usage,
4460                                    };
4461
4462                                    // serialize chat completion chunk
4463                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
4464                                        .map_err(|e| {
4465                                        let err_msg = format!(
4466                                            "Failed to serialize chat completion chunk. Reason: {e}"
4467                                        );
4468
4469                                        #[cfg(feature = "logging")]
4470                                        error!(target: "stdout", "{}", &err_msg);
4471
4472                                        LlamaCoreError::Operation(err_msg)
4473                                    })?;
4474
4475                                    Ok(format!("data: {chunk_str}\n\n"))
4476                                }
4477                                ContextFullState::Done => {
4478                                    *context_full_state = ContextFullState::EndOfSequence;
4479
4480                                    Ok("data: [DONE]\n\n".to_string())
4481                                }
4482                                ContextFullState::EndOfSequence => {
4483                                    Ok("[GGML] End of sequence".to_string())
4484                                }
4485                            }
4486                        }
4487                        Err(wasmedge_wasi_nn::Error::BackendError(
4488                            wasmedge_wasi_nn::BackendError::PromptTooLong,
4489                        )) => {
4490                            #[cfg(feature = "logging")]
4491                            debug!(target: "stdout", "Prompt too long");
4492
4493                            match prompt_too_long_state {
4494                                PromptTooLongState::Message => {
4495                                    match include_usage {
4496                                        true => *prompt_too_long_state = PromptTooLongState::Usage,
4497                                        false => *prompt_too_long_state = PromptTooLongState::Done,
4498                                    }
4499
4500                                    let created = SystemTime::now()
4501                                        .duration_since(std::time::UNIX_EPOCH)
4502                                        .map_err(|e| {
4503                                            let err_msg = format!(
4504                                                "Failed to get the current time. Reason: {e}"
4505                                            );
4506
4507                                            #[cfg(feature = "logging")]
4508                                            error!(target: "stdout", "{}", &err_msg);
4509
4510                                            LlamaCoreError::Operation(err_msg)
4511                                        })?;
4512
4513                                    let chat_completion_chunk = ChatCompletionChunk {
4514                                        id,
4515                                        object: "chat.completion.chunk".to_string(),
4516                                        created: created.as_secs(),
4517                                        model: graph.name().to_owned(),
4518                                        system_fingerprint: "fp_44709d6fcb".to_string(),
4519                                        choices: vec![ChatCompletionChunkChoice {
4520                                            index: 0,
4521                                            delta: ChatCompletionChunkChoiceDelta {
4522                                                role: ChatCompletionRole::Assistant,
4523                                                content: None,
4524                                                tool_calls: vec![],
4525                                            },
4526                                            logprobs: None,
4527                                            finish_reason: Some(FinishReason::length),
4528                                        }],
4529                                        usage: None,
4530                                    };
4531
4532                                    // serialize chat completion chunk
4533                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
4534                                        .map_err(|e| {
4535                                        let err_msg = format!(
4536                                            "Failed to serialize chat completion chunk. Reason: {e}"
4537                                        );
4538
4539                                        #[cfg(feature = "logging")]
4540                                        error!(target: "stdout", "{}", &err_msg);
4541
4542                                        LlamaCoreError::Operation(err_msg)
4543                                    })?;
4544
4545                                    Ok(format!("data: {chunk_str}\n\n"))
4546                                }
4547                                PromptTooLongState::Usage => {
4548                                    *prompt_too_long_state = PromptTooLongState::Done;
4549
4550                                    // retrieve the number of prompt and completion tokens
4551                                    let token_info = get_token_info_by_graph(graph)?;
4552
4553                                    let usage = Some(Usage {
4554                                        prompt_tokens: token_info.prompt_tokens,
4555                                        completion_tokens: token_info.completion_tokens,
4556                                        total_tokens: token_info.prompt_tokens
4557                                            + token_info.completion_tokens,
4558                                    });
4559
4560                                    let created = SystemTime::now()
4561                                        .duration_since(std::time::UNIX_EPOCH)
4562                                        .map_err(|e| {
4563                                            let err_msg = format!(
4564                                                "Failed to get the current time. Reason: {e}"
4565                                            );
4566
4567                                            #[cfg(feature = "logging")]
4568                                            error!(target: "stdout", "{}", &err_msg);
4569
4570                                            LlamaCoreError::Operation(err_msg)
4571                                        })?;
4572
4573                                    let chat_completion_chunk = ChatCompletionChunk {
4574                                        id,
4575                                        object: "chat.completion.chunk".to_string(),
4576                                        created: created.as_secs(),
4577                                        model: graph.name().to_owned(),
4578                                        system_fingerprint: "fp_44709d6fcb".to_string(),
4579                                        choices: vec![],
4580                                        usage,
4581                                    };
4582
4583                                    // serialize chat completion chunk
4584                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
4585                                        .map_err(|e| {
4586                                        let err_msg = format!(
4587                                            "Failed to serialize chat completion chunk. Reason: {e}"
4588                                        );
4589
4590                                        #[cfg(feature = "logging")]
4591                                        error!(target: "stdout", "{}", &err_msg);
4592
4593                                        LlamaCoreError::Operation(err_msg)
4594                                    })?;
4595
4596                                    Ok(format!("data: {chunk_str}\n\n"))
4597                                }
4598                                PromptTooLongState::Done => {
4599                                    *prompt_too_long_state = PromptTooLongState::EndOfSequence;
4600
4601                                    Ok("data: [DONE]\n\n".to_string())
4602                                }
4603                                PromptTooLongState::EndOfSequence => {
4604                                    Ok("[GGML] End of sequence".to_string())
4605                                }
4606                            }
4607                        }
4608                        Err(e) => {
4609                            let err_msg =
4610                                format!("Failed to compute the chat completion. Reason: {e}");
4611
4612                            #[cfg(feature = "logging")]
4613                            error!(target: "stdout", "{}", &err_msg);
4614
4615                            Err(LlamaCoreError::Backend(BackendError::ComputeSingle(
4616                                err_msg,
4617                            )))
4618                        }
4619                    }
4620                }
4621                false => {
4622                    match chat_graphs.iter_mut().next() {
4623                        Some((_, graph)) => {
4624                            // compute
4625                            match graph.compute_single() {
4626                                Ok(_) => {
4627                                    #[cfg(feature = "logging")]
4628                                    debug!(target: "stdout", "Compute the chat stream chunk successfully.");
4629
4630                                    match stream_state {
4631                                        StreamState::Usage | StreamState::NoUsage => {
4632                                            // Retrieve the output
4633                                            let output_buffer =
4634                                                get_output_buffer_single(graph, OUTPUT_TENSOR)?;
4635
4636                                            #[cfg(feature = "logging")]
4637                                            info!(target: "stdout", "retrieved the output buffer");
4638
4639                                            // decode the output buffer to a utf8 string
4640                                            let output = match String::from_utf8(
4641                                                output_buffer.clone(),
4642                                            ) {
4643                                                Ok(token) => token,
4644                                                Err(_) => {
4645                                                    // Use the per-stream utf8_cache instead of global CACHED_UTF8_ENCODINGS
4646                                                    utf8_cache
4647                                                        .extend_from_slice(&output_buffer[..]);
4648
4649                                                    match String::from_utf8(utf8_cache.clone()) {
4650                                                        Ok(token) => {
4651                                                            utf8_cache.clear();
4652                                                            token
4653                                                        }
4654                                                        Err(e) => {
4655                                                            if utf8_cache.len() > 4 {
4656                                                                #[cfg(feature = "logging")]
4657                                                                error!(target: "stdout", "UTF-8 decode failed, cache too long: {e}");
4658                                                                #[cfg(feature = "logging")]
4659                                                                error!(target: "stdout", "The cached buffer: {:?}", &utf8_cache[..]);
4660                                                                utf8_cache.clear();
4661                                                            } else {
4662                                                                #[cfg(feature = "logging")]
4663                                                                warn!(target: "stdout", "UTF-8 decode incomplete: {e}");
4664                                                            }
4665                                                            String::new()
4666                                                        }
4667                                                    }
4668                                                }
4669                                            };
4670
4671                                            #[cfg(feature = "logging")]
4672                                            info!(target: "stdout", "decoded the output buffer");
4673
4674                                            let created = SystemTime::now()
4675                                                .duration_since(std::time::UNIX_EPOCH)
4676                                                .map_err(|e| {
4677                                                    let err_msg = format!(
4678                                                "Failed to get the current time. Reason: {e}"
4679                                            );
4680
4681                                                    #[cfg(feature = "logging")]
4682                                                    error!(target: "stdout", "{}", &err_msg);
4683
4684                                                    LlamaCoreError::Operation(err_msg)
4685                                                })?;
4686
4687                                            let chat_completion_chunk = ChatCompletionChunk {
4688                                                id,
4689                                                object: "chat.completion.chunk".to_string(),
4690                                                created: created.as_secs(),
4691                                                model: graph.name().to_owned(),
4692                                                system_fingerprint: "fp_44709d6fcb".to_string(),
4693                                                choices: vec![ChatCompletionChunkChoice {
4694                                                    index: 0,
4695                                                    delta: ChatCompletionChunkChoiceDelta {
4696                                                        role: ChatCompletionRole::Assistant,
4697                                                        content: Some(output),
4698                                                        tool_calls: vec![],
4699                                                    },
4700                                                    logprobs: None,
4701                                                    finish_reason: None,
4702                                                }],
4703                                                usage: None,
4704                                            };
4705
4706                                            #[cfg(feature = "logging")]
4707                                            info!(target: "stdout", "created chat completion chunk");
4708
4709                                            // serialize chat completion chunk
4710                                            let chunk_str =
4711                                                serde_json::to_string(&chat_completion_chunk)
4712                                                    .map_err(|e| {
4713                                                        let err_msg = format!(
4714                                            "Failed to serialize chat completion chunk. Reason: {e}"
4715                                        );
4716
4717                                                        #[cfg(feature = "logging")]
4718                                                        error!(target: "stdout", "{}", &err_msg);
4719
4720                                                        LlamaCoreError::Operation(err_msg)
4721                                                    })?;
4722
4723                                            Ok(format!("data: {chunk_str}\n\n"))
4724                                        }
4725                                        StreamState::Done => {
4726                                            *stream_state = StreamState::EndOfSequence;
4727
4728                                            Ok("data: [DONE]\n\n".to_string())
4729                                        }
4730                                        StreamState::EndOfSequence => {
4731                                            Ok("[GGML] End of sequence".to_string())
4732                                        }
4733                                    }
4734                                }
4735                                Err(wasmedge_wasi_nn::Error::BackendError(
4736                                    wasmedge_wasi_nn::BackendError::EndOfSequence,
4737                                )) => {
4738                                    #[cfg(feature = "logging")]
4739                                    debug!(target: "stdout", "End of sequence");
4740
4741                                    match stream_state {
4742                                        StreamState::Usage => {
4743                                            *stream_state = StreamState::Done;
4744
4745                                            // retrieve the number of prompt and completion tokens
4746                                            let token_info = get_token_info_by_graph(graph)?;
4747
4748                                            let usage = Some(Usage {
4749                                                prompt_tokens: token_info.prompt_tokens,
4750                                                completion_tokens: token_info.completion_tokens,
4751                                                total_tokens: token_info.prompt_tokens
4752                                                    + token_info.completion_tokens,
4753                                            });
4754
4755                                            #[cfg(feature = "logging")]
4756                                            info!(target: "stdout", "token_info: {} prompt tokens, {} completion tokens", token_info.prompt_tokens, token_info.completion_tokens);
4757
4758                                            let created = SystemTime::now()
4759                                                .duration_since(std::time::UNIX_EPOCH)
4760                                                .map_err(|e| {
4761                                                    let err_msg = format!(
4762                                                "Failed to get the current time. Reason: {e}"
4763                                            );
4764
4765                                                    #[cfg(feature = "logging")]
4766                                                    error!(target: "stdout", "{}", &err_msg);
4767
4768                                                    LlamaCoreError::Operation(err_msg)
4769                                                })?;
4770
4771                                            let chat_completion_chunk = ChatCompletionChunk {
4772                                                id,
4773                                                object: "chat.completion.chunk".to_string(),
4774                                                created: created.as_secs(),
4775                                                model: graph.name().to_owned(),
4776                                                system_fingerprint: "fp_44709d6fcb".to_string(),
4777                                                choices: vec![],
4778                                                usage,
4779                                            };
4780
4781                                            // serialize chat completion chunk
4782                                            let chunk_str =
4783                                                serde_json::to_string(&chat_completion_chunk)
4784                                                    .map_err(|e| {
4785                                                        let err_msg = format!(
4786                                            "Failed to serialize chat completion chunk. Reason: {e}"
4787                                        );
4788
4789                                                        #[cfg(feature = "logging")]
4790                                                        error!(target: "stdout", "{}", &err_msg);
4791
4792                                                        LlamaCoreError::Operation(err_msg)
4793                                                    })?;
4794
4795                                            Ok(format!("data: {chunk_str}\n\n"))
4796                                        }
4797                                        StreamState::Done | StreamState::NoUsage => {
4798                                            *stream_state = StreamState::EndOfSequence;
4799
4800                                            Ok("data: [DONE]\n\n".to_string())
4801                                        }
4802                                        StreamState::EndOfSequence => {
4803                                            Ok("[GGML] End of sequence".to_string())
4804                                        }
4805                                    }
4806                                }
4807                                Err(wasmedge_wasi_nn::Error::BackendError(
4808                                    wasmedge_wasi_nn::BackendError::ContextFull,
4809                                )) => {
4810                                    #[cfg(feature = "logging")]
4811                                    debug!(target: "stdout", "Context full");
4812
4813                                    match context_full_state {
4814                                        ContextFullState::Message => {
4815                                            match include_usage {
4816                                                true => {
4817                                                    *context_full_state = ContextFullState::Usage
4818                                                }
4819                                                false => {
4820                                                    *context_full_state = ContextFullState::Done
4821                                                }
4822                                            }
4823
4824                                            let created = SystemTime::now()
4825                                                .duration_since(std::time::UNIX_EPOCH)
4826                                                .map_err(|e| {
4827                                                    let err_msg = format!(
4828                                                "Failed to get the current time. Reason: {e}"
4829                                            );
4830
4831                                                    #[cfg(feature = "logging")]
4832                                                    error!(target: "stdout", "{}", &err_msg);
4833
4834                                                    LlamaCoreError::Operation(err_msg)
4835                                                })?;
4836
4837                                            let chat_completion_chunk = ChatCompletionChunk {
4838                                                id,
4839                                                object: "chat.completion.chunk".to_string(),
4840                                                created: created.as_secs(),
4841                                                model: graph.name().to_owned(),
4842                                                system_fingerprint: "fp_44709d6fcb".to_string(),
4843                                                choices: vec![ChatCompletionChunkChoice {
4844                                                    index: 0,
4845                                                    delta: ChatCompletionChunkChoiceDelta {
4846                                                        role: ChatCompletionRole::Assistant,
4847                                                        content: Some(
4848                                                            "<|WASMEDGE-GGML-CONTEXT-FULL|>"
4849                                                                .to_string(),
4850                                                        ),
4851                                                        tool_calls: vec![],
4852                                                    },
4853                                                    logprobs: None,
4854                                                    finish_reason: Some(FinishReason::length),
4855                                                }],
4856                                                usage: None,
4857                                            };
4858
4859                                            // serialize chat completion chunk
4860                                            let chunk_str =
4861                                                serde_json::to_string(&chat_completion_chunk)
4862                                                    .map_err(|e| {
4863                                                        let err_msg = format!(
4864                                            "Failed to serialize chat completion chunk. Reason: {e}"
4865                                        );
4866
4867                                                        #[cfg(feature = "logging")]
4868                                                        error!(target: "stdout", "{}", &err_msg);
4869
4870                                                        LlamaCoreError::Operation(err_msg)
4871                                                    })?;
4872
4873                                            Ok(format!("data: {chunk_str}\n\n"))
4874                                        }
4875                                        ContextFullState::Usage => {
4876                                            *context_full_state = ContextFullState::Done;
4877
4878                                            // retrieve the number of prompt and completion tokens
4879                                            let token_info = get_token_info_by_graph(graph)?;
4880
4881                                            let usage = Some(Usage {
4882                                                prompt_tokens: token_info.prompt_tokens,
4883                                                completion_tokens: token_info.completion_tokens,
4884                                                total_tokens: token_info.prompt_tokens
4885                                                    + token_info.completion_tokens,
4886                                            });
4887
4888                                            let created = SystemTime::now()
4889                                                .duration_since(std::time::UNIX_EPOCH)
4890                                                .map_err(|e| {
4891                                                    let err_msg = format!(
4892                                                "Failed to get the current time. Reason: {e}"
4893                                            );
4894
4895                                                    #[cfg(feature = "logging")]
4896                                                    error!(target: "stdout", "{}", &err_msg);
4897
4898                                                    LlamaCoreError::Operation(err_msg)
4899                                                })?;
4900
4901                                            let chat_completion_chunk = ChatCompletionChunk {
4902                                                id,
4903                                                object: "chat.completion.chunk".to_string(),
4904                                                created: created.as_secs(),
4905                                                model: graph.name().to_owned(),
4906                                                system_fingerprint: "fp_44709d6fcb".to_string(),
4907                                                choices: vec![],
4908                                                usage,
4909                                            };
4910
4911                                            // serialize chat completion chunk
4912                                            let chunk_str =
4913                                                serde_json::to_string(&chat_completion_chunk)
4914                                                    .map_err(|e| {
4915                                                        let err_msg = format!(
4916                                            "Failed to serialize chat completion chunk. Reason: {e}"
4917                                        );
4918
4919                                                        #[cfg(feature = "logging")]
4920                                                        error!(target: "stdout", "{}", &err_msg);
4921
4922                                                        LlamaCoreError::Operation(err_msg)
4923                                                    })?;
4924
4925                                            Ok(format!("data: {chunk_str}\n\n"))
4926                                        }
4927                                        ContextFullState::Done => {
4928                                            *context_full_state = ContextFullState::EndOfSequence;
4929
4930                                            Ok("data: [DONE]\n\n".to_string())
4931                                        }
4932                                        ContextFullState::EndOfSequence => {
4933                                            Ok("[GGML] End of sequence".to_string())
4934                                        }
4935                                    }
4936                                }
4937                                Err(wasmedge_wasi_nn::Error::BackendError(
4938                                    wasmedge_wasi_nn::BackendError::PromptTooLong,
4939                                )) => {
4940                                    #[cfg(feature = "logging")]
4941                                    debug!(target: "stdout", "Prompt too long");
4942
4943                                    match prompt_too_long_state {
4944                                        PromptTooLongState::Message => {
4945                                            match include_usage {
4946                                                true => {
4947                                                    *prompt_too_long_state =
4948                                                        PromptTooLongState::Usage
4949                                                }
4950                                                false => {
4951                                                    *prompt_too_long_state =
4952                                                        PromptTooLongState::Done
4953                                                }
4954                                            }
4955
4956                                            let created = SystemTime::now()
4957                                                .duration_since(std::time::UNIX_EPOCH)
4958                                                .map_err(|e| {
4959                                                    let err_msg = format!(
4960                                                "Failed to get the current time. Reason: {e}"
4961                                            );
4962
4963                                                    #[cfg(feature = "logging")]
4964                                                    error!(target: "stdout", "{}", &err_msg);
4965
4966                                                    LlamaCoreError::Operation(err_msg)
4967                                                })?;
4968
4969                                            let chat_completion_chunk = ChatCompletionChunk {
4970                                                id,
4971                                                object: "chat.completion.chunk".to_string(),
4972                                                created: created.as_secs(),
4973                                                model: graph.name().to_owned(),
4974                                                system_fingerprint: "fp_44709d6fcb".to_string(),
4975                                                choices: vec![ChatCompletionChunkChoice {
4976                                                    index: 0,
4977                                                    delta: ChatCompletionChunkChoiceDelta {
4978                                                        role: ChatCompletionRole::Assistant,
4979                                                        content: None,
4980                                                        tool_calls: vec![],
4981                                                    },
4982                                                    logprobs: None,
4983                                                    finish_reason: Some(FinishReason::length),
4984                                                }],
4985                                                usage: None,
4986                                            };
4987
4988                                            // serialize chat completion chunk
4989                                            let chunk_str =
4990                                                serde_json::to_string(&chat_completion_chunk)
4991                                                    .map_err(|e| {
4992                                                        let err_msg = format!(
4993                                            "Failed to serialize chat completion chunk. Reason: {e}"
4994                                        );
4995
4996                                                        #[cfg(feature = "logging")]
4997                                                        error!(target: "stdout", "{}", &err_msg);
4998
4999                                                        LlamaCoreError::Operation(err_msg)
5000                                                    })?;
5001
5002                                            Ok(format!("data: {chunk_str}\n\n"))
5003                                        }
5004                                        PromptTooLongState::Usage => {
5005                                            *prompt_too_long_state = PromptTooLongState::Done;
5006
5007                                            // retrieve the number of prompt and completion tokens
5008                                            let token_info = get_token_info_by_graph(graph)?;
5009
5010                                            let usage = Some(Usage {
5011                                                prompt_tokens: token_info.prompt_tokens,
5012                                                completion_tokens: token_info.completion_tokens,
5013                                                total_tokens: token_info.prompt_tokens
5014                                                    + token_info.completion_tokens,
5015                                            });
5016
5017                                            let created = SystemTime::now()
5018                                                .duration_since(std::time::UNIX_EPOCH)
5019                                                .map_err(|e| {
5020                                                    let err_msg = format!(
5021                                                "Failed to get the current time. Reason: {e}"
5022                                            );
5023
5024                                                    #[cfg(feature = "logging")]
5025                                                    error!(target: "stdout", "{}", &err_msg);
5026
5027                                                    LlamaCoreError::Operation(err_msg)
5028                                                })?;
5029
5030                                            let chat_completion_chunk = ChatCompletionChunk {
5031                                                id,
5032                                                object: "chat.completion.chunk".to_string(),
5033                                                created: created.as_secs(),
5034                                                model: graph.name().to_owned(),
5035                                                system_fingerprint: "fp_44709d6fcb".to_string(),
5036                                                choices: vec![],
5037                                                usage,
5038                                            };
5039
5040                                            // serialize chat completion chunk
5041                                            let chunk_str =
5042                                                serde_json::to_string(&chat_completion_chunk)
5043                                                    .map_err(|e| {
5044                                                        let err_msg = format!(
5045                                            "Failed to serialize chat completion chunk. Reason: {e}"
5046                                        );
5047
5048                                                        #[cfg(feature = "logging")]
5049                                                        error!(target: "stdout", "{}", &err_msg);
5050
5051                                                        LlamaCoreError::Operation(err_msg)
5052                                                    })?;
5053
5054                                            Ok(format!("data: {chunk_str}\n\n"))
5055                                        }
5056                                        PromptTooLongState::Done => {
5057                                            *prompt_too_long_state =
5058                                                PromptTooLongState::EndOfSequence;
5059
5060                                            Ok("data: [DONE]\n\n".to_string())
5061                                        }
5062                                        PromptTooLongState::EndOfSequence => {
5063                                            Ok("[GGML] End of sequence".to_string())
5064                                        }
5065                                    }
5066                                }
5067                                Err(e) => {
5068                                    let err_msg = format!(
5069                                        "Failed to compute the chat completion. Reason: {e}"
5070                                    );
5071
5072                                    #[cfg(feature = "logging")]
5073                                    error!(target: "stdout", "{}", &err_msg);
5074
5075                                    Err(LlamaCoreError::Backend(BackendError::ComputeSingle(
5076                                        err_msg,
5077                                    )))
5078                                }
5079                            }
5080                        }
5081                        None => {
5082                            let err_msg = "There is no model available in the chat graphs.";
5083
5084                            #[cfg(feature = "logging")]
5085                            error!(target: "stdout", "{}", &err_msg);
5086
5087                            Err(LlamaCoreError::Operation(err_msg.into()))
5088                        }
5089                    }
5090                }
5091            }
5092        }
5093        None => {
5094            match chat_graphs.iter_mut().next() {
5095                Some((_, graph)) => {
5096                    // compute
5097                    match graph.compute_single() {
5098                        Ok(_) => {
5099                            #[cfg(feature = "logging")]
5100                            debug!(target: "stdout", "Compute the chat stream chunk successfully.");
5101
5102                            match stream_state {
5103                                StreamState::Usage | StreamState::NoUsage => {
5104                                    // Retrieve the output
5105                                    let output_buffer =
5106                                        get_output_buffer_single(graph, OUTPUT_TENSOR)?;
5107
5108                                    #[cfg(feature = "logging")]
5109                                    info!(target: "stdout", "retrieved the output buffer");
5110
5111                                    // decode the output buffer to a utf8 string
5112                                    let output = match String::from_utf8(output_buffer.clone()) {
5113                                        Ok(token) => token,
5114                                        Err(_) => {
5115                                            // Use the per-stream utf8_cache instead of global CACHED_UTF8_ENCODINGS
5116                                            utf8_cache.extend_from_slice(&output_buffer[..]);
5117
5118                                            match String::from_utf8(utf8_cache.clone()) {
5119                                                Ok(token) => {
5120                                                    utf8_cache.clear();
5121                                                    token
5122                                                }
5123                                                Err(e) => {
5124                                                    if utf8_cache.len() > 4 {
5125                                                        #[cfg(feature = "logging")]
5126                                                        error!(target: "stdout", "UTF-8 decode failed, cache too long: {e}");
5127                                                        #[cfg(feature = "logging")]
5128                                                        error!(target: "stdout", "The cached buffer: {:?}", &utf8_cache[..]);
5129                                                        utf8_cache.clear();
5130                                                    } else {
5131                                                        #[cfg(feature = "logging")]
5132                                                        warn!(target: "stdout", "UTF-8 decode incomplete: {e}");
5133                                                    }
5134                                                    String::new()
5135                                                }
5136                                            }
5137                                        }
5138                                    };
5139
5140                                    #[cfg(feature = "logging")]
5141                                    info!(target: "stdout", "decoded the output buffer");
5142
5143                                    let created = SystemTime::now()
5144                                        .duration_since(std::time::UNIX_EPOCH)
5145                                        .map_err(|e| {
5146                                            let err_msg = format!(
5147                                                "Failed to get the current time. Reason: {e}"
5148                                            );
5149
5150                                            #[cfg(feature = "logging")]
5151                                            error!(target: "stdout", "{}", &err_msg);
5152
5153                                            LlamaCoreError::Operation(err_msg)
5154                                        })?;
5155
5156                                    let chat_completion_chunk = ChatCompletionChunk {
5157                                        id,
5158                                        object: "chat.completion.chunk".to_string(),
5159                                        created: created.as_secs(),
5160                                        model: graph.name().to_owned(),
5161                                        system_fingerprint: "fp_44709d6fcb".to_string(),
5162                                        choices: vec![ChatCompletionChunkChoice {
5163                                            index: 0,
5164                                            delta: ChatCompletionChunkChoiceDelta {
5165                                                role: ChatCompletionRole::Assistant,
5166                                                content: Some(output),
5167                                                tool_calls: vec![],
5168                                            },
5169                                            logprobs: None,
5170                                            finish_reason: None,
5171                                        }],
5172                                        usage: None,
5173                                    };
5174
5175                                    #[cfg(feature = "logging")]
5176                                    info!(target: "stdout", "created chat completion chunk");
5177
5178                                    // serialize chat completion chunk
5179                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
5180                                        .map_err(|e| {
5181                                        let err_msg = format!(
5182                                            "Failed to serialize chat completion chunk. Reason: {e}"
5183                                        );
5184
5185                                        #[cfg(feature = "logging")]
5186                                        error!(target: "stdout", "{}", &err_msg);
5187
5188                                        LlamaCoreError::Operation(err_msg)
5189                                    })?;
5190
5191                                    Ok(format!("data: {chunk_str}\n\n"))
5192                                }
5193                                StreamState::Done => {
5194                                    *stream_state = StreamState::EndOfSequence;
5195
5196                                    Ok("data: [DONE]\n\n".to_string())
5197                                }
5198                                StreamState::EndOfSequence => {
5199                                    Ok("[GGML] End of sequence".to_string())
5200                                }
5201                            }
5202                        }
5203                        Err(wasmedge_wasi_nn::Error::BackendError(
5204                            wasmedge_wasi_nn::BackendError::EndOfSequence,
5205                        )) => {
5206                            #[cfg(feature = "logging")]
5207                            debug!(target: "stdout", "End of sequence");
5208
5209                            match stream_state {
5210                                StreamState::Usage => {
5211                                    *stream_state = StreamState::Done;
5212
5213                                    // retrieve the number of prompt and completion tokens
5214                                    let token_info = get_token_info_by_graph(graph)?;
5215
5216                                    let usage = Some(Usage {
5217                                        prompt_tokens: token_info.prompt_tokens,
5218                                        completion_tokens: token_info.completion_tokens,
5219                                        total_tokens: token_info.prompt_tokens
5220                                            + token_info.completion_tokens,
5221                                    });
5222
5223                                    #[cfg(feature = "logging")]
5224                                    info!(target: "stdout", "token_info: {} prompt tokens, {} completion tokens", token_info.prompt_tokens, token_info.completion_tokens);
5225
5226                                    let created = SystemTime::now()
5227                                        .duration_since(std::time::UNIX_EPOCH)
5228                                        .map_err(|e| {
5229                                            let err_msg = format!(
5230                                                "Failed to get the current time. Reason: {e}"
5231                                            );
5232
5233                                            #[cfg(feature = "logging")]
5234                                            error!(target: "stdout", "{}", &err_msg);
5235
5236                                            LlamaCoreError::Operation(err_msg)
5237                                        })?;
5238
5239                                    let chat_completion_chunk = ChatCompletionChunk {
5240                                        id,
5241                                        object: "chat.completion.chunk".to_string(),
5242                                        created: created.as_secs(),
5243                                        model: graph.name().to_owned(),
5244                                        system_fingerprint: "fp_44709d6fcb".to_string(),
5245                                        choices: vec![],
5246                                        usage,
5247                                    };
5248
5249                                    // serialize chat completion chunk
5250                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
5251                                        .map_err(|e| {
5252                                        let err_msg = format!(
5253                                            "Failed to serialize chat completion chunk. Reason: {e}"
5254                                        );
5255
5256                                        #[cfg(feature = "logging")]
5257                                        error!(target: "stdout", "{}", &err_msg);
5258
5259                                        LlamaCoreError::Operation(err_msg)
5260                                    })?;
5261
5262                                    Ok(format!("data: {chunk_str}\n\n"))
5263                                }
5264                                StreamState::Done | StreamState::NoUsage => {
5265                                    *stream_state = StreamState::EndOfSequence;
5266
5267                                    Ok("data: [DONE]\n\n".to_string())
5268                                }
5269                                StreamState::EndOfSequence => {
5270                                    Ok("[GGML] End of sequence".to_string())
5271                                }
5272                            }
5273                        }
5274                        Err(wasmedge_wasi_nn::Error::BackendError(
5275                            wasmedge_wasi_nn::BackendError::ContextFull,
5276                        )) => {
5277                            #[cfg(feature = "logging")]
5278                            debug!(target: "stdout", "Context full");
5279
5280                            match context_full_state {
5281                                ContextFullState::Message => {
5282                                    match include_usage {
5283                                        true => *context_full_state = ContextFullState::Usage,
5284                                        false => *context_full_state = ContextFullState::Done,
5285                                    }
5286
5287                                    let created = SystemTime::now()
5288                                        .duration_since(std::time::UNIX_EPOCH)
5289                                        .map_err(|e| {
5290                                            let err_msg = format!(
5291                                                "Failed to get the current time. Reason: {e}"
5292                                            );
5293
5294                                            #[cfg(feature = "logging")]
5295                                            error!(target: "stdout", "{}", &err_msg);
5296
5297                                            LlamaCoreError::Operation(err_msg)
5298                                        })?;
5299
5300                                    let chat_completion_chunk = ChatCompletionChunk {
5301                                        id,
5302                                        object: "chat.completion.chunk".to_string(),
5303                                        created: created.as_secs(),
5304                                        model: graph.name().to_owned(),
5305                                        system_fingerprint: "fp_44709d6fcb".to_string(),
5306                                        choices: vec![ChatCompletionChunkChoice {
5307                                            index: 0,
5308                                            delta: ChatCompletionChunkChoiceDelta {
5309                                                role: ChatCompletionRole::Assistant,
5310                                                content: Some(
5311                                                    "<|WASMEDGE-GGML-CONTEXT-FULL|>".to_string(),
5312                                                ),
5313                                                tool_calls: vec![],
5314                                            },
5315                                            logprobs: None,
5316                                            finish_reason: Some(FinishReason::length),
5317                                        }],
5318                                        usage: None,
5319                                    };
5320
5321                                    // serialize chat completion chunk
5322                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
5323                                        .map_err(|e| {
5324                                        let err_msg = format!(
5325                                            "Failed to serialize chat completion chunk. Reason: {e}"
5326                                        );
5327
5328                                        #[cfg(feature = "logging")]
5329                                        error!(target: "stdout", "{}", &err_msg);
5330
5331                                        LlamaCoreError::Operation(err_msg)
5332                                    })?;
5333
5334                                    Ok(format!("data: {chunk_str}\n\n"))
5335                                }
5336                                ContextFullState::Usage => {
5337                                    *context_full_state = ContextFullState::Done;
5338
5339                                    // retrieve the number of prompt and completion tokens
5340                                    let token_info = get_token_info_by_graph(graph)?;
5341
5342                                    let usage = Some(Usage {
5343                                        prompt_tokens: token_info.prompt_tokens,
5344                                        completion_tokens: token_info.completion_tokens,
5345                                        total_tokens: token_info.prompt_tokens
5346                                            + token_info.completion_tokens,
5347                                    });
5348
5349                                    let created = SystemTime::now()
5350                                        .duration_since(std::time::UNIX_EPOCH)
5351                                        .map_err(|e| {
5352                                            let err_msg = format!(
5353                                                "Failed to get the current time. Reason: {e}"
5354                                            );
5355
5356                                            #[cfg(feature = "logging")]
5357                                            error!(target: "stdout", "{}", &err_msg);
5358
5359                                            LlamaCoreError::Operation(err_msg)
5360                                        })?;
5361
5362                                    let chat_completion_chunk = ChatCompletionChunk {
5363                                        id,
5364                                        object: "chat.completion.chunk".to_string(),
5365                                        created: created.as_secs(),
5366                                        model: graph.name().to_owned(),
5367                                        system_fingerprint: "fp_44709d6fcb".to_string(),
5368                                        choices: vec![],
5369                                        usage,
5370                                    };
5371
5372                                    // serialize chat completion chunk
5373                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
5374                                        .map_err(|e| {
5375                                        let err_msg = format!(
5376                                            "Failed to serialize chat completion chunk. Reason: {e}"
5377                                        );
5378
5379                                        #[cfg(feature = "logging")]
5380                                        error!(target: "stdout", "{}", &err_msg);
5381
5382                                        LlamaCoreError::Operation(err_msg)
5383                                    })?;
5384
5385                                    Ok(format!("data: {chunk_str}\n\n"))
5386                                }
5387                                ContextFullState::Done => {
5388                                    *context_full_state = ContextFullState::EndOfSequence;
5389
5390                                    Ok("data: [DONE]\n\n".to_string())
5391                                }
5392                                ContextFullState::EndOfSequence => {
5393                                    Ok("[GGML] End of sequence".to_string())
5394                                }
5395                            }
5396                        }
5397                        Err(wasmedge_wasi_nn::Error::BackendError(
5398                            wasmedge_wasi_nn::BackendError::PromptTooLong,
5399                        )) => {
5400                            #[cfg(feature = "logging")]
5401                            debug!(target: "stdout", "Prompt too long");
5402
5403                            match prompt_too_long_state {
5404                                PromptTooLongState::Message => {
5405                                    match include_usage {
5406                                        true => *prompt_too_long_state = PromptTooLongState::Usage,
5407                                        false => *prompt_too_long_state = PromptTooLongState::Done,
5408                                    }
5409
5410                                    let created = SystemTime::now()
5411                                        .duration_since(std::time::UNIX_EPOCH)
5412                                        .map_err(|e| {
5413                                            let err_msg = format!(
5414                                                "Failed to get the current time. Reason: {e}"
5415                                            );
5416
5417                                            #[cfg(feature = "logging")]
5418                                            error!(target: "stdout", "{}", &err_msg);
5419
5420                                            LlamaCoreError::Operation(err_msg)
5421                                        })?;
5422
5423                                    let chat_completion_chunk = ChatCompletionChunk {
5424                                        id,
5425                                        object: "chat.completion.chunk".to_string(),
5426                                        created: created.as_secs(),
5427                                        model: graph.name().to_owned(),
5428                                        system_fingerprint: "fp_44709d6fcb".to_string(),
5429                                        choices: vec![ChatCompletionChunkChoice {
5430                                            index: 0,
5431                                            delta: ChatCompletionChunkChoiceDelta {
5432                                                role: ChatCompletionRole::Assistant,
5433                                                content: None,
5434                                                tool_calls: vec![],
5435                                            },
5436                                            logprobs: None,
5437                                            finish_reason: Some(FinishReason::length),
5438                                        }],
5439                                        usage: None,
5440                                    };
5441
5442                                    // serialize chat completion chunk
5443                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
5444                                        .map_err(|e| {
5445                                        let err_msg = format!(
5446                                            "Failed to serialize chat completion chunk. Reason: {e}"
5447                                        );
5448
5449                                        #[cfg(feature = "logging")]
5450                                        error!(target: "stdout", "{}", &err_msg);
5451
5452                                        LlamaCoreError::Operation(err_msg)
5453                                    })?;
5454
5455                                    Ok(format!("data: {chunk_str}\n\n"))
5456                                }
5457                                PromptTooLongState::Usage => {
5458                                    *prompt_too_long_state = PromptTooLongState::Done;
5459
5460                                    // retrieve the number of prompt and completion tokens
5461                                    let token_info = get_token_info_by_graph(graph)?;
5462
5463                                    let usage = Some(Usage {
5464                                        prompt_tokens: token_info.prompt_tokens,
5465                                        completion_tokens: token_info.completion_tokens,
5466                                        total_tokens: token_info.prompt_tokens
5467                                            + token_info.completion_tokens,
5468                                    });
5469
5470                                    let created = SystemTime::now()
5471                                        .duration_since(std::time::UNIX_EPOCH)
5472                                        .map_err(|e| {
5473                                            let err_msg = format!(
5474                                                "Failed to get the current time. Reason: {e}"
5475                                            );
5476
5477                                            #[cfg(feature = "logging")]
5478                                            error!(target: "stdout", "{}", &err_msg);
5479
5480                                            LlamaCoreError::Operation(err_msg)
5481                                        })?;
5482
5483                                    let chat_completion_chunk = ChatCompletionChunk {
5484                                        id,
5485                                        object: "chat.completion.chunk".to_string(),
5486                                        created: created.as_secs(),
5487                                        model: graph.name().to_owned(),
5488                                        system_fingerprint: "fp_44709d6fcb".to_string(),
5489                                        choices: vec![],
5490                                        usage,
5491                                    };
5492
5493                                    // serialize chat completion chunk
5494                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
5495                                        .map_err(|e| {
5496                                        let err_msg = format!(
5497                                            "Failed to serialize chat completion chunk. Reason: {e}"
5498                                        );
5499
5500                                        #[cfg(feature = "logging")]
5501                                        error!(target: "stdout", "{}", &err_msg);
5502
5503                                        LlamaCoreError::Operation(err_msg)
5504                                    })?;
5505
5506                                    Ok(format!("data: {chunk_str}\n\n"))
5507                                }
5508                                PromptTooLongState::Done => {
5509                                    *prompt_too_long_state = PromptTooLongState::EndOfSequence;
5510
5511                                    Ok("data: [DONE]\n\n".to_string())
5512                                }
5513                                PromptTooLongState::EndOfSequence => {
5514                                    Ok("[GGML] End of sequence".to_string())
5515                                }
5516                            }
5517                        }
5518                        Err(e) => {
5519                            let err_msg =
5520                                format!("Failed to compute the chat completion. Reason: {e}");
5521
5522                            #[cfg(feature = "logging")]
5523                            error!(target: "stdout", "{}", &err_msg);
5524
5525                            Err(LlamaCoreError::Backend(BackendError::ComputeSingle(
5526                                err_msg,
5527                            )))
5528                        }
5529                    }
5530                }
5531                None => {
5532                    let err_msg = "There is no model available in the chat graphs.";
5533
5534                    #[cfg(feature = "logging")]
5535                    error!(target: "stdout", "{}", &err_msg);
5536
5537                    Err(LlamaCoreError::Operation(err_msg.into()))
5538                }
5539            }
5540        }
5541    };
5542
5543    #[cfg(feature = "logging")]
5544    info!(target: "stdout", "Return the chat stream chunk!");
5545
5546    res
5547}
5548
5549#[allow(dead_code)]
5550#[derive(Debug)]
5551struct ParseResult {
5552    raw: String,
5553    content: Option<String>,
5554    tool_calls: Vec<ToolCall>,
5555}
5556
5557// ============================================================================
5558// Unit Tests for Per-Model Stream Lock
5559// ============================================================================
5560
5561#[cfg(test)]
5562mod tests {
5563    use super::*;
5564
5565    #[test]
5566    fn test_model_stream_lock_new() {
5567        let lock = ModelStreamLock::new();
5568        assert!(!lock.active.load(Ordering::SeqCst));
5569        assert!(lock.waker_queue.lock().unwrap().is_empty());
5570    }
5571
5572    #[test]
5573    fn test_model_stream_lock_acquire_release() {
5574        let lock = ModelStreamLock::new();
5575
5576        // First acquisition should succeed
5577        assert!(lock.try_acquire());
5578        assert!(lock.active.load(Ordering::SeqCst));
5579
5580        // Second acquisition should fail (already held)
5581        assert!(!lock.try_acquire());
5582
5583        // Release the lock
5584        lock.release();
5585        assert!(!lock.active.load(Ordering::SeqCst));
5586
5587        // After release, acquisition should succeed again
5588        assert!(lock.try_acquire());
5589    }
5590
5591    #[test]
5592    fn test_different_models_can_acquire_parallel() {
5593        let lock_a = Arc::new(ModelStreamLock::new());
5594        let lock_b = Arc::new(ModelStreamLock::new());
5595
5596        // Both locks can be acquired simultaneously
5597        assert!(lock_a.try_acquire());
5598        assert!(lock_b.try_acquire());
5599
5600        // Both are active
5601        assert!(lock_a.active.load(Ordering::SeqCst));
5602        assert!(lock_b.active.load(Ordering::SeqCst));
5603
5604        // Release both
5605        lock_a.release();
5606        lock_b.release();
5607
5608        assert!(!lock_a.active.load(Ordering::SeqCst));
5609        assert!(!lock_b.active.load(Ordering::SeqCst));
5610    }
5611
5612    #[test]
5613    fn test_model_stream_lock_default() {
5614        let lock = ModelStreamLock::default();
5615        assert!(!lock.active.load(Ordering::SeqCst));
5616    }
5617
5618    #[test]
5619    fn test_get_or_create_model_lock() {
5620        // First call should create a new lock
5621        let lock1 = get_or_create_model_lock("test-model-1").unwrap();
5622
5623        // Second call with same name should return the same lock
5624        let lock2 = get_or_create_model_lock("test-model-1").unwrap();
5625
5626        // They should be the same Arc (point to same lock)
5627        assert!(Arc::ptr_eq(&lock1, &lock2));
5628
5629        // Different model name should create a different lock
5630        let lock3 = get_or_create_model_lock("test-model-2").unwrap();
5631        assert!(!Arc::ptr_eq(&lock1, &lock3));
5632
5633        // Operations on one lock should not affect the other
5634        assert!(lock1.try_acquire());
5635        assert!(lock3.try_acquire()); // Different model, should succeed
5636
5637        lock1.release();
5638        lock3.release();
5639    }
5640
5641    #[test]
5642    fn test_concurrent_access_same_model() {
5643        let lock = get_or_create_model_lock("test-concurrent-model").unwrap();
5644
5645        // Simulate first request acquiring the lock
5646        assert!(lock.try_acquire());
5647
5648        // Simulate second request trying to acquire (should fail)
5649        let lock_clone = lock.clone();
5650        assert!(!lock_clone.try_acquire());
5651
5652        // Release from first request
5653        lock.release();
5654
5655        // Now second request can acquire
5656        assert!(lock_clone.try_acquire());
5657
5658        lock_clone.release();
5659    }
5660
5661    #[test]
5662    fn test_model_lock_guard() {
5663        let lock = Arc::new(ModelStreamLock::new());
5664
5665        // Acquire the lock
5666        assert!(lock.try_acquire());
5667        assert!(lock.active.load(Ordering::SeqCst));
5668
5669        // Create a guard (simulating non-stream mode usage)
5670        {
5671            let _guard = ModelLockGuard::new(lock.clone());
5672            // Lock should still be active while guard exists
5673            assert!(lock.active.load(Ordering::SeqCst));
5674        }
5675        // Guard dropped, lock should be released
5676        assert!(!lock.active.load(Ordering::SeqCst));
5677
5678        // Lock should be acquirable again
5679        assert!(lock.try_acquire());
5680        lock.release();
5681    }
5682
5683    #[test]
5684    fn test_model_lock_guard_ensures_release() {
5685        let lock = get_or_create_model_lock("test-guard-model").unwrap();
5686
5687        // Acquire the lock and create guard
5688        assert!(lock.try_acquire());
5689        let _guard = ModelLockGuard::new(lock.clone());
5690
5691        // Another request should fail to acquire
5692        let lock2 = get_or_create_model_lock("test-guard-model").unwrap();
5693        assert!(!lock2.try_acquire());
5694
5695        // Drop the guard explicitly
5696        drop(_guard);
5697
5698        // Now the other request should be able to acquire
5699        assert!(lock2.try_acquire());
5700        lock2.release();
5701    }
5702}