1use crate::{
4 error,
5 metadata::ggml::GgmlMetadata,
6 running_mode,
7 utils::{
8 gen_chat_id, get_output_buffer, get_output_buffer_single, get_token_info_by_graph,
9 get_token_info_by_graph_name, set_tensor_data_u8,
10 },
11 Graph, RunningMode, CACHED_UTF8_ENCODINGS, CHAT_GRAPHS, OUTPUT_TENSOR,
12};
13use chat_prompts::{BuildChatPrompt, ChatPrompt, PromptTemplateType};
14use either::{Either, Left, Right};
15use endpoints::{
16 chat::{
17 ChatCompletionChunk, ChatCompletionChunkChoice, ChatCompletionChunkChoiceDelta,
18 ChatCompletionObject, ChatCompletionObjectChoice, ChatCompletionObjectMessage,
19 ChatCompletionRequest, ChatCompletionRequestMessage, ChatCompletionRole,
20 ChatCompletionUserMessageContent, ContentPart, Function, ToolCall, ToolCallForChunk,
21 ToolChoice,
22 },
23 common::{FinishReason, Usage},
24};
25use error::{BackendError, LlamaCoreError};
26use std::{
27 collections::VecDeque,
28 pin::Pin,
29 sync::{
30 atomic::{AtomicBool, Ordering},
31 Mutex, OnceLock,
32 },
33 task::{Context, Poll, Waker},
34 time::SystemTime,
35};
36
37static CHAT_STREAM_WAKER_QUEUE: OnceLock<Mutex<VecDeque<Waker>>> = OnceLock::new();
39
40static CHAT_STREAM_ACTIVE: AtomicBool = AtomicBool::new(false);
42
43pub async fn chat(
45 chat_request: &mut ChatCompletionRequest,
46) -> Result<
47 (
48 Either<impl futures::TryStream<Ok = String, Error = LlamaCoreError>, ChatCompletionObject>,
49 bool,
50 ),
51 LlamaCoreError,
52> {
53 #[cfg(feature = "logging")]
54 {
55 debug!(target: "stdout", "tool choice: {:?}", chat_request.tool_choice.as_ref());
56 debug!(target: "stdout", "tools: {:?}", chat_request.tools.as_ref());
57 debug!(target: "stdout", "stream mode: {:?}", chat_request.stream);
58 }
59
60 let result = match chat_request.stream {
61 Some(true) => match chat_stream(chat_request).await {
62 Ok((stream, include_tool_calls)) => Ok((Left(stream), include_tool_calls)),
63 Err(e) => Err(e),
64 },
65 Some(false) | None => match chat_once(chat_request).await {
66 Ok((chat_completion_object, include_tool_calls)) => {
67 Ok((Right(chat_completion_object), include_tool_calls))
68 }
69 Err(e) => Err(e),
70 },
71 };
72
73 #[cfg(feature = "logging")]
74 info!(target: "stdout", "Reset the model metadata");
75
76 result
77}
78
79async fn chat_stream(
80 chat_request: &mut ChatCompletionRequest,
81) -> Result<
82 (
83 impl futures::TryStream<Ok = String, Error = LlamaCoreError>,
84 bool,
85 ),
86 LlamaCoreError,
87> {
88 #[cfg(feature = "logging")]
89 info!(target: "stdout", "Process chat completion request in the stream mode");
90
91 let running_mode = running_mode()?;
92 if !running_mode.contains(RunningMode::CHAT) && !running_mode.contains(RunningMode::RAG) {
93 let err_msg = "The chat completion is only supported in the chat or rag mode.";
94
95 #[cfg(feature = "logging")]
96 error!(target: "stdout", "{err_msg}");
97
98 return Err(LlamaCoreError::Operation(err_msg.to_string()));
99 }
100
101 let model_name = chat_request.model.clone();
102 let id = match &chat_request.user {
103 Some(id) => id.clone(),
104 None => gen_chat_id(),
105 };
106 #[cfg(feature = "logging")]
107 info!(target: "stdout", "user: {}", &id);
108
109 #[cfg(feature = "logging")]
110 info!(target: "stdout", "Check model metadata");
111
112 let mut metadata = check_model_metadata(chat_request)?;
114
115 let include_usage = match chat_request.stream_options {
117 Some(ref stream_options) => stream_options.include_usage.unwrap_or_default(),
118 None => metadata.include_usage,
119 };
120 #[cfg(feature = "logging")]
121 info!(target: "stdout", "include_usage: {include_usage}");
122
123 #[cfg(feature = "logging")]
124 info!(target: "stdout", "Build the chat prompt");
125
126 let (prompt, avaible_completion_tokens, tool_use) =
128 build_prompt(model_name.as_ref(), chat_request)?;
129
130 #[cfg(feature = "logging")]
131 {
132 info!(target: "stdout", "prompt:\n{}", &prompt);
133 info!(target: "stdout", "available_completion_tokens: {avaible_completion_tokens}");
134 info!(target: "stdout", "tool_use: {tool_use}");
135 }
136
137 #[cfg(feature = "logging")]
138 info!(target: "stdout", "Update the n_predict");
139
140 update_n_predict(chat_request, &mut metadata, avaible_completion_tokens)?;
142
143 #[cfg(feature = "logging")]
144 info!(target: "stdout", "Feed the prompt to the model");
145
146 set_prompt(chat_request.model.as_ref(), &prompt)?;
148
149 let stream = match tool_use {
150 false => (ChatStream::new(model_name, id, include_usage, None), false),
151 true => {
152 let chat_graphs = match CHAT_GRAPHS.get() {
153 Some(chat_graphs) => chat_graphs,
154 None => {
155 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
156
157 #[cfg(feature = "logging")]
158 error!(target: "stdout", "{}", &err_msg);
159
160 return Err(LlamaCoreError::Operation(err_msg.into()));
161 }
162 };
163
164 let mut chat_graphs = chat_graphs.lock().map_err(|e| {
165 let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
166
167 #[cfg(feature = "logging")]
168 error!(target: "stdout", "{}", &err_msg);
169
170 LlamaCoreError::Operation(err_msg)
171 })?;
172
173 match model_name {
174 Some(model_name) => match chat_graphs.contains_key(&model_name) {
175 true => {
176 let graph = chat_graphs.get_mut(&model_name).unwrap();
177 chat_stream_for_tool(graph, id, include_usage)?
178 }
179 false => match chat_graphs.iter_mut().next() {
180 Some((_, graph)) => chat_stream_for_tool(graph, id, include_usage)?,
181 None => {
182 let err_msg = "There is no model available in the chat graphs.";
183
184 #[cfg(feature = "logging")]
185 error!(target: "stdout", "{}", &err_msg);
186
187 return Err(LlamaCoreError::Operation(err_msg.into()));
188 }
189 },
190 },
191 None => match chat_graphs.iter_mut().next() {
192 Some((_, graph)) => chat_stream_for_tool(graph, id, include_usage)?,
193 None => {
194 let err_msg = "There is no model available in the chat graphs.";
195
196 #[cfg(feature = "logging")]
197 error!(target: "stdout", "{}", &err_msg);
198
199 return Err(LlamaCoreError::Operation(err_msg.into()));
200 }
201 },
202 }
203 }
204 };
205
206 #[cfg(feature = "logging")]
207 info!(target: "stdout", "End of the chat completion stream.");
208
209 Ok(stream)
210}
211
212fn chat_stream_for_tool(
213 graph: &mut Graph<GgmlMetadata>,
214 id: impl Into<String>,
215 include_usage: bool,
216) -> Result<(ChatStream, bool), LlamaCoreError> {
217 #[cfg(feature = "logging")]
218 info!(target: "stdout", "Handle chat request with available tools by the model named {}.", graph.name());
219
220 let id = id.into();
221
222 match graph.compute() {
223 Ok(_) => {
224 let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
226 let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
227 let err_msg = format!(
228 "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
229 );
230
231 #[cfg(feature = "logging")]
232 error!(target: "stdout", "{}", &err_msg);
233
234 LlamaCoreError::Operation(err_msg)
235 })?;
236
237 #[cfg(feature = "logging")]
238 info!(target: "stdout", "raw generation:\n{output}");
239
240 let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
242 LlamaCoreError::Operation(format!("Failed to post-process the output. {e}"))
243 })?;
244
245 #[cfg(feature = "logging")]
246 info!(target: "stdout", "post-processed generation:\n{}", &message);
247
248 let token_info = get_token_info_by_graph(graph)?;
250
251 #[cfg(feature = "logging")]
252 info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
253
254 let usage = Some(Usage {
255 prompt_tokens: token_info.prompt_tokens,
256 completion_tokens: token_info.completion_tokens,
257 total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
258 });
259
260 let created = SystemTime::now()
261 .duration_since(std::time::UNIX_EPOCH)
262 .map_err(|e| {
263 let err_msg = format!("Failed to get the current time. Reason: {e}");
264
265 #[cfg(feature = "logging")]
266 error!(target: "stdout", "{}", &err_msg);
267
268 LlamaCoreError::Operation(err_msg)
269 })?;
270
271 if graph.metadata.prompt_template != PromptTemplateType::MistralTool
272 && graph.metadata.prompt_template != PromptTemplateType::ChatMLTool
273 && graph.metadata.prompt_template != PromptTemplateType::GroqLlama3Tool
274 && graph.metadata.prompt_template != PromptTemplateType::Llama3Tool
275 && graph.metadata.prompt_template != PromptTemplateType::InternLM2Tool
276 && graph.metadata.prompt_template != PromptTemplateType::NemotronTool
277 && graph.metadata.prompt_template != PromptTemplateType::FunctionaryV32
278 && graph.metadata.prompt_template != PromptTemplateType::FunctionaryV31
279 && graph.metadata.prompt_template != PromptTemplateType::MistralSmallTool
280 && graph.metadata.prompt_template != PromptTemplateType::Llama4Chat
281 && graph.metadata.prompt_template != PromptTemplateType::Qwen3NoThink
282 && graph.metadata.prompt_template != PromptTemplateType::Smol3NoThink
283 && graph.metadata.prompt_template != PromptTemplateType::Gemma3
284 && graph.metadata.prompt_template != PromptTemplateType::GptOss
285 && graph.metadata.prompt_template != PromptTemplateType::Qwen3Agent
286 {
287 let err_msg = format!("Unsupported prompt template: {}. The tool use is only supported for 'mistral-tool', 'chatml-tool', 'groq-llama3-tool', 'llama-3-tool', 'internlm-2-tool', 'nemotron-tool', 'functionary-31', 'functionary-32', 'mistral-small-tool', 'llama-4-chat', 'qwen3-no-think', 'smol-3-no-think', 'gemma-3', 'gpt-oss' and 'qwen3-agent' prompt templates.", graph.metadata.prompt_template);
288
289 #[cfg(feature = "logging")]
290 error!(target: "stdout", "{}", &err_msg);
291
292 return Err(LlamaCoreError::Operation(err_msg));
293 }
294
295 let parsed_result = parse_tool_calls(&message, graph.metadata.prompt_template)?;
296
297 let content = if parsed_result.tool_calls.is_empty() {
298 Some(parsed_result.raw.clone())
299 } else {
300 parsed_result.content.clone()
301 };
302
303 let (tool_calls, include_tool_calls) = match parsed_result.tool_calls.is_empty() {
304 false => {
305 let tool_calls: Vec<ToolCallForChunk> = parsed_result
306 .tool_calls
307 .into_iter()
308 .enumerate()
309 .map(|(index, tool_call)| ToolCallForChunk {
310 index,
311 id: tool_call.id,
312 ty: tool_call.ty,
313 function: tool_call.function,
314 })
315 .collect();
316 (tool_calls, true)
317 }
318 true => (vec![], false),
319 };
320
321 let tool_call_chunk = {
323 let chat_completion_chunk = ChatCompletionChunk {
324 id: id.clone(),
325 object: "chat.completion.chunk".to_string(),
326 created: created.as_secs(),
327 model: graph.name().to_owned(),
328 system_fingerprint: "fp_44709d6fcb".to_string(),
329 choices: vec![ChatCompletionChunkChoice {
330 index: 0,
331 delta: ChatCompletionChunkChoiceDelta {
332 role: ChatCompletionRole::Assistant,
333 content,
334 tool_calls,
335 },
336 logprobs: None,
337 finish_reason: None,
338 }],
339 usage: None,
340 };
341 let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
342 let err_msg = format!("Failed to serialize chat completion chunk. Reason: {e}");
343
344 #[cfg(feature = "logging")]
345 error!(target: "stdout", "{}", &err_msg);
346
347 LlamaCoreError::Operation(err_msg)
348 })?;
349
350 format!("data: {chunk_str}\n\n")
351 };
352
353 let usage_chunk = {
355 let chat_completion_chunk = ChatCompletionChunk {
356 id: id.clone(),
357 object: "chat.completion.chunk".to_string(),
358 created: created.as_secs(),
359 model: graph.name().to_owned(),
360 system_fingerprint: "fp_44709d6fcb".to_string(),
361 choices: vec![],
362 usage,
363 };
364 let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
365 let err_msg = format!("Failed to serialize chat completion chunk. Reason: {e}");
366
367 #[cfg(feature = "logging")]
368 error!(target: "stdout", "{}", &err_msg);
369
370 LlamaCoreError::Operation(err_msg)
371 })?;
372
373 format!("data: {chunk_str}\n\n")
374 };
375
376 let ending_chunk = "data: [DONE]\n\n".to_string();
378
379 let chunks = vec![tool_call_chunk, usage_chunk, ending_chunk];
380
381 let stream = ChatStream::new(
382 Some(graph.name().to_owned()),
383 id,
384 include_usage,
385 Some(chunks),
386 );
387
388 Ok((stream, include_tool_calls))
389 }
390 Err(wasmedge_wasi_nn::Error::BackendError(wasmedge_wasi_nn::BackendError::ContextFull)) => {
391 let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
393 let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
394 let err_msg = format!(
395 "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
396 );
397
398 #[cfg(feature = "logging")]
399 error!(target: "stdout", "{}", &err_msg);
400
401 LlamaCoreError::Operation(err_msg)
402 })?;
403
404 let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
406 let err_msg = format!("Failed to post-process the output. {e}");
407
408 #[cfg(feature = "logging")]
409 error!(target: "stdout", "{}", &err_msg);
410
411 LlamaCoreError::Operation(err_msg)
412 })?;
413
414 let token_info = get_token_info_by_graph(graph)?;
416
417 #[cfg(feature = "logging")]
418 info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
419
420 let usage = Some(Usage {
421 prompt_tokens: token_info.prompt_tokens,
422 completion_tokens: token_info.completion_tokens,
423 total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
424 });
425
426 let created = SystemTime::now()
427 .duration_since(std::time::UNIX_EPOCH)
428 .map_err(|e| {
429 let err_msg = format!("Failed to get the current time. Reason: {e}");
430
431 #[cfg(feature = "logging")]
432 error!(target: "stdout", "{}", &err_msg);
433
434 LlamaCoreError::Operation(err_msg)
435 })?;
436
437 let context_full_chunk = {
439 let chat_completion_chunk = ChatCompletionChunk {
440 id: id.clone(),
441 object: "chat.completion.chunk".to_string(),
442 created: created.as_secs(),
443 model: graph.name().to_owned(),
444 system_fingerprint: "fp_44709d6fcb".to_string(),
445 choices: vec![ChatCompletionChunkChoice {
446 index: 0,
447 delta: ChatCompletionChunkChoiceDelta {
448 role: ChatCompletionRole::Assistant,
449 content: Some(message),
450 tool_calls: vec![],
451 },
452 logprobs: None,
453 finish_reason: Some(FinishReason::length),
454 }],
455 usage: None,
456 };
457
458 let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
460 let err_msg = format!("Failed to serialize chat completion chunk. Reason: {e}");
461
462 #[cfg(feature = "logging")]
463 error!(target: "stdout", "{}", &err_msg);
464
465 LlamaCoreError::Operation(err_msg)
466 })?;
467
468 format!("data: {chunk_str}\n\n")
469 };
470
471 let usage_chunk = {
473 let chat_completion_chunk = ChatCompletionChunk {
474 id: id.clone(),
475 object: "chat.completion.chunk".to_string(),
476 created: created.as_secs(),
477 model: graph.name().to_owned(),
478 system_fingerprint: "fp_44709d6fcb".to_string(),
479 choices: vec![],
480 usage,
481 };
482
483 let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
485 let err_msg = format!("Failed to serialize chat completion chunk. Reason: {e}");
486
487 #[cfg(feature = "logging")]
488 error!(target: "stdout", "{}", &err_msg);
489
490 LlamaCoreError::Operation(err_msg)
491 })?;
492
493 format!("data: {chunk_str}\n\n")
494 };
495
496 let ending_chunk = "data: [DONE]\n\n".to_string();
498
499 let chunks = vec![context_full_chunk, usage_chunk, ending_chunk];
500
501 let stream = ChatStream::new(
502 Some(graph.name().to_owned()),
503 id,
504 include_usage,
505 Some(chunks),
506 );
507
508 Ok((stream, false))
509 }
510 Err(wasmedge_wasi_nn::Error::BackendError(
511 wasmedge_wasi_nn::BackendError::PromptTooLong,
512 )) => {
513 #[cfg(feature = "logging")]
514 warn!(target: "stdout", "The prompt is too long. Please reduce the length of your input and try again.");
515
516 let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
518 let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
519 let err_msg = format!(
520 "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
521 );
522
523 #[cfg(feature = "logging")]
524 error!(target: "stdout", "{}", &err_msg);
525
526 LlamaCoreError::Operation(err_msg)
527 })?;
528
529 let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
531 let err_msg = format!("Failed to post-process the output. {e}");
532
533 #[cfg(feature = "logging")]
534 error!(target: "stdout", "{}", &err_msg);
535
536 LlamaCoreError::Operation(err_msg)
537 })?;
538
539 let token_info = get_token_info_by_graph(graph)?;
541
542 #[cfg(feature = "logging")]
543 info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
544
545 let usage = Some(Usage {
546 prompt_tokens: token_info.prompt_tokens,
547 completion_tokens: token_info.completion_tokens,
548 total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
549 });
550
551 let created = SystemTime::now()
552 .duration_since(std::time::UNIX_EPOCH)
553 .map_err(|e| {
554 let err_msg = format!("Failed to get the current time. Reason: {e}");
555
556 #[cfg(feature = "logging")]
557 error!(target: "stdout", "{}", &err_msg);
558
559 LlamaCoreError::Operation(err_msg)
560 })?;
561
562 let prompt_too_long_chunk = {
564 let chat_completion_chunk = ChatCompletionChunk {
565 id: id.clone(),
566 object: "chat.completion.chunk".to_string(),
567 created: created.as_secs(),
568 model: graph.name().to_owned(),
569 system_fingerprint: "fp_44709d6fcb".to_string(),
570 choices: vec![ChatCompletionChunkChoice {
571 index: 0,
572 delta: ChatCompletionChunkChoiceDelta {
573 role: ChatCompletionRole::Assistant,
574 content: Some(message),
575 tool_calls: vec![],
576 },
577 logprobs: None,
578 finish_reason: Some(FinishReason::length),
579 }],
580 usage: None,
581 };
582
583 let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
585 let err_msg = format!("Failed to serialize chat completion chunk. Reason: {e}");
586
587 #[cfg(feature = "logging")]
588 error!(target: "stdout", "{}", &err_msg);
589
590 LlamaCoreError::Operation(err_msg)
591 })?;
592
593 format!("data: {chunk_str}\n\n")
594 };
595
596 let usage_chunk = {
598 let chat_completion_chunk = ChatCompletionChunk {
599 id: id.clone(),
600 object: "chat.completion.chunk".to_string(),
601 created: created.as_secs(),
602 model: graph.name().to_owned(),
603 system_fingerprint: "fp_44709d6fcb".to_string(),
604 choices: vec![],
605 usage,
606 };
607
608 let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
610 let err_msg = format!("Failed to serialize chat completion chunk. Reason: {e}");
611
612 #[cfg(feature = "logging")]
613 error!(target: "stdout", "{}", &err_msg);
614
615 LlamaCoreError::Operation(err_msg)
616 })?;
617
618 format!("data: {chunk_str}\n\n")
619 };
620
621 let ending_chunk = "data: [DONE]\n\n".to_string();
623
624 let chunks = vec![prompt_too_long_chunk, usage_chunk, ending_chunk];
625
626 let stream = ChatStream::new(
627 Some(graph.name().to_owned()),
628 id,
629 include_usage,
630 Some(chunks),
631 );
632
633 Ok((stream, false))
634 }
635 Err(e) => {
636 let err_msg = format!("Failed to compute the chat completion. Reason: {e}");
637
638 #[cfg(feature = "logging")]
639 error!(target: "stdout", "{}", &err_msg);
640
641 Err(LlamaCoreError::Backend(BackendError::Compute(err_msg)))
642 }
643 }
644}
645
646async fn chat_once(
647 chat_request: &mut ChatCompletionRequest,
648) -> Result<(ChatCompletionObject, bool), LlamaCoreError> {
649 #[cfg(feature = "logging")]
650 info!(target: "stdout", "Processing chat completion request in non-stream mode");
651
652 let running_mode = running_mode()?;
653 if !running_mode.contains(RunningMode::CHAT) && !running_mode.contains(RunningMode::RAG) {
654 let err_msg = "The chat completion is only supported in the chat or rag mode.";
655
656 #[cfg(feature = "logging")]
657 error!(target: "stdout", "{err_msg}");
658
659 return Err(LlamaCoreError::Operation(err_msg.to_string()));
660 }
661
662 let model_name = chat_request.model.clone();
663 let id = match &chat_request.user {
664 Some(id) => id.clone(),
665 None => gen_chat_id(),
666 };
667
668 #[cfg(feature = "logging")]
669 info!(target: "stdout", "user: {}", &id);
670
671 #[cfg(feature = "logging")]
672 info!(target: "stdout", "Check model metadata");
673
674 let mut metadata = check_model_metadata(chat_request)?;
676
677 #[cfg(feature = "logging")]
678 info!(target: "stdout", "Build the chat prompt");
679
680 let (prompt, avaible_completion_tokens, tool_use) =
682 build_prompt(model_name.as_ref(), chat_request)?;
683
684 #[cfg(feature = "logging")]
685 {
686 info!(target: "stdout", "prompt:\n{}", &prompt);
687 info!(target: "stdout", "available_completion_tokens: {avaible_completion_tokens}");
688 info!(target: "stdout", "tool_use: {tool_use}");
689 }
690
691 #[cfg(feature = "logging")]
692 info!(target: "stdout", "Update n_predict");
693
694 update_n_predict(chat_request, &mut metadata, avaible_completion_tokens)?;
696
697 #[cfg(feature = "logging")]
698 info!(target: "stdout", "Feed the prompt to the model");
699
700 set_prompt(model_name.as_ref(), &prompt)?;
702
703 #[cfg(feature = "logging")]
704 info!(target: "stdout", "Compute chat completion.");
705
706 let res = compute(model_name.as_ref(), id, tool_use);
708
709 #[cfg(feature = "logging")]
710 info!(target: "stdout", "End of the chat completion");
711
712 reset_model_metadata(model_name.as_ref())?;
714
715 res
716}
717
718fn compute(
719 model_name: Option<&String>,
720 id: impl Into<String>,
721 tool_use: bool,
722) -> Result<(ChatCompletionObject, bool), LlamaCoreError> {
723 let chat_graphs = match CHAT_GRAPHS.get() {
724 Some(chat_graphs) => chat_graphs,
725 None => {
726 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
727
728 #[cfg(feature = "logging")]
729 error!(target: "stdout", "{}", &err_msg);
730
731 return Err(LlamaCoreError::Operation(err_msg.into()));
732 }
733 };
734
735 let mut chat_graphs = chat_graphs.lock().map_err(|e| {
736 let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
737
738 #[cfg(feature = "logging")]
739 error!(target: "stdout", "{}", &err_msg);
740
741 LlamaCoreError::Operation(err_msg)
742 })?;
743
744 match model_name {
745 Some(model_name) => match chat_graphs.contains_key(model_name) {
746 true => {
747 let graph = chat_graphs.get_mut(model_name).unwrap();
748 compute_by_graph(graph, id, tool_use)
749 }
750 false => match chat_graphs.iter_mut().next() {
751 Some((_, graph)) => compute_by_graph(graph, id, tool_use),
752 None => {
753 let err_msg = "There is no model available in the chat graphs.";
754
755 #[cfg(feature = "logging")]
756 error!(target: "stdout", "{}", &err_msg);
757
758 Err(LlamaCoreError::Operation(err_msg.into()))
759 }
760 },
761 },
762 None => match chat_graphs.iter_mut().next() {
763 Some((_, graph)) => compute_by_graph(graph, id, tool_use),
764 None => {
765 let err_msg = "There is no model available in the chat graphs.";
766
767 #[cfg(feature = "logging")]
768 error!(target: "stdout", "{}", &err_msg);
769
770 Err(LlamaCoreError::Operation(err_msg.into()))
771 }
772 },
773 }
774}
775
776fn compute_by_graph(
777 graph: &mut Graph<GgmlMetadata>,
778 id: impl Into<String>,
779 tool_use: bool,
780) -> Result<(ChatCompletionObject, bool), LlamaCoreError> {
781 #[cfg(feature = "logging")]
782 info!(target: "stdout", "Compute chat completion by the model named {}.", graph.name());
783
784 match graph.compute() {
785 Ok(_) => {
786 let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
788 let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
789 let err_msg = format!(
790 "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
791 );
792
793 #[cfg(feature = "logging")]
794 error!(target: "stdout", "{}", &err_msg);
795
796 LlamaCoreError::Operation(err_msg)
797 })?;
798
799 #[cfg(feature = "logging")]
800 info!(target: "stdout", "raw generation: {output}");
801
802 let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
804 LlamaCoreError::Operation(format!("Failed to post-process the output. {e}"))
805 })?;
806
807 #[cfg(feature = "logging")]
808 info!(target: "stdout", "post-processed generation:\n{}", &message);
809
810 let token_info = get_token_info_by_graph(graph)?;
812
813 #[cfg(feature = "logging")]
814 info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
815
816 let created = SystemTime::now()
817 .duration_since(std::time::UNIX_EPOCH)
818 .map_err(|e| {
819 let err_msg = format!("Failed to get the current time. Reason: {e}");
820
821 #[cfg(feature = "logging")]
822 error!(target: "stdout", "{}", &err_msg);
823
824 LlamaCoreError::Operation(err_msg)
825 })?;
826
827 match tool_use {
828 true => {
829 if graph.metadata.prompt_template != PromptTemplateType::MistralTool
830 && graph.metadata.prompt_template != PromptTemplateType::ChatMLTool
831 && graph.metadata.prompt_template != PromptTemplateType::GroqLlama3Tool
832 && graph.metadata.prompt_template != PromptTemplateType::Llama3Tool
833 && graph.metadata.prompt_template != PromptTemplateType::InternLM2Tool
834 && graph.metadata.prompt_template != PromptTemplateType::NemotronTool
835 && graph.metadata.prompt_template != PromptTemplateType::FunctionaryV32
836 && graph.metadata.prompt_template != PromptTemplateType::FunctionaryV31
837 && graph.metadata.prompt_template != PromptTemplateType::MistralSmallTool
838 && graph.metadata.prompt_template != PromptTemplateType::Llama4Chat
839 && graph.metadata.prompt_template != PromptTemplateType::Qwen3NoThink
840 && graph.metadata.prompt_template != PromptTemplateType::Smol3NoThink
841 && graph.metadata.prompt_template != PromptTemplateType::Gemma3
842 && graph.metadata.prompt_template != PromptTemplateType::GptOss
843 && graph.metadata.prompt_template != PromptTemplateType::Qwen3Agent
844 {
845 let err_msg = format!("Unsupported prompt template: {}. The tool use is only supported for 'mistral-tool', 'chatml-tool', 'groq-llama3-tool', 'llama-3-tool', 'internlm-2-tool', 'nemotron-tool', 'functionary-31', 'functionary-32', 'mistral-small-tool', 'llama-4-chat', 'qwen3-no-think', 'smol-3-no-think', 'gemma-3', 'gpt-oss', and 'qwen3-agent' prompt templates.", graph.metadata.prompt_template);
846
847 #[cfg(feature = "logging")]
848 error!(target: "stdout", "{}", &err_msg);
849
850 return Err(LlamaCoreError::Operation(err_msg));
851 }
852
853 let parsed_result = parse_tool_calls(&message, graph.metadata.prompt_template)?;
854
855 let (finish_reason, content, include_tool_calls) =
856 if parsed_result.tool_calls.is_empty() {
857 (FinishReason::stop, Some(parsed_result.raw.clone()), false)
858 } else if graph.metadata.prompt_template != PromptTemplateType::Qwen3Agent {
859 (
860 FinishReason::tool_calls,
861 Some(parsed_result.raw.clone()),
862 true,
863 )
864 } else {
865 (
866 FinishReason::tool_calls,
867 parsed_result.content.clone(),
868 true,
869 )
870 };
871
872 let res = ChatCompletionObject {
873 id: id.into(),
874 object: String::from("chat.completion"),
875 created: created.as_secs(),
876 model: graph.name().to_owned(),
877 choices: vec![ChatCompletionObjectChoice {
878 index: 0,
879 message: ChatCompletionObjectMessage {
880 role: ChatCompletionRole::Assistant,
881 content,
882 tool_calls: parsed_result.tool_calls,
883 function_call: None,
884 },
885 finish_reason,
886 logprobs: None,
887 }],
888 usage: Usage {
889 prompt_tokens: token_info.prompt_tokens,
890 completion_tokens: token_info.completion_tokens,
891 total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
892 },
893 };
894
895 Ok((res, include_tool_calls))
897 }
898 false => {
899 let res = ChatCompletionObject {
901 id: id.into(),
902 object: String::from("chat.completion"),
903 created: created.as_secs(),
904 model: graph.name().to_owned(),
905 choices: vec![ChatCompletionObjectChoice {
906 index: 0,
907 message: ChatCompletionObjectMessage {
908 role: ChatCompletionRole::Assistant,
909 content: Some(message),
910 tool_calls: vec![],
911 function_call: None,
912 },
913 finish_reason: FinishReason::stop,
914 logprobs: None,
915 }],
916 usage: Usage {
917 prompt_tokens: token_info.prompt_tokens,
918 completion_tokens: token_info.completion_tokens,
919 total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
920 },
921 };
922
923 Ok((res, false))
924 }
925 }
926 }
927 Err(wasmedge_wasi_nn::Error::BackendError(wasmedge_wasi_nn::BackendError::ContextFull)) => {
928 let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
930 let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
931 let err_msg = format!(
932 "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
933 );
934
935 #[cfg(feature = "logging")]
936 error!(target: "stdout", "{}", &err_msg);
937
938 LlamaCoreError::Operation(err_msg)
939 })?;
940
941 let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
943 let err_msg = format!("Failed to post-process the output. {e}");
944
945 #[cfg(feature = "logging")]
946 error!(target: "stdout", "{}", &err_msg);
947
948 LlamaCoreError::Operation(err_msg)
949 })?;
950
951 let token_info = get_token_info_by_graph(graph)?;
953
954 #[cfg(feature = "logging")]
955 info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
956
957 let created = SystemTime::now()
958 .duration_since(std::time::UNIX_EPOCH)
959 .map_err(|e| {
960 let err_msg = format!("Failed to get the current time. Reason: {e}");
961
962 #[cfg(feature = "logging")]
963 error!(target: "stdout", "{}", &err_msg);
964
965 LlamaCoreError::Operation(err_msg)
966 })?;
967
968 let res = ChatCompletionObject {
970 id: id.into(),
971 object: String::from("chat.completion"),
972 created: created.as_secs(),
973 model: graph.name().to_owned(),
974 choices: vec![ChatCompletionObjectChoice {
975 index: 0,
976 message: ChatCompletionObjectMessage {
977 role: ChatCompletionRole::Assistant,
978 content: Some(message),
979 tool_calls: vec![],
980 function_call: None,
981 },
982 finish_reason: FinishReason::length,
983 logprobs: None,
984 }],
985 usage: Usage {
986 prompt_tokens: token_info.prompt_tokens,
987 completion_tokens: token_info.completion_tokens,
988 total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
989 },
990 };
991
992 Ok((res, false))
993 }
994 Err(wasmedge_wasi_nn::Error::BackendError(
995 wasmedge_wasi_nn::BackendError::PromptTooLong,
996 )) => {
997 #[cfg(feature = "logging")]
998 warn!(target: "stdout", "The prompt is too long. Please reduce the length of your input and try again.");
999
1000 let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
1002 let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
1003 let err_msg = format!(
1004 "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
1005 );
1006
1007 #[cfg(feature = "logging")]
1008 error!(target: "stdout", "{}", &err_msg);
1009
1010 LlamaCoreError::Operation(err_msg)
1011 })?;
1012
1013 let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
1015 let err_msg = format!("Failed to post-process the output. {e}");
1016
1017 #[cfg(feature = "logging")]
1018 error!(target: "stdout", "{}", &err_msg);
1019
1020 LlamaCoreError::Operation(err_msg)
1021 })?;
1022
1023 let token_info = get_token_info_by_graph(graph)?;
1025
1026 #[cfg(feature = "logging")]
1027 info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
1028
1029 let usage = Usage {
1030 prompt_tokens: token_info.prompt_tokens,
1031 completion_tokens: token_info.completion_tokens,
1032 total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
1033 };
1034
1035 let created = SystemTime::now()
1036 .duration_since(std::time::UNIX_EPOCH)
1037 .map_err(|e| {
1038 let err_msg = format!("Failed to get the current time. Reason: {e}");
1039
1040 #[cfg(feature = "logging")]
1041 error!(target: "stdout", "{}", &err_msg);
1042
1043 LlamaCoreError::Operation(err_msg)
1044 })?;
1045
1046 let res = ChatCompletionObject {
1048 id: id.into(),
1049 object: String::from("chat.completion"),
1050 created: created.as_secs(),
1051 model: graph.name().to_owned(),
1052 choices: vec![ChatCompletionObjectChoice {
1053 index: 0,
1054 message: ChatCompletionObjectMessage {
1055 role: ChatCompletionRole::Assistant,
1056 content: Some(message),
1057 tool_calls: vec![],
1058 function_call: None,
1059 },
1060 finish_reason: FinishReason::length,
1061 logprobs: None,
1062 }],
1063 usage,
1064 };
1065
1066 Ok((res, false))
1067 }
1068 Err(e) => {
1069 let err_msg = format!("Failed to compute the chat completion. Reason: {e}");
1070
1071 #[cfg(feature = "logging")]
1072 error!(target: "stdout", "{}", &err_msg);
1073
1074 Err(LlamaCoreError::Backend(BackendError::Compute(err_msg)))
1075 }
1076 }
1077}
1078
1079fn parse_tool_calls(
1080 input: &str,
1081 prompt_template: PromptTemplateType,
1082) -> Result<ParseResult, LlamaCoreError> {
1083 match prompt_template {
1084 PromptTemplateType::MistralTool => match regex::Regex::new(r"\[\{.*?\}\]") {
1085 Ok(re) => {
1086 let mut values: Vec<serde_json::Value> = vec![];
1087 for cap in re.captures_iter(input) {
1088 let matched = &cap[0];
1089
1090 #[cfg(feature = "logging")]
1091 info!(target: "stdout", "captured: {matched}");
1092
1093 match serde_json::from_str::<Vec<serde_json::Value>>(matched) {
1094 Ok(group) => values.extend(group),
1095 Err(e) => {
1096 let err_msg =
1097 format!("Failed to deserialize generated tool calls. Reason: {e}");
1098
1099 #[cfg(feature = "logging")]
1100 error!(target: "stdout", "{}", &err_msg);
1101
1102 return Err(LlamaCoreError::Operation(err_msg));
1103 }
1104 }
1105 }
1106
1107 let mut tool_calls: Vec<ToolCall> = vec![];
1108 for value in values.iter() {
1109 let name = match value.get("name") {
1110 Some(name) => name.to_string().replace("\"", ""),
1111 None => {
1112 let err_msg = format!(
1113 "Failed to get the name of the function. Tool call: {value:?}"
1114 );
1115
1116 #[cfg(feature = "logging")]
1117 error!(target: "stdout", "{}", &err_msg);
1118
1119 return Err(LlamaCoreError::Operation(err_msg));
1120 }
1121 };
1122
1123 let arguments = match value.get("arguments") {
1124 Some(arguments) => arguments.to_string(),
1125 None => {
1126 let err_msg = format!(
1127 "Failed to get the arguments of the function. Tool call: {value:?}"
1128 );
1129
1130 #[cfg(feature = "logging")]
1131 error!(target: "stdout", "{}", &err_msg);
1132
1133 return Err(LlamaCoreError::Operation(err_msg));
1134 }
1135 };
1136
1137 let function = Function { name, arguments };
1138
1139 let tool_call = ToolCall {
1140 id: "call_abc123".to_string(),
1141 ty: "function".to_string(),
1142 function,
1143 };
1144
1145 tool_calls.push(tool_call);
1146 }
1147
1148 let parsed = ParseResult {
1149 raw: input.to_owned(),
1150 content: None,
1151 tool_calls,
1152 };
1153
1154 #[cfg(feature = "logging")]
1155 info!(target: "stdout", "parsed result: {parsed:?}");
1156
1157 Ok(parsed)
1158 }
1159 Err(e) => {
1160 let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1161
1162 #[cfg(feature = "logging")]
1163 error!(target: "stdout", "{}", &err_msg);
1164
1165 Err(LlamaCoreError::Operation(err_msg))
1166 }
1167 },
1168 PromptTemplateType::ChatMLTool => {
1169 match regex::Regex::new(r"<tool_call>(.*?)</tool_call>") {
1170 Ok(re) => {
1171 let mut values: Vec<serde_json::Value> = vec![];
1172 for cap in re.captures_iter(input) {
1173 let matched = cap[1].replace("\\n", ""); #[cfg(feature = "logging")]
1176 info!(target: "stdout", "captured: {}", &matched);
1177
1178 match serde_json::from_str::<serde_json::Value>(&matched) {
1179 Ok(value) => values.push(value),
1180 Err(e) => {
1181 let err_msg = format!(
1182 "Failed to deserialize generated tool calls. Reason: {e}"
1183 );
1184
1185 #[cfg(feature = "logging")]
1186 error!(target: "stdout", "{}", &err_msg);
1187
1188 return Err(LlamaCoreError::Operation(err_msg));
1189 }
1190 }
1191 }
1192
1193 let mut tool_calls: Vec<ToolCall> = vec![];
1194 for value in values.iter() {
1195 let name = match value.get("name") {
1196 Some(name) => name.to_string().replace("\"", ""),
1197 None => {
1198 let err_msg = format!(
1199 "Failed to get the name of the function. Tool call: {value:?}"
1200 );
1201
1202 #[cfg(feature = "logging")]
1203 error!(target: "stdout", "{}", &err_msg);
1204
1205 return Err(LlamaCoreError::Operation(err_msg));
1206 }
1207 };
1208
1209 let arguments = match value.get("arguments") {
1210 Some(arguments) => arguments.to_string(),
1211 None => {
1212 let err_msg = format!(
1213 "Failed to get the arguments of the function. Tool call: {value:?}"
1214 );
1215
1216 #[cfg(feature = "logging")]
1217 error!(target: "stdout", "{}", &err_msg);
1218
1219 return Err(LlamaCoreError::Operation(err_msg));
1220 }
1221 };
1222
1223 let function = Function { name, arguments };
1224
1225 let tool_call = ToolCall {
1226 id: "call_abc123".to_string(),
1227 ty: "function".to_string(),
1228 function,
1229 };
1230
1231 tool_calls.push(tool_call);
1232 }
1233
1234 let parsed = ParseResult {
1235 raw: input.to_owned(),
1236 content: None,
1237 tool_calls,
1238 };
1239
1240 #[cfg(feature = "logging")]
1241 info!(target: "stdout", "parsed result: {parsed:?}");
1242
1243 Ok(parsed)
1244 }
1245 Err(e) => {
1246 let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1247
1248 #[cfg(feature = "logging")]
1249 error!(target: "stdout", "{}", &err_msg);
1250
1251 Err(LlamaCoreError::Operation(err_msg))
1252 }
1253 }
1254 }
1255 PromptTemplateType::GroqLlama3Tool => {
1256 #[cfg(feature = "logging")]
1257 info!(target: "stdout", "raw input: {input}");
1258
1259 match regex::Regex::new(r"(?s)<tool_call>((.|\r|\n)*?)</tool_call>") {
1260 Ok(re) => {
1261 let mut values: Vec<serde_json::Value> = vec![];
1262 for cap in re.captures_iter(input) {
1263 let matched = cap[1].trim();
1264
1265 #[cfg(feature = "logging")]
1266 info!(target: "stdout", "captured: {matched}");
1267
1268 match serde_json::from_str::<serde_json::Value>(matched) {
1269 Ok(value) => values.push(value),
1270 Err(e) => {
1271 let err_msg = format!(
1272 "Failed to deserialize generated tool calls. Reason: {e}"
1273 );
1274
1275 #[cfg(feature = "logging")]
1276 error!(target: "stdout", "{}", &err_msg);
1277
1278 return Err(LlamaCoreError::Operation(err_msg));
1279 }
1280 }
1281 }
1282
1283 let mut tool_calls: Vec<ToolCall> = vec![];
1284 for value in values.iter() {
1285 let name = match value.get("name") {
1286 Some(name) => name.to_string().replace("\"", ""),
1287 None => {
1288 let err_msg = format!(
1289 "Failed to get the name of the function. Tool call: {value:?}"
1290 );
1291
1292 #[cfg(feature = "logging")]
1293 error!(target: "stdout", "{}", &err_msg);
1294
1295 return Err(LlamaCoreError::Operation(err_msg));
1296 }
1297 };
1298
1299 let arguments = match value.get("arguments") {
1300 Some(arguments) => {
1301 if arguments.is_string() {
1302 arguments.as_str().unwrap().to_string()
1303 } else if arguments.is_object() {
1304 let map = arguments.as_object().unwrap();
1305
1306 #[cfg(feature = "logging")]
1307 info!(target: "stdout", "func arguments: {map:?}");
1308
1309 serde_json::to_string(map).unwrap()
1310 } else {
1311 serde_json::to_string(arguments).unwrap()
1312 }
1313 }
1314 None => {
1315 let err_msg = format!(
1316 "Failed to get the arguments of the function. Tool call: {value:?}"
1317 );
1318
1319 #[cfg(feature = "logging")]
1320 error!(target: "stdout", "{}", &err_msg);
1321
1322 return Err(LlamaCoreError::Operation(err_msg));
1323 }
1324 };
1325
1326 let function = Function { name, arguments };
1327
1328 let tool_call = ToolCall {
1329 id: "call_abc123".to_string(),
1330 ty: "function".to_string(),
1331 function,
1332 };
1333
1334 tool_calls.push(tool_call);
1335 }
1336
1337 let parsed = if tool_calls.is_empty() {
1338 ParseResult {
1339 raw: input.to_owned(),
1340 content: Some(input.to_owned()),
1341 tool_calls: vec![],
1342 }
1343 } else {
1344 ParseResult {
1345 raw: input.to_owned(),
1346 content: None,
1347 tool_calls,
1348 }
1349 };
1350
1351 #[cfg(feature = "logging")]
1352 info!(target: "stdout", "parsed result: {parsed:?}");
1353
1354 Ok(parsed)
1355 }
1356 Err(e) => {
1357 let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1358
1359 #[cfg(feature = "logging")]
1360 error!(target: "stdout", "{}", &err_msg);
1361
1362 Err(LlamaCoreError::Operation(err_msg))
1363 }
1364 }
1365 }
1366 PromptTemplateType::Llama3Tool => {
1367 #[cfg(feature = "logging")]
1368 info!(target: "stdout", "raw input: {input}");
1369
1370 let re = match regex::Regex::new(r"^\{(.|\r|\n)*\}$") {
1371 Ok(re) => re,
1372 Err(e) => {
1373 let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1374
1375 #[cfg(feature = "logging")]
1376 error!(target: "stdout", "{}", &err_msg);
1377
1378 return Err(LlamaCoreError::Operation(err_msg));
1379 }
1380 };
1381
1382 if re.is_match(input) {
1383 match serde_json::from_str::<serde_json::Value>(input) {
1384 Ok(value) => {
1385 let values: Vec<serde_json::Value> = vec![value];
1386
1387 let mut tool_calls: Vec<ToolCall> = vec![];
1388 for value in values.iter() {
1389 let name = match value.get("name") {
1390 Some(name) => name.to_string().replace("\"", ""),
1391 None => {
1392 let err_msg = format!(
1393 "Failed to get the name of the function. Tool call: {value:?}"
1394 );
1395
1396 #[cfg(feature = "logging")]
1397 error!(target: "stdout", "{}", &err_msg);
1398
1399 return Err(LlamaCoreError::Operation(err_msg));
1400 }
1401 };
1402
1403 let arguments = match value.get("parameters") {
1404 Some(arguments) => arguments.to_string(),
1405 None => {
1406 let err_msg = format!(
1407 "Failed to get the arguments of the function. Tool call: {value:?}"
1408 );
1409
1410 #[cfg(feature = "logging")]
1411 error!(target: "stdout", "{}", &err_msg);
1412
1413 return Err(LlamaCoreError::Operation(err_msg));
1414 }
1415 };
1416
1417 let function = Function { name, arguments };
1418
1419 let tool_call = ToolCall {
1420 id: "call_abc123".to_string(),
1421 ty: "function".to_string(),
1422 function,
1423 };
1424
1425 tool_calls.push(tool_call);
1426 }
1427
1428 let parsed = ParseResult {
1429 raw: input.to_owned(),
1430 content: None,
1431 tool_calls,
1432 };
1433
1434 #[cfg(feature = "logging")]
1435 info!(target: "stdout", "parsed result: {parsed:?}");
1436
1437 Ok(parsed)
1438 }
1439 Err(e) => {
1440 let err_msg =
1441 format!("Failed to deserialize generated tool calls. Reason: {e}");
1442
1443 #[cfg(feature = "logging")]
1444 error!(target: "stdout", "{}", &err_msg);
1445
1446 Err(LlamaCoreError::Operation(err_msg))
1447 }
1448 }
1449 } else {
1450 let parsed = ParseResult {
1451 raw: input.to_owned(),
1452 content: None,
1453 tool_calls: vec![],
1454 };
1455
1456 #[cfg(feature = "logging")]
1457 info!(target: "stdout", "parsed result: {parsed:?}");
1458
1459 Ok(parsed)
1460 }
1461 }
1462 PromptTemplateType::InternLM2Tool => {
1463 #[cfg(feature = "logging")]
1464 info!(target: "stdout", "raw input: {input}");
1465
1466 let blocks: Vec<&str> = input.trim().split("<|action_start|><|plugin|>").collect();
1467
1468 #[cfg(feature = "logging")]
1469 info!(target: "stdout", "blocks: {blocks:?}");
1470
1471 let mut tool_calls: Vec<ToolCall> = vec![];
1472 let mut content = String::new();
1473 for block in blocks {
1474 let block = block.trim();
1475 if !block.is_empty() {
1476 if block.ends_with("<|action_end|>") {
1477 let value = block.trim().trim_end_matches("<|action_end|>");
1478
1479 #[cfg(feature = "logging")]
1480 info!(target: "stdout", "tool call: {value}");
1481
1482 match serde_json::from_str::<serde_json::Value>(value) {
1483 Ok(value) => {
1484 let name = match value.get("name") {
1485 Some(name) => name.to_string().replace("\"", ""),
1486 None => {
1487 let err_msg = format!(
1488 "Failed to get the name of the function. Tool call: {value:?}"
1489 );
1490
1491 #[cfg(feature = "logging")]
1492 error!(target: "stdout", "{}", &err_msg);
1493
1494 return Err(LlamaCoreError::Operation(err_msg));
1495 }
1496 };
1497
1498 let arguments = match value.get("parameters") {
1499 Some(arguments) => arguments.to_string(),
1500 None => {
1501 let err_msg = format!(
1502 "Failed to get the arguments of the function. Tool call: {value:?}"
1503 );
1504
1505 #[cfg(feature = "logging")]
1506 error!(target: "stdout", "{}", &err_msg);
1507
1508 return Err(LlamaCoreError::Operation(err_msg));
1509 }
1510 };
1511
1512 let function = Function { name, arguments };
1513
1514 let tool_call = ToolCall {
1515 id: "call_abc123".to_string(),
1516 ty: "function".to_string(),
1517 function,
1518 };
1519
1520 tool_calls.push(tool_call);
1521 }
1522 Err(e) => {
1523 let err_msg = format!(
1524 "Failed to deserialize generated tool calls. Reason: {e}"
1525 );
1526
1527 #[cfg(feature = "logging")]
1528 error!(target: "stdout", "{}", &err_msg);
1529
1530 return Err(LlamaCoreError::Operation(err_msg));
1531 }
1532 }
1533 } else {
1534 content.push_str(block);
1535 content.push('\n');
1536 }
1537 }
1538 }
1539
1540 let parsed = match content.is_empty() {
1541 true => ParseResult {
1542 raw: input.to_owned(),
1543 content: None,
1544 tool_calls,
1545 },
1546 false => ParseResult {
1547 raw: input.to_owned(),
1548 content: Some(content.trim().to_owned()),
1549 tool_calls,
1550 },
1551 };
1552
1553 #[cfg(feature = "logging")]
1554 info!(target: "stdout", "parsed result: {parsed:?}");
1555
1556 Ok(parsed)
1557 }
1558 PromptTemplateType::NemotronTool => {
1559 #[cfg(feature = "logging")]
1560 info!(target: "stdout", "raw input: {input}");
1561
1562 match regex::Regex::new(r"(?s)<toolcall>\s*(.*?)\s*</toolcall>") {
1563 Ok(re) => {
1564 let mut values: Vec<serde_json::Value> = vec![];
1565 for cap in re.captures_iter(input) {
1566 #[cfg(feature = "logging")]
1567 info!(target: "stdout", "captured: {}", &cap[0]);
1568
1569 #[cfg(feature = "logging")]
1570 info!(target: "stdout", "extracted: {}", &cap[1]);
1571
1572 let matched = cap[1].trim();
1573
1574 #[cfg(feature = "logging")]
1575 info!(target: "stdout", "captured: {matched}");
1576
1577 match serde_json::from_str::<serde_json::Value>(matched) {
1578 Ok(value) => values.push(value),
1579 Err(e) => {
1580 let err_msg = format!(
1581 "Failed to deserialize generated tool calls. Reason: {e}"
1582 );
1583
1584 #[cfg(feature = "logging")]
1585 error!(target: "stdout", "{}", &err_msg);
1586
1587 return Err(LlamaCoreError::Operation(err_msg));
1588 }
1589 }
1590 }
1591
1592 let mut tool_calls: Vec<ToolCall> = vec![];
1593 for value in values.iter() {
1594 let name = match value.get("name") {
1595 Some(name) => name.to_string().replace("\"", ""),
1596 None => {
1597 let err_msg = format!(
1598 "Failed to get the name of the function. Tool call: {value:?}"
1599 );
1600
1601 #[cfg(feature = "logging")]
1602 error!(target: "stdout", "{}", &err_msg);
1603
1604 return Err(LlamaCoreError::Operation(err_msg));
1605 }
1606 };
1607
1608 let arguments = match value.get("arguments") {
1609 Some(arguments) => arguments.to_string(),
1610 None => {
1611 let err_msg = format!(
1612 "Failed to get the arguments of the function. Tool call: {value:?}"
1613 );
1614
1615 #[cfg(feature = "logging")]
1616 error!(target: "stdout", "{}", &err_msg);
1617
1618 return Err(LlamaCoreError::Operation(err_msg));
1619 }
1620 };
1621
1622 let function = Function { name, arguments };
1623
1624 let tool_call = ToolCall {
1625 id: "call_abc123".to_string(),
1626 ty: "function".to_string(),
1627 function,
1628 };
1629
1630 tool_calls.push(tool_call);
1631 }
1632
1633 let parsed = ParseResult {
1634 raw: input.to_owned(),
1635 content: None,
1636 tool_calls,
1637 };
1638
1639 #[cfg(feature = "logging")]
1640 info!(target: "stdout", "parsed result: {parsed:?}");
1641
1642 Ok(parsed)
1643 }
1644 Err(e) => {
1645 let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1646
1647 #[cfg(feature = "logging")]
1648 error!(target: "stdout", "{}", &err_msg);
1649
1650 Err(LlamaCoreError::Operation(err_msg))
1651 }
1652 }
1653 }
1654 PromptTemplateType::FunctionaryV32 => {
1655 #[cfg(feature = "logging")]
1656 info!(target: "stdout", "raw input: {input}");
1657
1658 match regex::Regex::new(r">>>\s*(\w+)\s*\{(.*)\}<\|eot_id\|>") {
1659 Ok(re) => {
1660 let mut tool_calls: Vec<ToolCall> = vec![];
1661 for cap in re.captures_iter(input) {
1662 #[cfg(feature = "logging")]
1663 info!(target: "stdout", "func_name: {}", &cap[1]);
1664
1665 #[cfg(feature = "logging")]
1666 info!(target: "stdout", "arguments: {}", &cap[2]);
1667
1668 let tool_call = ToolCall {
1669 id: "call_abc123".to_string(),
1670 ty: "function".to_string(),
1671 function: Function {
1672 name: cap[1].to_string(),
1673 arguments: cap[2].to_string(),
1674 },
1675 };
1676
1677 tool_calls.push(tool_call);
1678 }
1679
1680 let parsed = ParseResult {
1681 raw: input.to_owned(),
1682 content: None,
1683 tool_calls,
1684 };
1685
1686 #[cfg(feature = "logging")]
1687 info!(target: "stdout", "parsed result: {parsed:?}");
1688
1689 Ok(parsed)
1690 }
1691 Err(e) => {
1692 let warn_msg = format!("Failed to create a regex pattern. Reason: {e}");
1693
1694 #[cfg(feature = "logging")]
1695 warn!(target: "stdout", "{}", &warn_msg);
1696
1697 Ok(ParseResult {
1698 raw: input.to_owned(),
1699 content: None,
1700 tool_calls: vec![],
1701 })
1702 }
1703 }
1704 }
1705 PromptTemplateType::FunctionaryV31 => {
1706 #[cfg(feature = "logging")]
1707 info!(target: "stdout", "raw input: {input}");
1708
1709 match regex::Regex::new(r"<function=(\w+)>\s*(\{.*?\})</function>") {
1710 Ok(re) => {
1711 let mut tool_calls: Vec<ToolCall> = vec![];
1712 for cap in re.captures_iter(input) {
1713 #[cfg(feature = "logging")]
1714 info!(target: "stdout", "func_name: {}", &cap[1]);
1715
1716 #[cfg(feature = "logging")]
1717 info!(target: "stdout", "arguments: {}", &cap[2]);
1718
1719 let tool_call = ToolCall {
1720 id: "call_abc123".to_string(),
1721 ty: "function".to_string(),
1722 function: Function {
1723 name: cap[1].to_string(),
1724 arguments: cap[2].to_string(),
1725 },
1726 };
1727
1728 tool_calls.push(tool_call);
1729 }
1730
1731 let parsed = ParseResult {
1732 raw: input.to_owned(),
1733 content: None,
1734 tool_calls,
1735 };
1736
1737 #[cfg(feature = "logging")]
1738 info!(target: "stdout", "parsed result: {parsed:?}");
1739
1740 Ok(parsed)
1741 }
1742 Err(e) => {
1743 let warn_msg = format!("Failed to create a regex pattern. Reason: {e}");
1744
1745 #[cfg(feature = "logging")]
1746 warn!(target: "stdout", "{}", &warn_msg);
1747
1748 Ok(ParseResult {
1749 raw: input.to_owned(),
1750 content: None,
1751 tool_calls: vec![],
1752 })
1753 }
1754 }
1755 }
1756 PromptTemplateType::MistralSmallTool => {
1757 #[cfg(feature = "logging")]
1758 info!(target: "stdout", "raw input: {input}");
1759
1760 match regex::Regex::new(r"\[TOOL_CALLS\]\s*(\[(.*?)\])") {
1761 Ok(re) => {
1762 let mut values: Vec<serde_json::Value> = vec![];
1763 if let Some(cap) = re.captures(input) {
1764 let matched = cap[1].trim();
1765
1766 #[cfg(feature = "logging")]
1767 info!(target: "stdout", "captured: {matched}");
1768
1769 match serde_json::from_str::<Vec<serde_json::Value>>(matched) {
1770 Ok(vals) => values = vals,
1771 Err(e) => {
1772 let err_msg = format!(
1773 "Failed to deserialize generated tool calls. Reason: {e}"
1774 );
1775
1776 #[cfg(feature = "logging")]
1777 error!(target: "stdout", "{}", &err_msg);
1778
1779 return Err(LlamaCoreError::Operation(err_msg));
1780 }
1781 }
1782 };
1783
1784 let mut tool_calls: Vec<ToolCall> = vec![];
1785 for value in values.iter() {
1786 if let Some(object_map) = value.as_object() {
1787 if object_map.contains_key("function") {
1788 let mut function = Function {
1789 name: String::new(),
1790 arguments: String::new(),
1791 };
1792
1793 let value = object_map.get("function").unwrap();
1794 let func_map = value.as_object().unwrap();
1795 if func_map.contains_key("name") {
1796 let func_name = func_map.get("name").unwrap().as_str().unwrap();
1797 println!("Function name: {func_name:?}");
1798
1799 function.name = func_name.to_string();
1800 }
1801 if func_map.contains_key("arguments") {
1802 let args = func_map.get("arguments").unwrap();
1803 let arguments = args.to_string();
1804 println!("Arguments: {arguments:?}");
1805
1806 function.arguments = arguments;
1807 }
1808
1809 let tool_call = ToolCall {
1810 id: "call_abc123".to_string(),
1811 ty: "function".to_string(),
1812 function,
1813 };
1814
1815 tool_calls.push(tool_call);
1816 } else if object_map.contains_key("name") {
1817 let mut function = Function {
1818 name: String::new(),
1819 arguments: String::new(),
1820 };
1821
1822 let name = object_map.get("name").unwrap().as_str().unwrap();
1823 println!("name: {name:?}");
1824 function.name = name.to_string();
1825
1826 if object_map.contains_key("arguments") {
1827 let args = object_map.get("arguments").unwrap();
1828 let arguments = args.to_string();
1829 println!("Arguments: {arguments:?}");
1830
1831 function.arguments = arguments;
1832 }
1833
1834 let tool_call = ToolCall {
1835 id: "call_abc123".to_string(),
1836 ty: "function".to_string(),
1837 function,
1838 };
1839
1840 tool_calls.push(tool_call);
1841 }
1842 }
1843 }
1844
1845 let parsed = ParseResult {
1846 raw: input.to_owned(),
1847 content: None,
1848 tool_calls,
1849 };
1850
1851 #[cfg(feature = "logging")]
1852 info!(target: "stdout", "parsed result: {parsed:?}");
1853
1854 Ok(parsed)
1855 }
1856 Err(e) => {
1857 let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1858
1859 #[cfg(feature = "logging")]
1860 error!(target: "stdout", "{}", &err_msg);
1861
1862 Err(LlamaCoreError::Operation(err_msg))
1863 }
1864 }
1865 }
1866 PromptTemplateType::Llama4Chat => {
1867 #[cfg(feature = "logging")]
1868 info!(target: "stdout", "raw input: {input:?}");
1869
1870 let mut tool_calls: Vec<ToolCall> = vec![];
1871 if let Ok(value) = serde_json::from_str::<serde_json::Value>(input) {
1872 match value.as_object() {
1873 Some(object_map) => {
1874 #[cfg(feature = "logging")]
1875 debug!(target: "stdout", "object_map: {object_map:?}");
1876
1877 if object_map.contains_key("name") {
1879 let name = object_map.get("name").unwrap().as_str().unwrap();
1880
1881 #[cfg(feature = "logging")]
1882 debug!(target: "stdout", "name: {name:?}");
1883
1884 let mut function = Function {
1885 name: name.to_string(),
1886 arguments: String::new(),
1887 };
1888
1889 if object_map.contains_key("parameters") {
1891 let args = object_map.get("parameters").unwrap();
1892 let arguments = args.to_string();
1893
1894 #[cfg(feature = "logging")]
1895 debug!(target: "stdout", "arguments: {:?}", &arguments);
1896
1897 function.arguments = arguments;
1898 }
1899
1900 tool_calls.push(ToolCall {
1901 id: "call_abc123".to_string(),
1902 ty: "function".to_string(),
1903 function,
1904 });
1905 } else {
1906 let err_msg = format!(
1907 "Failed to get the name of the function. raw input: {input:?}"
1908 );
1909
1910 #[cfg(feature = "logging")]
1911 error!(target: "stdout", "{}", &err_msg);
1912
1913 return Err(LlamaCoreError::Operation(err_msg));
1914 }
1915 }
1916 None => {
1917 let err_msg = format!("Failed to parse the JSON string. JSON: {input}");
1918
1919 #[cfg(feature = "logging")]
1920 error!(target: "stdout", "{}", &err_msg);
1921
1922 return Err(LlamaCoreError::Operation(err_msg));
1923 }
1924 }
1925 }
1926
1927 let parsed = ParseResult {
1928 raw: input.to_owned(),
1929 content: None,
1930 tool_calls,
1931 };
1932
1933 #[cfg(feature = "logging")]
1934 info!(target: "stdout", "parsed result: {parsed:?}");
1935
1936 Ok(parsed)
1937 }
1938 PromptTemplateType::Qwen3NoThink | PromptTemplateType::Smol3NoThink => {
1939 #[cfg(feature = "logging")]
1940 info!(target: "stdout", "raw input: {input:?}");
1941
1942 match regex::Regex::new(r"(?s)<tool_call>((.|\r|\n)*?)</tool_call>") {
1943 Ok(re) => {
1944 let mut values: Vec<serde_json::Value> = vec![];
1945 for cap in re.captures_iter(input) {
1946 let mut matched = cap[1].trim();
1947
1948 if matched.starts_with("\\n") {
1949 matched = matched.trim_start_matches("\\n");
1950 }
1951
1952 if matched.ends_with("\\n") {
1953 matched = matched.trim_end_matches("\\n");
1954 }
1955
1956 #[cfg(feature = "logging")]
1957 info!(target: "stdout", "captured: {matched:#?}");
1958
1959 if !matched.is_empty() {
1960 match serde_json::from_str::<serde_json::Value>(matched) {
1961 Ok(value) => values.push(value),
1962 Err(e) => {
1963 let err_msg = format!(
1964 "Failed to deserialize generated tool calls: {matched:#?}. Reason: {e}"
1965 );
1966
1967 #[cfg(feature = "logging")]
1968 error!(target: "stdout", "{}", &err_msg);
1969
1970 return Err(LlamaCoreError::Operation(err_msg));
1971 }
1972 }
1973 }
1974 }
1975
1976 let mut tool_calls: Vec<ToolCall> = vec![];
1977 for value in values.iter() {
1978 let name = match value.get("name") {
1979 Some(name) => name.to_string().replace("\"", ""),
1980 None => {
1981 let err_msg = format!(
1982 "Failed to get the name of the function. Tool call: {value:?}"
1983 );
1984
1985 #[cfg(feature = "logging")]
1986 error!(target: "stdout", "{}", &err_msg);
1987
1988 return Err(LlamaCoreError::Operation(err_msg));
1989 }
1990 };
1991
1992 let arguments = match value.get("arguments") {
1993 Some(arguments) => {
1994 if arguments.is_string() {
1995 arguments.as_str().unwrap().to_string()
1996 } else if arguments.is_object() {
1997 let map = arguments.as_object().unwrap();
1998
1999 #[cfg(feature = "logging")]
2000 info!(target: "stdout", "func arguments: {map:?}");
2001
2002 serde_json::to_string(map).unwrap()
2003 } else {
2004 serde_json::to_string(arguments).unwrap()
2005 }
2006 }
2007 None => {
2008 let err_msg = format!(
2009 "Failed to get the arguments of the function. Tool call: {value:?}"
2010 );
2011
2012 #[cfg(feature = "logging")]
2013 error!(target: "stdout", "{}", &err_msg);
2014
2015 return Err(LlamaCoreError::Operation(err_msg));
2016 }
2017 };
2018
2019 let function = Function { name, arguments };
2020
2021 let tool_call = ToolCall {
2022 id: "call_abc123".to_string(),
2023 ty: "function".to_string(),
2024 function,
2025 };
2026
2027 tool_calls.push(tool_call);
2028 }
2029
2030 let parsed = if tool_calls.is_empty() {
2031 ParseResult {
2032 raw: input.to_owned(),
2033 content: Some(input.to_owned()),
2034 tool_calls: vec![],
2035 }
2036 } else {
2037 ParseResult {
2038 raw: input.to_owned(),
2039 content: None,
2040 tool_calls,
2041 }
2042 };
2043
2044 #[cfg(feature = "logging")]
2045 info!(target: "stdout", "parsed result: {parsed:?}");
2046
2047 Ok(parsed)
2048 }
2049 Err(e) => {
2050 let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
2051
2052 #[cfg(feature = "logging")]
2053 error!(target: "stdout", "{}", &err_msg);
2054
2055 Err(LlamaCoreError::Operation(err_msg))
2056 }
2057 }
2058 }
2059 PromptTemplateType::Gemma3 => {
2060 #[cfg(feature = "logging")]
2061 info!(target: "stdout", "raw input: {input:?}");
2062
2063 match regex::Regex::new(r"(?s)```json\s*(.*?)\s*```") {
2064 Ok(re) => {
2065 let mut values: Vec<serde_json::Value> = vec![];
2066 for cap in re.captures_iter(input) {
2067 let mut matched = cap[1].trim();
2068
2069 if matched.starts_with("\\n") {
2070 matched = matched.trim_start_matches("\\n");
2071 }
2072
2073 if matched.ends_with("\\n") {
2074 matched = matched.trim_end_matches("\\n");
2075 }
2076
2077 #[cfg(feature = "logging")]
2078 info!(target: "stdout", "captured: {matched:#?}");
2079
2080 if !matched.is_empty() {
2081 match serde_json::from_str::<serde_json::Value>(matched) {
2082 Ok(value) => values.push(value),
2083 Err(e) => {
2084 let err_msg = format!(
2085 "Failed to deserialize generated tool calls: {matched:#?}. Reason: {e}"
2086 );
2087
2088 #[cfg(feature = "logging")]
2089 error!(target: "stdout", "{}", &err_msg);
2090
2091 return Err(LlamaCoreError::Operation(err_msg));
2092 }
2093 }
2094 }
2095 }
2096
2097 let mut tool_calls: Vec<ToolCall> = vec![];
2098 for value in values.iter() {
2099 let name = match value.get("name") {
2100 Some(name) => name.to_string().replace("\"", ""),
2101 None => {
2102 let err_msg = format!(
2103 "Failed to get the name of the function. Tool call: {value:?}"
2104 );
2105
2106 #[cfg(feature = "logging")]
2107 error!(target: "stdout", "{}", &err_msg);
2108
2109 return Err(LlamaCoreError::Operation(err_msg));
2110 }
2111 };
2112
2113 let arguments = match value.get("arguments") {
2114 Some(arguments) => {
2115 if arguments.is_string() {
2116 arguments.as_str().unwrap().to_string()
2117 } else if arguments.is_object() {
2118 let map = arguments.as_object().unwrap();
2119
2120 #[cfg(feature = "logging")]
2121 info!(target: "stdout", "func arguments: {map:?}");
2122
2123 serde_json::to_string(map).unwrap()
2124 } else {
2125 serde_json::to_string(arguments).unwrap()
2126 }
2127 }
2128 None => {
2129 let err_msg = format!(
2130 "Failed to get the arguments of the function. Tool call: {value:?}"
2131 );
2132
2133 #[cfg(feature = "logging")]
2134 error!(target: "stdout", "{}", &err_msg);
2135
2136 return Err(LlamaCoreError::Operation(err_msg));
2137 }
2138 };
2139
2140 let function = Function { name, arguments };
2141
2142 let tool_call = ToolCall {
2143 id: "call_abc123".to_string(),
2144 ty: "function".to_string(),
2145 function,
2146 };
2147
2148 tool_calls.push(tool_call);
2149 }
2150
2151 let parsed = if tool_calls.is_empty() {
2152 ParseResult {
2153 raw: input.to_owned(),
2154 content: Some(input.to_owned()),
2155 tool_calls: vec![],
2156 }
2157 } else {
2158 ParseResult {
2159 raw: input.to_owned(),
2160 content: None,
2161 tool_calls,
2162 }
2163 };
2164
2165 #[cfg(feature = "logging")]
2166 info!(target: "stdout", "parsed result: {parsed:?}");
2167
2168 Ok(parsed)
2169 }
2170 Err(e) => {
2171 let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
2172
2173 #[cfg(feature = "logging")]
2174 error!(target: "stdout", "{}", &err_msg);
2175
2176 Err(LlamaCoreError::Operation(err_msg))
2177 }
2178 }
2179 }
2180 PromptTemplateType::GptOss => {
2181 #[cfg(feature = "logging")]
2182 info!(target: "stdout", "raw input: {input:?}");
2183
2184 match regex::Regex::new(
2186 r"<\|channel\|>commentary to=functions\.([^<\s]+)\s*<\|constrain\|>json<\|message\|>([^<]*)<\|call\|>$",
2187 ) {
2188 Ok(re) => {
2189 if let Some(cap) = re.captures(input) {
2190 let function_name = cap[1].trim();
2191 let arguments = cap[2].trim();
2192
2193 #[cfg(feature = "logging")]
2194 info!(target: "stdout", "extracted function_name: {function_name}, arguments: {arguments}");
2195
2196 let function = Function {
2197 name: function_name.to_string(),
2198 arguments: arguments.to_string(),
2199 };
2200
2201 let tool_call = ToolCall {
2202 id: "call_abc123".to_string(),
2203 ty: "function".to_string(),
2204 function,
2205 };
2206
2207 let parsed = ParseResult {
2208 raw: input.to_owned(),
2209 content: None,
2210 tool_calls: vec![tool_call],
2211 };
2212
2213 #[cfg(feature = "logging")]
2214 info!(target: "stdout", "parsed result: {parsed:?}");
2215
2216 Ok(parsed)
2217 } else {
2218 match regex::Regex::new(r"(?s)```json\s*(.*?)\s*```") {
2219 Ok(re) => {
2220 let mut values: Vec<serde_json::Value> = vec![];
2221 for cap in re.captures_iter(input) {
2222 let mut matched = cap[1].trim();
2223
2224 if matched.starts_with("\\n") {
2225 matched = matched.trim_start_matches("\\n");
2226 }
2227
2228 if matched.ends_with("\\n") {
2229 matched = matched.trim_end_matches("\\n");
2230 }
2231
2232 #[cfg(feature = "logging")]
2233 info!(target: "stdout", "captured: {matched:#?}");
2234
2235 if !matched.is_empty() {
2236 match serde_json::from_str::<serde_json::Value>(matched) {
2237 Ok(value) => values.push(value),
2238 Err(e) => {
2239 let err_msg = format!(
2240 "Failed to deserialize generated tool calls: {matched:#?}. Reason: {e}"
2241 );
2242
2243 #[cfg(feature = "logging")]
2244 error!(target: "stdout", "{}", &err_msg);
2245
2246 return Err(LlamaCoreError::Operation(err_msg));
2247 }
2248 }
2249 }
2250 }
2251
2252 let mut tool_calls: Vec<ToolCall> = vec![];
2253 for value in values.iter() {
2254 let name = match value.get("name") {
2255 Some(name) => name.to_string().replace("\"", ""),
2256 None => {
2257 let err_msg = format!(
2258 "Failed to get the name of the function. Tool call: {value:?}"
2259 );
2260
2261 #[cfg(feature = "logging")]
2262 error!(target: "stdout", "{}", &err_msg);
2263
2264 return Err(LlamaCoreError::Operation(err_msg));
2265 }
2266 };
2267
2268 let arguments = match value.get("arguments") {
2269 Some(arguments) => {
2270 if arguments.is_string() {
2271 arguments.as_str().unwrap().to_string()
2272 } else if arguments.is_object() {
2273 let map = arguments.as_object().unwrap();
2274
2275 #[cfg(feature = "logging")]
2276 info!(target: "stdout", "func arguments: {map:?}");
2277
2278 serde_json::to_string(map).unwrap()
2279 } else {
2280 serde_json::to_string(arguments).unwrap()
2281 }
2282 }
2283 None => {
2284 let err_msg = format!(
2285 "Failed to get the arguments of the function. Tool call: {value:?}"
2286 );
2287
2288 #[cfg(feature = "logging")]
2289 error!(target: "stdout", "{}", &err_msg);
2290
2291 return Err(LlamaCoreError::Operation(err_msg));
2292 }
2293 };
2294
2295 let function = Function { name, arguments };
2296
2297 let tool_call = ToolCall {
2298 id: "call_abc123".to_string(),
2299 ty: "function".to_string(),
2300 function,
2301 };
2302
2303 tool_calls.push(tool_call);
2304 }
2305
2306 let parsed = if tool_calls.is_empty() {
2307 ParseResult {
2308 raw: input.to_owned(),
2309 content: Some(input.to_owned()),
2310 tool_calls: vec![],
2311 }
2312 } else {
2313 ParseResult {
2314 raw: input.to_owned(),
2315 content: Some(input.to_owned()),
2316 tool_calls,
2317 }
2318 };
2319
2320 #[cfg(feature = "logging")]
2321 info!(target: "stdout", "parsed result: {parsed:?}");
2322
2323 Ok(parsed)
2324 }
2325 Err(e) => {
2326 let err_msg =
2327 format!("Failed to create a regex pattern. Reason: {e}");
2328
2329 #[cfg(feature = "logging")]
2330 error!(target: "stdout", "{}", &err_msg);
2331
2332 Err(LlamaCoreError::Operation(err_msg))
2333 }
2334 }
2335 }
2336 }
2337 Err(e) => {
2338 let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
2339
2340 #[cfg(feature = "logging")]
2341 error!(target: "stdout", "{}", &err_msg);
2342
2343 Err(LlamaCoreError::Operation(err_msg))
2344 }
2345 }
2346 }
2347 PromptTemplateType::Qwen3Agent => {
2348 #[cfg(feature = "logging")]
2349 info!(target: "stdout", "Raw input to tool call parser: {input:?}");
2350
2351 match regex::Regex::new(r"<action>(.*?)</action>")
2353 .unwrap()
2354 .captures(input)
2355 {
2356 Some(captures) => {
2357 let action = captures.get(1).unwrap().as_str();
2358
2359 #[cfg(feature = "logging")]
2360 info!(target: "stdout", "Action: {action}");
2361
2362 match serde_json::from_str::<serde_json::Value>(action) {
2363 Ok(value) => {
2364 let name = match value.get("name") {
2365 Some(name) => name.to_string().replace("\"", ""),
2366 None => {
2367 let err_msg = format!(
2368 "Failed to get the name of the function. Tool call: {value:?}"
2369 );
2370
2371 #[cfg(feature = "logging")]
2372 error!(target: "stdout", "{}", &err_msg);
2373
2374 return Err(LlamaCoreError::Operation(err_msg));
2375 }
2376 };
2377
2378 let arguments = match value.get("arguments") {
2379 Some(arguments) => {
2380 if arguments.is_string() {
2381 arguments.as_str().unwrap().to_string()
2382 } else if arguments.is_object() {
2383 let map = arguments.as_object().unwrap();
2384
2385 #[cfg(feature = "logging")]
2386 info!(target: "stdout", "func arguments: {map:?}");
2387
2388 serde_json::to_string(map).unwrap()
2389 } else {
2390 serde_json::to_string(arguments).unwrap()
2391 }
2392 }
2393 None => {
2394 let err_msg = format!(
2395 "Failed to get the arguments of the function. Tool call: {value:?}"
2396 );
2397
2398 #[cfg(feature = "logging")]
2399 error!(target: "stdout", "{}", &err_msg);
2400
2401 return Err(LlamaCoreError::Operation(err_msg));
2402 }
2403 };
2404
2405 let function = Function { name, arguments };
2406
2407 let tool_call = ToolCall {
2408 id: "call_abc123".to_string(),
2409 ty: "function".to_string(),
2410 function,
2411 };
2412
2413 Ok(ParseResult {
2414 raw: input.to_owned(),
2415 content: Some(input.to_owned()),
2416 tool_calls: vec![tool_call],
2417 })
2418 }
2419 Err(e) => {
2420 let err_msg = format!(
2421 "Failed to deserialize generated tool calls: {action:#?}. Reason: {e}"
2422 );
2423
2424 #[cfg(feature = "logging")]
2425 error!(target: "stdout", "{}", &err_msg);
2426
2427 Err(LlamaCoreError::Operation(err_msg))
2428 }
2429 }
2430 }
2431 None => match input.contains("<final_answer>") {
2432 true => Ok(ParseResult {
2433 raw: input.to_owned(),
2434 content: Some(input.to_owned()),
2435 tool_calls: vec![],
2436 }),
2437 false => {
2438 let content = format!("<final_answer>{}</final_answer>", input.trim());
2439
2440 Ok(ParseResult {
2441 raw: input.to_owned(),
2442 content: Some(content),
2443 tool_calls: vec![],
2444 })
2445 }
2446 },
2447 }
2448 }
2449 _ => {
2450 let err_msg = format!(
2451 "The tool use is only supported for prompt templates: {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, and {}.",
2452 PromptTemplateType::MistralTool,
2453 PromptTemplateType::ChatMLTool,
2454 PromptTemplateType::GroqLlama3Tool,
2455 PromptTemplateType::Llama3Tool,
2456 PromptTemplateType::InternLM2Tool,
2457 PromptTemplateType::NemotronTool,
2458 PromptTemplateType::FunctionaryV32,
2459 PromptTemplateType::MistralSmallTool,
2460 PromptTemplateType::Llama4Chat,
2461 PromptTemplateType::Qwen3NoThink,
2462 PromptTemplateType::Smol3NoThink,
2463 PromptTemplateType::Gemma3,
2464 PromptTemplateType::GptOss,
2465 PromptTemplateType::Qwen3Agent,
2466 );
2467
2468 #[cfg(feature = "logging")]
2469 error!(target: "stdout", "{}", &err_msg);
2470
2471 Err(LlamaCoreError::Operation(err_msg))
2472 }
2473 }
2474}
2475
2476fn check_model_metadata(
2477 chat_request: &ChatCompletionRequest,
2478) -> Result<GgmlMetadata, LlamaCoreError> {
2479 let mut should_update = false;
2480 let mut metadata = get_model_metadata(chat_request.model.as_ref())?;
2481
2482 if metadata.prompt_template.is_image_supported() {
2484 if let Some(ChatCompletionRequestMessage::User(user_message)) = chat_request.messages.last()
2485 {
2486 if let ChatCompletionUserMessageContent::Parts(parts) = user_message.content() {
2487 for part in parts {
2488 if let ContentPart::Image(image_part) = part {
2489 let image = image_part.image();
2490
2491 if image.is_url() {
2492 let err_msg = "The image is provided in URL format. Only base64 format is supported.".to_string();
2493
2494 #[cfg(feature = "logging")]
2495 error!(target: "stdout", "{}", &err_msg);
2496
2497 return Err(LlamaCoreError::Operation(err_msg));
2498 } else {
2499 #[cfg(feature = "logging")]
2500 info!(target: "stdout", "The image is provided in base64 format.");
2501
2502 break;
2505 }
2506 }
2507 }
2508 }
2509 }
2510 }
2511
2512 if let Some(temp) = chat_request.temperature {
2514 if metadata.temperature != temp {
2515 metadata.temperature = temp;
2517
2518 if !should_update {
2519 should_update = true;
2520 }
2521 }
2522 }
2523
2524 if let Some(top_p) = chat_request.top_p {
2526 if metadata.top_p != top_p {
2527 metadata.top_p = top_p;
2529
2530 if !should_update {
2531 should_update = true;
2532 }
2533 }
2534 }
2535
2536 if let Some(frequency_penalty) = chat_request.frequency_penalty {
2538 if metadata.frequency_penalty != frequency_penalty {
2539 metadata.frequency_penalty = frequency_penalty;
2541
2542 if !should_update {
2543 should_update = true;
2544 }
2545 }
2546 }
2547
2548 if let Some(presence_penalty) = chat_request.presence_penalty {
2550 if metadata.presence_penalty != presence_penalty {
2551 metadata.presence_penalty = presence_penalty;
2553
2554 if !should_update {
2555 should_update = true;
2556 }
2557 }
2558 }
2559
2560 if metadata.embeddings {
2562 metadata.embeddings = false;
2563
2564 if !should_update {
2565 should_update = true;
2566 }
2567 }
2568
2569 if should_update {
2570 #[cfg(feature = "logging")]
2571 info!(target: "stdout", "Update the model metadata.");
2572
2573 update_model_metadata(chat_request.model.as_ref(), &metadata)?;
2575 }
2576
2577 Ok(metadata)
2578}
2579
2580fn update_n_predict(
2581 chat_request: &ChatCompletionRequest,
2582 metadata: &mut GgmlMetadata,
2583 available_completion_tokens: u64,
2584) -> Result<(), LlamaCoreError> {
2585 let mut should_update = false;
2586
2587 #[cfg(feature = "logging")]
2588 info!(target: "stdout", "n_predict: {}", metadata.n_predict);
2589
2590 if let Some(max_completion_tokens) = chat_request.max_completion_tokens {
2596 if metadata.n_predict != max_completion_tokens {
2597 #[cfg(feature = "logging")]
2598 info!(target: "stdout", "Update n_predict with max_completion_tokens from {} to {}", metadata.n_predict, max_completion_tokens);
2599
2600 metadata.n_predict = max_completion_tokens;
2601
2602 if !should_update {
2603 should_update = true;
2604 }
2605 }
2606 }
2607
2608 if metadata.n_predict == -2 {
2610 #[cfg(feature = "logging")]
2611 info!(target: "stdout", "Update n_predict with available_completion_tokens from {} to {}", metadata.n_predict, available_completion_tokens);
2612
2613 metadata.n_predict = available_completion_tokens as i32;
2615
2616 if !should_update {
2617 should_update = true;
2618 }
2619 }
2620
2621 if metadata.n_predict == -1
2622 || (metadata.n_predict > 0 && metadata.n_predict < available_completion_tokens as i32)
2623 || (metadata.n_predict < 0 && metadata.n_predict != -2)
2624 {
2626 #[cfg(feature = "logging")]
2627 info!(target: "stdout", "Update n_predict with available_completion_tokens from {} to {}", metadata.n_predict, available_completion_tokens);
2628
2629 metadata.n_predict = available_completion_tokens as i32;
2631
2632 if !should_update {
2633 should_update = true;
2634 }
2635 }
2636
2637 if should_update {
2638 #[cfg(feature = "logging")]
2639 info!(target: "stdout", "Update the model metadata.");
2640
2641 update_model_metadata(chat_request.model.as_ref(), metadata)?;
2643 }
2644
2645 Ok(())
2646}
2647
2648fn post_process(
2650 output: impl AsRef<str>,
2651 template_ty: &PromptTemplateType,
2652) -> Result<String, String> {
2653 let output = if *template_ty == PromptTemplateType::Baichuan2 {
2654 if output.as_ref().contains("用户:") {
2655 output.as_ref().trim_end_matches("用户:").trim().to_owned()
2656 } else {
2657 output.as_ref().trim().to_owned()
2658 }
2659 } else if *template_ty == PromptTemplateType::OpenChat {
2660 if output.as_ref().contains("<|end_of_turn|>") {
2661 output
2662 .as_ref()
2663 .trim_end_matches("<|end_of_turn|>")
2664 .trim()
2665 .to_owned()
2666 } else {
2667 output.as_ref().trim().to_owned()
2668 }
2669 } else if *template_ty == PromptTemplateType::GemmaInstruct
2670 || *template_ty == PromptTemplateType::Gemma3
2671 {
2672 let s = output.as_ref().trim();
2673 if s.ends_with("<end_of_turn>") {
2674 s.trim_end_matches("<end_of_turn>").trim().to_owned()
2675 } else {
2676 s.to_owned()
2677 }
2678 } else if *template_ty == PromptTemplateType::ChatML
2679 || *template_ty == PromptTemplateType::ChatMLTool
2680 || *template_ty == PromptTemplateType::InternLM2Tool
2681 || *template_ty == PromptTemplateType::MiniCPMV
2682 {
2683 let mut s = output.as_ref().trim();
2684 if s.ends_with("<|endoftext|>") {
2685 s = s.trim_end_matches("<|endoftext|>").trim();
2686 }
2687
2688 if s.starts_with(":") {
2689 s = s.trim_start_matches(":").trim();
2690 }
2691
2692 let x = {
2694 let pat = r#"<think>
2695
2696</think>
2697"#;
2698 if s.contains(pat) {
2699 let x = s.replace(pat, "");
2700 if x.starts_with("()") {
2701 x.trim_start_matches("()").to_owned()
2702 } else {
2703 x.to_owned()
2704 }
2705 } else {
2706 s.to_owned()
2707 }
2708 };
2709 s = x.trim();
2710
2711 if s.contains("<|im_start|>") && s.contains("<|im_end|>") {
2712 let idx_start = s.find("<|im_start|>").unwrap();
2713 let idx_end = s.find("<|im_end|>").unwrap();
2714
2715 match idx_start <= idx_end {
2716 true => s.split("<|im_start|>").collect::<Vec<_>>()[0]
2717 .trim()
2718 .to_owned(),
2719 false => s.split("<|im_end|>").collect::<Vec<_>>()[0]
2720 .trim()
2721 .to_owned(),
2722 }
2723 } else if s.contains("<|im_start|>") {
2724 s.split("<|im_start|>").collect::<Vec<_>>()[0]
2725 .trim()
2726 .to_owned()
2727 } else if s.contains("<|im_end|>") {
2728 let output = s.trim_end_matches("<|im_end|>").trim();
2729 if output.starts_with(": ") {
2730 output.trim_start_matches(": ").to_owned()
2731 } else {
2732 output.to_owned()
2733 }
2734 } else {
2735 s.to_owned()
2736 }
2737 } else if *template_ty == PromptTemplateType::Zephyr
2738 || *template_ty == PromptTemplateType::MistralLite
2739 || *template_ty == PromptTemplateType::MistralTool
2740 || *template_ty == PromptTemplateType::MistralInstruct
2741 || *template_ty == PromptTemplateType::MistralSmallChat
2742 || *template_ty == PromptTemplateType::MistralSmallTool
2743 || *template_ty == PromptTemplateType::BreezeInstruct
2744 {
2745 if output.as_ref().contains("</s><") {
2746 output.as_ref().trim_end_matches("</s><").trim().to_owned()
2747 } else if output.as_ref().contains("</s>") {
2748 output
2749 .as_ref()
2750 .strip_suffix("</s>")
2751 .unwrap()
2752 .trim()
2753 .to_owned()
2754 } else {
2755 output.as_ref().trim().to_owned()
2756 }
2757 } else if *template_ty == PromptTemplateType::DeepseekChat {
2758 if output.as_ref().contains("<|end_of_sentence|>") {
2759 output
2760 .as_ref()
2761 .trim_end_matches("<|end_of_sentence|>")
2762 .trim()
2763 .replace("<|end_of_sentence|>", " ")
2764 .trim()
2765 .to_owned()
2766 } else {
2767 output.as_ref().trim().to_owned()
2768 }
2769 } else if *template_ty == PromptTemplateType::HumanAssistant {
2770 if output.as_ref().contains("Human:") {
2771 output.as_ref().trim_end_matches("Human:").trim().to_owned()
2772 } else {
2773 output.as_ref().trim().to_owned()
2774 }
2775 } else if *template_ty == PromptTemplateType::SolarInstruct {
2776 let s = output.as_ref().trim();
2777
2778 if s.starts_with("### Answer") {
2779 let s = s.trim_start_matches("###").trim();
2780
2781 if s.starts_with("Answer:\n") {
2782 s.replace("Answer:\n", "Answer: ")
2783 } else {
2784 s.to_owned()
2785 }
2786 } else {
2787 s.to_owned()
2788 }
2789 } else if *template_ty == PromptTemplateType::Llama2Chat
2790 || *template_ty == PromptTemplateType::NemotronTool
2791 || *template_ty == PromptTemplateType::NemotronChat
2792 {
2793 let s = output.as_ref().trim();
2794 if s.ends_with("</s>") {
2795 s.trim_end_matches("</s>").trim().to_owned()
2796 } else {
2797 s.to_owned()
2798 }
2799 } else if *template_ty == PromptTemplateType::Llama3Chat
2800 || *template_ty == PromptTemplateType::GroqLlama3Tool
2801 || *template_ty == PromptTemplateType::Llama3Tool
2802 || *template_ty == PromptTemplateType::FunctionaryV32
2803 {
2804 let s = output.as_ref().trim();
2805 if s.ends_with("<|eot_id|>") {
2806 s.trim_end_matches("<|eot_id|>").trim().to_owned()
2807 } else {
2808 s.to_owned()
2809 }
2810 } else if *template_ty == PromptTemplateType::Phi3Chat {
2811 let s = output.as_ref().trim();
2812 if s.ends_with("<|end|>") {
2813 s.trim_end_matches("<|end|>").trim().to_owned()
2814 } else {
2815 s.to_owned()
2816 }
2817 } else if *template_ty == PromptTemplateType::Phi4Chat {
2818 let mut s = output.as_ref().trim();
2819
2820 if s.starts_with("think>") {
2821 s = s.trim_start_matches("think>").trim();
2822 }
2823
2824 if s.ends_with("<|im_end|>") {
2825 s.trim_end_matches("<|im_end|>").trim().to_owned()
2826 } else if s.ends_with("<|end|>") {
2827 s.trim_end_matches("<|end|>").trim().to_owned()
2828 } else {
2829 s.to_owned()
2830 }
2831 } else if *template_ty == PromptTemplateType::FunctionaryV31 {
2832 let mut s = output.as_ref().trim();
2833 if s.ends_with("<|eot_id|>") {
2834 s = s.trim_end_matches("<|eot_id|>").trim();
2835 }
2836 if s.ends_with("<|eom_id|>") {
2837 s = s.trim_end_matches("<|eom_id|>").trim();
2838 }
2839 s.to_owned()
2840 } else if *template_ty == PromptTemplateType::MoxinChat
2841 || *template_ty == PromptTemplateType::MoxinInstruct
2842 {
2843 let s = output.as_ref().trim();
2844 if s.ends_with("</s>") {
2845 s.trim_end_matches("</s>").trim().to_owned()
2846 } else if s.ends_with("[INST]") {
2847 s.trim_end_matches("[INST]").trim().to_owned()
2848 } else {
2849 s.to_owned()
2850 }
2851 } else if *template_ty == PromptTemplateType::Falcon3 {
2852 let s = output.as_ref().trim();
2853 if s.ends_with("<|endoftext|>") {
2854 s.trim_end_matches("<|endoftext|>").trim().to_owned()
2855 } else {
2856 s.to_owned()
2857 }
2858 } else if *template_ty == PromptTemplateType::Megrez {
2859 let s = output.as_ref().trim();
2860 if s.ends_with("<|turn_end|>") {
2861 s.trim_end_matches("<|turn_end|>").trim().to_owned()
2862 } else {
2863 s.to_owned()
2864 }
2865 } else if *template_ty == PromptTemplateType::Qwen2vl
2866 || *template_ty == PromptTemplateType::Qwen3NoThink
2867 || *template_ty == PromptTemplateType::ChatMLThink
2868 {
2869 let mut s = output.as_ref().trim();
2870
2871 if s.starts_with(":") {
2872 s = s.trim_start_matches(":").trim();
2873 }
2874
2875 if s.starts_with("</think>") {
2876 s = s.trim_start_matches("</think>").trim();
2877 }
2878
2879 if s.ends_with("<|im_end|>") {
2880 s.trim_end_matches("<|im_end|>").trim().to_owned()
2881 } else {
2882 s.to_owned()
2883 }
2884 } else if *template_ty == PromptTemplateType::VicunaLlava {
2885 let s = output.as_ref().trim();
2886 if s.ends_with("</s>") {
2887 s.trim_end_matches("</s>").trim().to_owned()
2888 } else {
2889 s.to_owned()
2890 }
2891 } else if *template_ty == PromptTemplateType::ExaoneDeepChat
2892 || *template_ty == PromptTemplateType::ExaoneChat
2893 {
2894 let mut s = output.as_ref().trim();
2895
2896 if s.ends_with("[|endofturn|]") {
2897 s = s.trim_end_matches("[|endofturn|]").trim();
2898 }
2899
2900 s.to_owned()
2901 } else if *template_ty == PromptTemplateType::Llama4Chat {
2902 let mut s = output.as_ref().trim();
2903
2904 if s.ends_with("<|eot|>") {
2905 s = s.trim_end_matches("<|eot|>").trim();
2906 }
2907
2908 s.to_owned()
2909 } else if *template_ty == PromptTemplateType::Smolvl {
2910 let mut s = output.as_ref().trim();
2911
2912 if s.starts_with(":") {
2913 s = s.trim_start_matches(":").trim();
2914 }
2915
2916 if s.ends_with("<end_of_utterance>") {
2917 s = s.trim_end_matches("<end_of_utterance>").trim();
2918 }
2919
2920 if s.contains("<end_of_utterance>:") {
2921 let parts = s.split("<end_of_utterance>:").collect::<Vec<_>>();
2922 parts.last().unwrap().trim().to_owned()
2923 } else {
2924 s.to_owned()
2925 }
2926 } else if *template_ty == PromptTemplateType::Smol3NoThink {
2927 let mut s = output.as_ref().trim();
2928
2929 if s.ends_with("<|im_end|>") {
2930 s = s.trim_end_matches("<|im_end|>").trim();
2931 }
2932
2933 let re = regex::Regex::new(r"(?s)^<think>.*?</think>\s*").unwrap();
2934 re.replace(s, "").to_string()
2935 } else if *template_ty == PromptTemplateType::GptOss {
2936 let s = output.as_ref().trim();
2937
2938 let re =
2939 regex::Regex::new(r"(?s).*<\|channel\|>final<\|message\|>(.*?)<\|return\|>$").unwrap();
2940
2941 if let Some(caps) = re.captures(s) {
2942 let extracted = &caps[1];
2943 extracted.to_owned()
2944 } else {
2945 s.to_owned()
2946 }
2947 } else if *template_ty == PromptTemplateType::Qwen3Agent {
2948 let mut s = output.as_ref().trim();
2949
2950 if s.starts_with(":") {
2951 s = s.trim_start_matches(":").trim();
2952 }
2953
2954 if s.starts_with("</think>") {
2955 s = s.trim_start_matches("</think>").trim();
2956 }
2957
2958 if s.ends_with("<|im_end|>") {
2959 s = s.trim_end_matches("<|im_end|>").trim();
2960 }
2961
2962 if s.contains("<final_answer>") && !s.contains("</final_answer>") {
2963 format!("{s}</final_answer>")
2964 } else {
2965 s.to_owned()
2966 }
2967 } else {
2968 output.as_ref().trim().to_owned()
2969 };
2970
2971 Ok(output)
2972}
2973
2974fn build_prompt(
2986 model_name: Option<&String>,
2987 chat_request: &mut ChatCompletionRequest,
2988) -> Result<(String, u64, bool), LlamaCoreError> {
2989 let metadata = get_model_metadata(model_name)?;
2990 let ctx_size = metadata.ctx_size as u64;
2991 let chat_prompt = ChatPrompt::from(metadata.prompt_template);
2992
2993 let max_prompt_tokens = ctx_size * 4 / 5;
2995
2996 loop {
2997 {
2999 }
3012
3013 if chat_request.messages.is_empty() {
3014 let err_msg = "The messages in the chat request are empty.";
3015
3016 #[cfg(feature = "logging")]
3017 error!(target: "stdout", "{err_msg}");
3018
3019 return Err(LlamaCoreError::Operation(err_msg.to_owned()));
3020 }
3021
3022 #[cfg(feature = "logging")]
3023 {
3024 let mut role_chain = String::new();
3025 for (idx, message) in chat_request.messages.iter().enumerate() {
3026 if idx == chat_request.messages.len() - 1 {
3027 role_chain.push_str(&format!("{}", message.role()));
3028 } else {
3029 role_chain.push_str(&format!("{} -> ", message.role()));
3030 }
3031 }
3032 info!(target: "stdout", "Role chain: {role_chain}");
3033 }
3034
3035 let (prompt, tool_use) = match chat_request.tool_choice.as_ref() {
3036 Some(tool_choice) => match tool_choice {
3037 ToolChoice::None => {
3038 match chat_prompt.build_with_tools(&mut chat_request.messages, Some(&[])) {
3039 Ok(prompt) => (prompt, false),
3040 Err(e) => {
3041 let err_msg = format!("Fail to build chat prompts. Reason: {e}");
3042
3043 #[cfg(feature = "logging")]
3044 error!(target: "stdout", "{}", &err_msg);
3045
3046 return Err(LlamaCoreError::Operation(err_msg));
3047 }
3048 }
3049 }
3050 _ => match chat_request.tools.as_ref() {
3051 Some(tools) => match chat_prompt
3052 .build_with_tools(&mut chat_request.messages, Some(tools.as_slice()))
3053 {
3054 Ok(prompt) => (prompt, true),
3055 Err(e) => {
3056 let err_msg = format!("Fail to build chat prompts. Reason: {e}");
3057
3058 #[cfg(feature = "logging")]
3059 error!(target: "stdout", "{}", &err_msg);
3060
3061 return Err(LlamaCoreError::Operation(err_msg));
3062 }
3063 },
3064 None => {
3065 #[cfg(feature = "logging")]
3066 warn!(target: "stdout", "The tool choice without tools is not supported.");
3067
3068 match chat_prompt.build_with_tools(&mut chat_request.messages, None) {
3069 Ok(prompt) => (prompt, false),
3070 Err(e) => {
3071 let err_msg = format!("Fail to build chat prompts. Reason: {e}");
3072
3073 #[cfg(feature = "logging")]
3074 error!(target: "stdout", "{}", &err_msg);
3075
3076 return Err(LlamaCoreError::Operation(err_msg));
3077 }
3078 }
3079 }
3080 },
3081 },
3082 None => match chat_prompt.build_with_tools(&mut chat_request.messages, None) {
3083 Ok(prompt) => (prompt, false),
3084 Err(e) => {
3085 let err_msg = format!("Fail to build chat prompts. Reason: {e}");
3086
3087 #[cfg(feature = "logging")]
3088 error!(target: "stdout", "{}", &err_msg);
3089
3090 return Err(LlamaCoreError::Operation(err_msg));
3091 }
3092 },
3093 };
3094 #[cfg(feature = "logging")]
3095 info!(target: "stdout", "Try to set prompt: {prompt}");
3096
3097 set_prompt(model_name, &prompt)?;
3099
3100 let token_info = get_token_info_by_graph_name(model_name)?;
3102
3103 match token_info.prompt_tokens > max_prompt_tokens {
3104 true => {
3105 match chat_request.messages[0].role() {
3106 ChatCompletionRole::System => {
3107 if chat_request.messages.len() == 4
3109 && chat_request.messages[1].role() == ChatCompletionRole::User
3110 && chat_request.messages[2].role() == ChatCompletionRole::Assistant
3111 && chat_request.messages[3].role() == ChatCompletionRole::Tool
3112 {
3113 let err_msg = format!(
3114 "The number of prompt tokens ({}) is greater than the max prompt tokens ({}). Please increase the context size.",
3115 token_info.prompt_tokens, max_prompt_tokens
3116 );
3117
3118 #[cfg(feature = "logging")]
3119 error!(target: "stdout", "{}", &err_msg);
3120
3121 return Err(LlamaCoreError::Operation(err_msg));
3122 }
3123
3124 if chat_request.messages.len() > 2 {
3125 #[cfg(feature = "logging")]
3126 info!(target: "stdout", "Prune chat history: current length {}", chat_request.messages.len());
3127
3128 if chat_request.messages[1].role() == ChatCompletionRole::User {
3131 let user_message = chat_request.messages.remove(1);
3132
3133 #[cfg(feature = "logging")]
3134 info!(target: "stdout", "Remove a user message from the chat history: {user_message:?}");
3135 }
3136
3137 while chat_request.messages[1].role() != ChatCompletionRole::User {
3140 let message = chat_request.messages.remove(1);
3141
3142 #[cfg(feature = "logging")]
3143 info!(target: "stdout", "Remove a {} message from the chat history: {:?}", message.role(), message);
3144
3145 if chat_request.messages.len() == 1 {
3146 let err_msg = format!("The last message in the chat history should be a user message, but found a {} message.", message.role());
3147
3148 #[cfg(feature = "logging")]
3149 error!(target: "stdout", "{err_msg}");
3150
3151 return Err(LlamaCoreError::Operation(err_msg));
3152 }
3153 }
3154 } else if token_info.prompt_tokens > ctx_size {
3155 let err_msg = format!(
3156 "The number of prompt tokens ({}) is greater than the context size ({}). Please increase the context size, or simplify the input message.",
3157 token_info.prompt_tokens, ctx_size
3158 );
3159
3160 #[cfg(feature = "logging")]
3161 error!(target: "stdout", "{}", &err_msg);
3162
3163 return Err(LlamaCoreError::Operation(err_msg));
3164 } else {
3165 return Ok((prompt, ctx_size - token_info.prompt_tokens, tool_use));
3166 }
3167 }
3168 ChatCompletionRole::User => {
3169 if chat_request.messages.len() == 3
3171 && chat_request.messages[1].role() == ChatCompletionRole::User
3172 && chat_request.messages[2].role() == ChatCompletionRole::Assistant
3173 && chat_request.messages[3].role() == ChatCompletionRole::Tool
3174 {
3175 let err_msg = format!(
3176 "The number of prompt tokens ({}) is greater than the max prompt tokens ({}). Please increase the context size.",
3177 token_info.prompt_tokens, max_prompt_tokens
3178 );
3179
3180 #[cfg(feature = "logging")]
3181 error!(target: "stdout", "{}", &err_msg);
3182
3183 return Err(LlamaCoreError::Operation(err_msg));
3184 }
3185
3186 if chat_request.messages.len() > 1 {
3187 if chat_request.messages[0].role() == ChatCompletionRole::User {
3192 let user_message = chat_request.messages.remove(0);
3193
3194 #[cfg(feature = "logging")]
3195 info!(target: "stdout", "Remove a user message from the chat history: {user_message:?}");
3196 }
3197
3198 while chat_request.messages[0].role() != ChatCompletionRole::User {
3201 let message = chat_request.messages.remove(0);
3202
3203 #[cfg(feature = "logging")]
3204 info!(target: "stdout", "Remove a {} message from the chat history: {:?}", message.role(), message);
3205
3206 if chat_request.messages.is_empty() {
3207 let err_msg = format!("The last message in the chat history should be a user message, but found a {} message.", message.role());
3208
3209 #[cfg(feature = "logging")]
3210 error!(target: "stdout", "{err_msg}");
3211
3212 return Err(LlamaCoreError::Operation(err_msg));
3213 }
3214 }
3215 } else if token_info.prompt_tokens > ctx_size {
3216 let err_msg = format!(
3217 "The number of prompt tokens ({}) is greater than the context size ({}). Please increase the context size, or simplify the input message.",
3218 token_info.prompt_tokens, ctx_size
3219 );
3220
3221 #[cfg(feature = "logging")]
3222 error!(target: "stdout", "{}", &err_msg);
3223
3224 return Err(LlamaCoreError::Operation(err_msg));
3225 } else {
3226 return Ok((prompt, ctx_size - token_info.prompt_tokens, tool_use));
3227 }
3228 }
3229 _ => {
3230 #[cfg(feature = "logging")]
3231 info!(target: "stdout", "remove a {} message from the message queue", chat_request.messages[0].role());
3232
3233 chat_request.messages.remove(0);
3234 }
3235 }
3236
3237 continue;
3238 }
3239 false => return Ok((prompt, ctx_size - max_prompt_tokens, tool_use)),
3240 }
3241 }
3242}
3243
3244fn set_prompt(model_name: Option<&String>, prompt: impl AsRef<str>) -> Result<(), LlamaCoreError> {
3245 let chat_graphs = match CHAT_GRAPHS.get() {
3246 Some(chat_graphs) => chat_graphs,
3247 None => {
3248 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
3249
3250 #[cfg(feature = "logging")]
3251 error!(target: "stdout", "{}", &err_msg);
3252
3253 return Err(LlamaCoreError::Operation(err_msg.into()));
3254 }
3255 };
3256
3257 let mut chat_graphs = chat_graphs.lock().map_err(|e| {
3258 let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
3259
3260 #[cfg(feature = "logging")]
3261 error!(target: "stdout", "{}", &err_msg);
3262
3263 LlamaCoreError::Operation(err_msg)
3264 })?;
3265
3266 match model_name {
3267 Some(model_name) => {
3268 #[cfg(feature = "logging")]
3269 info!(target: "stdout", "Set prompt to the chat model named {model_name}");
3270
3271 match chat_graphs.contains_key(model_name) {
3272 true => {
3273 let graph = chat_graphs.get_mut(model_name).unwrap();
3274 let tensor_data = prompt.as_ref().as_bytes().to_vec();
3275 set_tensor_data_u8(graph, 0, &tensor_data)
3276 }
3277 false => match chat_graphs.iter_mut().next() {
3278 Some((_, graph)) => {
3279 let tensor_data = prompt.as_ref().as_bytes().to_vec();
3280 set_tensor_data_u8(graph, 0, &tensor_data)
3281 }
3282 None => {
3283 let err_msg = "There is no model available in the chat graphs.";
3284
3285 #[cfg(feature = "logging")]
3286 error!(target: "stdout", "{}", &err_msg);
3287
3288 Err(LlamaCoreError::Operation(err_msg.into()))
3289 }
3290 },
3291 }
3292 }
3293 None => {
3294 #[cfg(feature = "logging")]
3295 info!(target: "stdout", "Set prompt to the default chat model.");
3296
3297 match chat_graphs.iter_mut().next() {
3298 Some((_, graph)) => {
3299 let tensor_data = prompt.as_ref().as_bytes().to_vec();
3300 set_tensor_data_u8(graph, 0, &tensor_data)
3301 }
3302 None => {
3303 let err_msg = "There is no model available in the chat graphs while trying to set prompt to the default model.";
3304
3305 #[cfg(feature = "logging")]
3306 error!(target: "stdout", "{err_msg}");
3307
3308 Err(LlamaCoreError::Operation(err_msg.into()))
3309 }
3310 }
3311 }
3312 }
3313}
3314
3315fn get_model_metadata(model_name: Option<&String>) -> Result<GgmlMetadata, LlamaCoreError> {
3317 let chat_graphs = match CHAT_GRAPHS.get() {
3318 Some(chat_graphs) => chat_graphs,
3319 None => {
3320 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
3321
3322 #[cfg(feature = "logging")]
3323 error!(target: "stdout", "{err_msg}");
3324
3325 return Err(LlamaCoreError::Operation(err_msg.into()));
3326 }
3327 };
3328
3329 let chat_graphs = chat_graphs.lock().map_err(|e| {
3330 let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
3331
3332 #[cfg(feature = "logging")]
3333 error!(target: "stdout", "{}", &err_msg);
3334
3335 LlamaCoreError::Operation(err_msg)
3336 })?;
3337
3338 match model_name {
3339 Some(model_name) => match chat_graphs.contains_key(model_name) {
3340 true => {
3341 let graph = chat_graphs.get(model_name).unwrap();
3342 Ok(graph.metadata.clone())
3343 }
3344 false => match chat_graphs.iter().next() {
3345 Some((_, graph)) => Ok(graph.metadata.clone()),
3346 None => {
3347 let err_msg = "There is no model available in the chat graphs.";
3348
3349 #[cfg(feature = "logging")]
3350 error!(target: "stdout", "{}", &err_msg);
3351
3352 Err(LlamaCoreError::Operation(err_msg.into()))
3353 }
3354 },
3355 },
3356 None => match chat_graphs.iter().next() {
3357 Some((_, graph)) => Ok(graph.metadata.clone()),
3358 None => {
3359 let err_msg = "There is no model available in the chat graphs.";
3360
3361 #[cfg(feature = "logging")]
3362 error!(target: "stdout", "{err_msg}");
3363
3364 Err(LlamaCoreError::Operation(err_msg.into()))
3365 }
3366 },
3367 }
3368}
3369
3370fn update_model_metadata(
3371 model_name: Option<&String>,
3372 metadata: &GgmlMetadata,
3373) -> Result<(), LlamaCoreError> {
3374 let config = match serde_json::to_string(metadata) {
3375 Ok(config) => config,
3376 Err(e) => {
3377 let err_msg = format!("Fail to serialize metadata to a JSON string. {e}");
3378
3379 #[cfg(feature = "logging")]
3380 error!(target: "stdout", "{}", &err_msg);
3381
3382 return Err(LlamaCoreError::Operation(err_msg));
3383 }
3384 };
3385
3386 let chat_graphs = match CHAT_GRAPHS.get() {
3387 Some(chat_graphs) => chat_graphs,
3388 None => {
3389 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
3390
3391 #[cfg(feature = "logging")]
3392 error!(target: "stdout", "{err_msg}");
3393
3394 return Err(LlamaCoreError::Operation(err_msg.into()));
3395 }
3396 };
3397
3398 let mut chat_graphs = chat_graphs.lock().map_err(|e| {
3399 let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. Reason: {e}");
3400
3401 #[cfg(feature = "logging")]
3402 error!(target: "stdout", "{}", &err_msg);
3403
3404 LlamaCoreError::Operation(err_msg)
3405 })?;
3406
3407 match model_name {
3408 Some(model_name) => {
3409 match chat_graphs.contains_key(model_name) {
3410 true => {
3411 let graph = chat_graphs.get_mut(model_name).unwrap();
3412 set_tensor_data_u8(graph, 1, config.as_bytes())
3414 }
3415 false => match chat_graphs.iter_mut().next() {
3416 Some((_, graph)) => {
3417 set_tensor_data_u8(graph, 1, config.as_bytes())
3419 }
3420 None => {
3421 let err_msg = "There is no model available in the chat graphs.";
3422
3423 #[cfg(feature = "logging")]
3424 error!(target: "stdout", "{}", &err_msg);
3425
3426 Err(LlamaCoreError::Operation(err_msg.into()))
3427 }
3428 },
3429 }
3430 }
3431 None => {
3432 match chat_graphs.iter_mut().next() {
3433 Some((_, graph)) => {
3434 set_tensor_data_u8(graph, 1, config.as_bytes())
3436 }
3437 None => {
3438 let err_msg = "There is no model available in the chat graphs.";
3439
3440 #[cfg(feature = "logging")]
3441 error!(target: "stdout", "{err_msg}");
3442
3443 Err(LlamaCoreError::Operation(err_msg.into()))
3444 }
3445 }
3446 }
3447 }
3448}
3449
3450fn reset_model_metadata(model_name: Option<&String>) -> Result<(), LlamaCoreError> {
3451 let metadata = get_model_metadata(model_name)?;
3453
3454 update_model_metadata(model_name, &metadata)
3456}
3457
3458#[derive(Debug, Clone, Copy, PartialEq, Eq)]
3459enum ContextFullState {
3460 Message,
3461 Usage,
3462 Done,
3463 EndOfSequence,
3464}
3465
3466#[derive(Debug, Clone, Copy, PartialEq, Eq)]
3467enum StreamState {
3468 Usage,
3469 NoUsage,
3470 Done,
3471 EndOfSequence,
3472}
3473
3474#[derive(Debug, Clone, Copy, PartialEq, Eq)]
3475enum PromptTooLongState {
3476 Message,
3477 Usage,
3478 Done,
3479 EndOfSequence,
3480}
3481
3482struct ChatStream {
3483 id: String,
3484 model: Option<String>,
3485 include_usage: bool,
3486 context_full_state: ContextFullState,
3487 prompt_too_long_state: PromptTooLongState,
3488 stream_state: StreamState,
3489 cache: Option<VecDeque<String>>,
3490 is_waiting: bool,
3491 has_lock: bool,
3492}
3493impl ChatStream {
3494 fn new(
3495 model: Option<String>,
3496 id: String,
3497 include_usage: bool,
3498 cache: Option<Vec<String>>,
3499 ) -> Self {
3500 let has_lock = CHAT_STREAM_ACTIVE
3502 .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
3503 .is_ok();
3504
3505 #[cfg(feature = "logging")]
3506 if !has_lock {
3507 info!(target: "stdout", "Lock acquisition failed in ChatStream::new, creating with waiting status");
3508 }
3509
3510 ChatStream {
3511 id,
3512 model,
3513 include_usage,
3514 context_full_state: ContextFullState::Message,
3515 prompt_too_long_state: PromptTooLongState::Message,
3516 stream_state: if include_usage {
3517 StreamState::Usage
3518 } else {
3519 StreamState::NoUsage
3520 },
3521 cache: cache.map(VecDeque::from),
3522 is_waiting: !has_lock,
3523 has_lock,
3524 }
3525 }
3526
3527 fn try_acquire_lock(&mut self) -> bool {
3529 if self.has_lock {
3530 return true;
3531 }
3532
3533 let acquired = CHAT_STREAM_ACTIVE
3534 .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
3535 .is_ok();
3536
3537 if acquired {
3538 self.has_lock = true;
3539 self.is_waiting = false;
3540 }
3541
3542 acquired
3543 }
3544}
3545impl Drop for ChatStream {
3546 fn drop(&mut self) {
3547 if self.has_lock || (self.cache.is_none() && !self.is_waiting) {
3549 #[cfg(feature = "logging")]
3550 info!(target: "stdout", "Cleaning up context for ChatStream {}", &self.id);
3551
3552 match &self.model {
3553 Some(model_name) => {
3554 match CHAT_GRAPHS.get() {
3555 Some(chat_graphs) => {
3556 match chat_graphs.lock() {
3557 Ok(mut chat_graphs) => match chat_graphs.contains_key(model_name) {
3558 true => {
3559 let graph = chat_graphs.get_mut(model_name).unwrap();
3560
3561 if let Err(e) = graph.finish_single() {
3563 let err_msg = format!(
3564 "Failed to clean up the context. Reason: {e}"
3565 );
3566
3567 #[cfg(feature = "logging")]
3568 error!(target: "stdout", "{}", &err_msg);
3569
3570 #[cfg(not(feature = "logging"))]
3571 println!(
3572 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3573 &err_msg
3574 );
3575 }
3576 }
3577 false => match chat_graphs.iter_mut().next() {
3578 Some((_, graph)) => {
3579 if let Err(e) = graph.finish_single() {
3581 let err_msg = format!(
3582 "Failed to clean up the context. Reason: {e}"
3583 );
3584
3585 #[cfg(feature = "logging")]
3586 error!(target: "stdout", "{}", &err_msg);
3587
3588 #[cfg(not(feature = "logging"))]
3589 println!(
3590 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3591 &err_msg
3592 );
3593 }
3594 }
3595 None => {
3596 let err_msg =
3597 "There is no model available in the chat graphs.";
3598
3599 #[cfg(feature = "logging")]
3600 error!(target: "stdout", "{}", &err_msg);
3601
3602 #[cfg(not(feature = "logging"))]
3603 println!(
3604 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3605 &err_msg
3606 );
3607 }
3608 },
3609 },
3610 Err(e) => {
3611 let err_msg =
3612 format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
3613
3614 #[cfg(feature = "logging")]
3615 error!(target: "stdout", "{}", &err_msg);
3616
3617 #[cfg(not(feature = "logging"))]
3618 println!(
3619 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3620 &err_msg
3621 );
3622 }
3623 }
3624 }
3625 None => {
3626 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
3627
3628 #[cfg(feature = "logging")]
3629 error!(target: "stdout", "{}", &err_msg);
3630
3631 #[cfg(not(feature = "logging"))]
3632 println!(
3633 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3634 &err_msg
3635 );
3636 }
3637 };
3638 }
3639 None => {
3640 match CHAT_GRAPHS.get() {
3641 Some(chat_graphs) => {
3642 match chat_graphs.lock() {
3643 Ok(mut chat_graphs) => match chat_graphs.iter_mut().next() {
3644 Some((_, graph)) => {
3645 if let Err(e) = graph.finish_single() {
3647 let err_msg = format!(
3648 "Failed to clean up the context. Reason: {e}"
3649 );
3650
3651 #[cfg(feature = "logging")]
3652 error!(target: "stdout", "{}", &err_msg);
3653
3654 #[cfg(not(feature = "logging"))]
3655 println!(
3656 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3657 &err_msg
3658 );
3659 }
3660 }
3661 None => {
3662 let err_msg =
3663 "There is no model available in the chat graphs.";
3664
3665 #[cfg(feature = "logging")]
3666 error!(target: "stdout", "{err_msg}");
3667
3668 #[cfg(not(feature = "logging"))]
3669 println!(
3670 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3671 err_msg
3672 );
3673 }
3674 },
3675 Err(e) => {
3676 let err_msg =
3677 format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
3678
3679 #[cfg(feature = "logging")]
3680 error!(target: "stdout", "{}", &err_msg);
3681
3682 #[cfg(not(feature = "logging"))]
3683 println!(
3684 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3685 &err_msg
3686 );
3687 }
3688 }
3689 }
3690 None => {
3691 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
3692
3693 #[cfg(feature = "logging")]
3694 error!(target: "stdout", "{}", &err_msg);
3695
3696 #[cfg(not(feature = "logging"))]
3697 println!(
3698 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3699 &err_msg
3700 );
3701 }
3702 };
3703 }
3704 }
3705
3706 #[cfg(feature = "logging")]
3707 info!(target: "stdout", "Model context cleanup done!");
3708 }
3709
3710 if let Err(e) = reset_model_metadata(self.model.as_ref()) {
3712 let err_msg = format!("Fail to reset model metadata. Reason: {e}");
3713
3714 #[cfg(feature = "logging")]
3715 error!(target: "stdout", "{}", &err_msg);
3716
3717 #[cfg(not(feature = "logging"))]
3718 println!("[ERROR][llama_core] {}", &err_msg);
3719 }
3720 #[cfg(feature = "logging")]
3721 info!(target: "stdout", "Model metadata reset done!");
3722
3723 if self.has_lock {
3725 CHAT_STREAM_ACTIVE.store(false, Ordering::SeqCst);
3727
3728 #[cfg(feature = "logging")]
3729 info!(target: "stdout", "Lock from ChatStream {} released", &self.id);
3730
3731 if let Ok(mut queue) = get_chat_stream_waker_queue().lock() {
3733 if let Some(waker) = queue.pop_front() {
3734 #[cfg(feature = "logging")]
3735 info!(target: "stdout", "Waking up a waiting ChatStream");
3736
3737 waker.wake();
3738 }
3739 }
3740 }
3741 }
3742}
3743impl futures::Stream for ChatStream {
3744 type Item = Result<String, LlamaCoreError>;
3745
3746 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
3747 let this = self.get_mut();
3748
3749 if this.is_waiting {
3751 if !this.try_acquire_lock() {
3752 if let Ok(mut queue) = get_chat_stream_waker_queue().lock() {
3754 queue.retain(|w| !w.will_wake(cx.waker()));
3756 queue.push_back(cx.waker().clone());
3758
3759 #[cfg(feature = "logging")]
3760 debug!(target: "stdout", "ChatStream {} is waiting for lock, added waker to queue", &this.id);
3761 }
3762
3763 return Poll::Pending;
3764 }
3765
3766 #[cfg(feature = "logging")]
3767 info!(target: "stdout", "ChatStream {} acquired lock and is now active", &this.id);
3768 }
3770
3771 if !this.has_lock && !this.try_acquire_lock() {
3773 this.is_waiting = true;
3775
3776 if let Ok(mut queue) = get_chat_stream_waker_queue().lock() {
3778 queue.retain(|w| !w.will_wake(cx.waker()));
3779 queue.push_back(cx.waker().clone());
3780 }
3781
3782 return Poll::Pending;
3783 }
3784
3785 if this.cache.is_none() {
3786 let res = compute_stream(
3787 this.model.clone(),
3788 this.id.clone(),
3789 this.include_usage,
3790 &mut this.prompt_too_long_state,
3791 &mut this.context_full_state,
3792 &mut this.stream_state,
3793 );
3794
3795 match res {
3796 Ok(x) => {
3797 #[cfg(feature = "logging")]
3798 info!(target: "stdout", "next item for ChatStream {}: {}", &this.id, &x);
3799
3800 if x != "[GGML] End of sequence" && !x.is_empty() {
3801 Poll::Ready(Some(Ok(x)))
3802 } else {
3803 Poll::Ready(None)
3805 }
3806 }
3807 Err(e) => Poll::Ready(Some(Err(e))),
3808 }
3809 } else {
3810 let x = this.cache.as_mut().unwrap().pop_front();
3811
3812 #[cfg(feature = "logging")]
3813 info!(target: "stdout", "Get the next item from the cache for ChatStream {}: {:?}", &this.id, &x);
3814
3815 match x {
3816 Some(x) => Poll::Ready(Some(Ok(x))),
3817 None => Poll::Ready(None),
3818 }
3819 }
3820 }
3821}
3822
3823fn get_chat_stream_waker_queue() -> &'static Mutex<VecDeque<Waker>> {
3825 CHAT_STREAM_WAKER_QUEUE.get_or_init(|| {
3826 #[cfg(feature = "logging")]
3827 info!(target: "stdout", "Initializing ChatStream waker queue");
3828 Mutex::new(VecDeque::new())
3829 })
3830}
3831
3832fn compute_stream(
3833 model_name: Option<String>,
3834 id: String,
3835 include_usage: bool,
3836 prompt_too_long_state: &mut PromptTooLongState,
3837 context_full_state: &mut ContextFullState,
3838 stream_state: &mut StreamState,
3839) -> Result<String, LlamaCoreError> {
3840 #[cfg(feature = "logging")]
3841 info!(target: "stdout", "Computing stream chunk for ChatStream {}", &id);
3842
3843 #[cfg(feature = "logging")]
3844 debug!(target: "stdout", "prompt_too_long_state: {:?}", *prompt_too_long_state);
3845 #[cfg(feature = "logging")]
3846 debug!(target: "stdout", "context_full_state: {:?}", *context_full_state);
3847 #[cfg(feature = "logging")]
3848 debug!(target: "stdout", "stream_state: {:?}", *stream_state);
3849
3850 if *prompt_too_long_state == PromptTooLongState::EndOfSequence
3851 || *context_full_state == ContextFullState::EndOfSequence
3852 || *stream_state == StreamState::EndOfSequence
3853 {
3854 #[cfg(feature = "logging")]
3855 info!(target: "stdout", "Return the chat stream chunk!");
3856
3857 return Ok("[GGML] End of sequence".to_string());
3858 }
3859
3860 let chat_graphs = match CHAT_GRAPHS.get() {
3861 Some(chat_graphs) => chat_graphs,
3862 None => {
3863 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
3864
3865 #[cfg(feature = "logging")]
3866 error!(target: "stdout", "{}", &err_msg);
3867
3868 return Err(LlamaCoreError::Operation(err_msg.into()));
3869 }
3870 };
3871
3872 let mut chat_graphs = chat_graphs.lock().map_err(|e| {
3874 let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
3875
3876 #[cfg(feature = "logging")]
3877 error!(target: "stdout", "{}", &err_msg);
3878
3879 LlamaCoreError::Operation(err_msg)
3880 })?;
3881
3882 let res = match &model_name {
3884 Some(model_name) => {
3885 match chat_graphs.contains_key(model_name) {
3886 true => {
3887 let graph = chat_graphs.get_mut(model_name).unwrap();
3888 match graph.compute_single() {
3890 Ok(_) => {
3891 #[cfg(feature = "logging")]
3892 debug!(target: "stdout", "Compute the chat stream chunk successfully.");
3893
3894 match stream_state {
3896 StreamState::Usage | StreamState::NoUsage => {
3897 let output_buffer =
3899 get_output_buffer_single(graph, OUTPUT_TENSOR)?;
3900
3901 #[cfg(feature = "logging")]
3902 info!(target: "stdout", "retrieved the output buffer");
3903
3904 let output = match String::from_utf8(output_buffer.clone()) {
3906 Ok(token) => token,
3907 Err(_) => {
3908 let mutex = CACHED_UTF8_ENCODINGS
3909 .get_or_init(|| Mutex::new(Vec::new()));
3910 let mut cached_encodings = mutex.lock().map_err(|e| {
3911 let err_msg = format!(
3912 "Fail to acquire the lock of `UTF8_ENCODINGS`. Reason: {e}"
3913 );
3914
3915 #[cfg(feature = "logging")]
3916 error!(target: "stdout", "{}", &err_msg);
3917
3918
3919 LlamaCoreError::Operation(err_msg)
3920 })?;
3921
3922 cached_encodings.extend_from_slice(&output_buffer[..]);
3924
3925 match String::from_utf8(cached_encodings.to_vec()) {
3926 Ok(token) => {
3927 cached_encodings.clear();
3929
3930 token
3931 }
3932 Err(e) => {
3933 if cached_encodings.len() > 4 {
3935 let err_msg = format!("Fail to convert a vector of bytes to string. The length of the utf8 bytes exceeds 4. {e}");
3936
3937 #[cfg(feature = "logging")]
3938 error!(target: "stdout", "{}", &err_msg);
3939
3940 #[cfg(feature = "logging")]
3941 error!(target: "stdout", "The cached buffer: {:?}", &cached_encodings[..]);
3942
3943 cached_encodings.clear();
3950
3951 String::from("")
3952 } else {
3953 let warn_msg = format!("Fail to convert a vector of bytes to string. {e}");
3954
3955 #[cfg(feature = "logging")]
3956 warn!(target: "stdout", "{}", &warn_msg);
3957
3958 String::from("")
3959 }
3960 }
3961 }
3962 }
3963 };
3964
3965 #[cfg(feature = "logging")]
3966 info!(target: "stdout", "decoded the output buffer");
3967
3968 let created = SystemTime::now()
3969 .duration_since(std::time::UNIX_EPOCH)
3970 .map_err(|e| {
3971 let err_msg = format!(
3972 "Failed to get the current time. Reason: {e}"
3973 );
3974
3975 #[cfg(feature = "logging")]
3976 error!(target: "stdout", "{}", &err_msg);
3977
3978 LlamaCoreError::Operation(err_msg)
3979 })?;
3980
3981 let chat_completion_chunk = ChatCompletionChunk {
3982 id,
3983 object: "chat.completion.chunk".to_string(),
3984 created: created.as_secs(),
3985 model: graph.name().to_owned(),
3986 system_fingerprint: "fp_44709d6fcb".to_string(),
3987 choices: vec![ChatCompletionChunkChoice {
3988 index: 0,
3989 delta: ChatCompletionChunkChoiceDelta {
3990 role: ChatCompletionRole::Assistant,
3991 content: Some(output),
3992 tool_calls: vec![],
3993 },
3994 logprobs: None,
3995 finish_reason: None,
3996 }],
3997 usage: None,
3998 };
3999
4000 #[cfg(feature = "logging")]
4001 info!(target: "stdout", "created chat completion chunk");
4002
4003 let chunk_str = serde_json::to_string(&chat_completion_chunk)
4005 .map_err(|e| {
4006 let err_msg = format!(
4007 "Failed to serialize chat completion chunk. Reason: {e}"
4008 );
4009
4010 #[cfg(feature = "logging")]
4011 error!(target: "stdout", "{}", &err_msg);
4012
4013 LlamaCoreError::Operation(err_msg)
4014 })?;
4015
4016 Ok(format!("data: {chunk_str}\n\n"))
4017 }
4018 StreamState::Done => {
4019 *stream_state = StreamState::EndOfSequence;
4020
4021 Ok("data: [DONE]\n\n".to_string())
4022 }
4023 StreamState::EndOfSequence => {
4024 Ok("[GGML] End of sequence".to_string())
4025 }
4026 }
4027 }
4028 Err(wasmedge_wasi_nn::Error::BackendError(
4029 wasmedge_wasi_nn::BackendError::EndOfSequence,
4030 )) => {
4031 #[cfg(feature = "logging")]
4032 debug!(target: "stdout", "End of sequence");
4033
4034 match stream_state {
4035 StreamState::Usage => {
4036 *stream_state = StreamState::Done;
4037
4038 let token_info = get_token_info_by_graph(graph)?;
4040
4041 let usage = Some(Usage {
4042 prompt_tokens: token_info.prompt_tokens,
4043 completion_tokens: token_info.completion_tokens,
4044 total_tokens: token_info.prompt_tokens
4045 + token_info.completion_tokens,
4046 });
4047
4048 #[cfg(feature = "logging")]
4049 info!(target: "stdout", "token_info: {} prompt tokens, {} completion tokens", token_info.prompt_tokens, token_info.completion_tokens);
4050
4051 let created = SystemTime::now()
4052 .duration_since(std::time::UNIX_EPOCH)
4053 .map_err(|e| {
4054 let err_msg = format!(
4055 "Failed to get the current time. Reason: {e}"
4056 );
4057
4058 #[cfg(feature = "logging")]
4059 error!(target: "stdout", "{}", &err_msg);
4060
4061 LlamaCoreError::Operation(err_msg)
4062 })?;
4063
4064 let chat_completion_chunk = ChatCompletionChunk {
4065 id,
4066 object: "chat.completion.chunk".to_string(),
4067 created: created.as_secs(),
4068 model: graph.name().to_owned(),
4069 system_fingerprint: "fp_44709d6fcb".to_string(),
4070 choices: vec![],
4071 usage,
4072 };
4073
4074 let chunk_str = serde_json::to_string(&chat_completion_chunk)
4076 .map_err(|e| {
4077 let err_msg = format!(
4078 "Failed to serialize chat completion chunk. Reason: {e}"
4079 );
4080
4081 #[cfg(feature = "logging")]
4082 error!(target: "stdout", "{}", &err_msg);
4083
4084 LlamaCoreError::Operation(err_msg)
4085 })?;
4086
4087 Ok(format!("data: {chunk_str}\n\n"))
4088 }
4089 StreamState::Done | StreamState::NoUsage => {
4090 *stream_state = StreamState::EndOfSequence;
4091
4092 Ok("data: [DONE]\n\n".to_string())
4093 }
4094 StreamState::EndOfSequence => {
4095 Ok("[GGML] End of sequence".to_string())
4096 }
4097 }
4098 }
4099 Err(wasmedge_wasi_nn::Error::BackendError(
4100 wasmedge_wasi_nn::BackendError::ContextFull,
4101 )) => {
4102 #[cfg(feature = "logging")]
4103 debug!(target: "stdout", "Context full");
4104
4105 match context_full_state {
4106 ContextFullState::Message => {
4107 match include_usage {
4108 true => *context_full_state = ContextFullState::Usage,
4109 false => *context_full_state = ContextFullState::Done,
4110 }
4111
4112 let created = SystemTime::now()
4113 .duration_since(std::time::UNIX_EPOCH)
4114 .map_err(|e| {
4115 let err_msg = format!(
4116 "Failed to get the current time. Reason: {e}"
4117 );
4118
4119 #[cfg(feature = "logging")]
4120 error!(target: "stdout", "{}", &err_msg);
4121
4122 LlamaCoreError::Operation(err_msg)
4123 })?;
4124
4125 let chat_completion_chunk = ChatCompletionChunk {
4126 id,
4127 object: "chat.completion.chunk".to_string(),
4128 created: created.as_secs(),
4129 model: graph.name().to_owned(),
4130 system_fingerprint: "fp_44709d6fcb".to_string(),
4131 choices: vec![ChatCompletionChunkChoice {
4132 index: 0,
4133 delta: ChatCompletionChunkChoiceDelta {
4134 role: ChatCompletionRole::Assistant,
4135 content: Some(
4136 "<|WASMEDGE-GGML-CONTEXT-FULL|>".to_string(),
4137 ),
4138 tool_calls: vec![],
4139 },
4140 logprobs: None,
4141 finish_reason: Some(FinishReason::length),
4142 }],
4143 usage: None,
4144 };
4145
4146 let chunk_str = serde_json::to_string(&chat_completion_chunk)
4148 .map_err(|e| {
4149 let err_msg = format!(
4150 "Failed to serialize chat completion chunk. Reason: {e}"
4151 );
4152
4153 #[cfg(feature = "logging")]
4154 error!(target: "stdout", "{}", &err_msg);
4155
4156 LlamaCoreError::Operation(err_msg)
4157 })?;
4158
4159 Ok(format!("data: {chunk_str}\n\n"))
4160 }
4161 ContextFullState::Usage => {
4162 *context_full_state = ContextFullState::Done;
4163
4164 let token_info = get_token_info_by_graph(graph)?;
4166
4167 let usage = Some(Usage {
4168 prompt_tokens: token_info.prompt_tokens,
4169 completion_tokens: token_info.completion_tokens,
4170 total_tokens: token_info.prompt_tokens
4171 + token_info.completion_tokens,
4172 });
4173
4174 let created = SystemTime::now()
4175 .duration_since(std::time::UNIX_EPOCH)
4176 .map_err(|e| {
4177 let err_msg = format!(
4178 "Failed to get the current time. Reason: {e}"
4179 );
4180
4181 #[cfg(feature = "logging")]
4182 error!(target: "stdout", "{}", &err_msg);
4183
4184 LlamaCoreError::Operation(err_msg)
4185 })?;
4186
4187 let chat_completion_chunk = ChatCompletionChunk {
4188 id,
4189 object: "chat.completion.chunk".to_string(),
4190 created: created.as_secs(),
4191 model: graph.name().to_owned(),
4192 system_fingerprint: "fp_44709d6fcb".to_string(),
4193 choices: vec![],
4194 usage,
4195 };
4196
4197 let chunk_str = serde_json::to_string(&chat_completion_chunk)
4199 .map_err(|e| {
4200 let err_msg = format!(
4201 "Failed to serialize chat completion chunk. Reason: {e}"
4202 );
4203
4204 #[cfg(feature = "logging")]
4205 error!(target: "stdout", "{}", &err_msg);
4206
4207 LlamaCoreError::Operation(err_msg)
4208 })?;
4209
4210 Ok(format!("data: {chunk_str}\n\n"))
4211 }
4212 ContextFullState::Done => {
4213 *context_full_state = ContextFullState::EndOfSequence;
4214
4215 Ok("data: [DONE]\n\n".to_string())
4216 }
4217 ContextFullState::EndOfSequence => {
4218 Ok("[GGML] End of sequence".to_string())
4219 }
4220 }
4221 }
4222 Err(wasmedge_wasi_nn::Error::BackendError(
4223 wasmedge_wasi_nn::BackendError::PromptTooLong,
4224 )) => {
4225 #[cfg(feature = "logging")]
4226 debug!(target: "stdout", "Prompt too long");
4227
4228 match prompt_too_long_state {
4229 PromptTooLongState::Message => {
4230 match include_usage {
4231 true => *prompt_too_long_state = PromptTooLongState::Usage,
4232 false => *prompt_too_long_state = PromptTooLongState::Done,
4233 }
4234
4235 let created = SystemTime::now()
4236 .duration_since(std::time::UNIX_EPOCH)
4237 .map_err(|e| {
4238 let err_msg = format!(
4239 "Failed to get the current time. Reason: {e}"
4240 );
4241
4242 #[cfg(feature = "logging")]
4243 error!(target: "stdout", "{}", &err_msg);
4244
4245 LlamaCoreError::Operation(err_msg)
4246 })?;
4247
4248 let chat_completion_chunk = ChatCompletionChunk {
4249 id,
4250 object: "chat.completion.chunk".to_string(),
4251 created: created.as_secs(),
4252 model: graph.name().to_owned(),
4253 system_fingerprint: "fp_44709d6fcb".to_string(),
4254 choices: vec![ChatCompletionChunkChoice {
4255 index: 0,
4256 delta: ChatCompletionChunkChoiceDelta {
4257 role: ChatCompletionRole::Assistant,
4258 content: None,
4259 tool_calls: vec![],
4260 },
4261 logprobs: None,
4262 finish_reason: Some(FinishReason::length),
4263 }],
4264 usage: None,
4265 };
4266
4267 let chunk_str = serde_json::to_string(&chat_completion_chunk)
4269 .map_err(|e| {
4270 let err_msg = format!(
4271 "Failed to serialize chat completion chunk. Reason: {e}"
4272 );
4273
4274 #[cfg(feature = "logging")]
4275 error!(target: "stdout", "{}", &err_msg);
4276
4277 LlamaCoreError::Operation(err_msg)
4278 })?;
4279
4280 Ok(format!("data: {chunk_str}\n\n"))
4281 }
4282 PromptTooLongState::Usage => {
4283 *prompt_too_long_state = PromptTooLongState::Done;
4284
4285 let token_info = get_token_info_by_graph(graph)?;
4287
4288 let usage = Some(Usage {
4289 prompt_tokens: token_info.prompt_tokens,
4290 completion_tokens: token_info.completion_tokens,
4291 total_tokens: token_info.prompt_tokens
4292 + token_info.completion_tokens,
4293 });
4294
4295 let created = SystemTime::now()
4296 .duration_since(std::time::UNIX_EPOCH)
4297 .map_err(|e| {
4298 let err_msg = format!(
4299 "Failed to get the current time. Reason: {e}"
4300 );
4301
4302 #[cfg(feature = "logging")]
4303 error!(target: "stdout", "{}", &err_msg);
4304
4305 LlamaCoreError::Operation(err_msg)
4306 })?;
4307
4308 let chat_completion_chunk = ChatCompletionChunk {
4309 id,
4310 object: "chat.completion.chunk".to_string(),
4311 created: created.as_secs(),
4312 model: graph.name().to_owned(),
4313 system_fingerprint: "fp_44709d6fcb".to_string(),
4314 choices: vec![],
4315 usage,
4316 };
4317
4318 let chunk_str = serde_json::to_string(&chat_completion_chunk)
4320 .map_err(|e| {
4321 let err_msg = format!(
4322 "Failed to serialize chat completion chunk. Reason: {e}"
4323 );
4324
4325 #[cfg(feature = "logging")]
4326 error!(target: "stdout", "{}", &err_msg);
4327
4328 LlamaCoreError::Operation(err_msg)
4329 })?;
4330
4331 Ok(format!("data: {chunk_str}\n\n"))
4332 }
4333 PromptTooLongState::Done => {
4334 *prompt_too_long_state = PromptTooLongState::EndOfSequence;
4335
4336 Ok("data: [DONE]\n\n".to_string())
4337 }
4338 PromptTooLongState::EndOfSequence => {
4339 Ok("[GGML] End of sequence".to_string())
4340 }
4341 }
4342 }
4343 Err(e) => {
4344 let err_msg =
4345 format!("Failed to compute the chat completion. Reason: {e}");
4346
4347 #[cfg(feature = "logging")]
4348 error!(target: "stdout", "{}", &err_msg);
4349
4350 Err(LlamaCoreError::Backend(BackendError::ComputeSingle(
4351 err_msg,
4352 )))
4353 }
4354 }
4355 }
4356 false => {
4357 match chat_graphs.iter_mut().next() {
4358 Some((_, graph)) => {
4359 match graph.compute_single() {
4361 Ok(_) => {
4362 #[cfg(feature = "logging")]
4363 debug!(target: "stdout", "Compute the chat stream chunk successfully.");
4364
4365 match stream_state {
4366 StreamState::Usage | StreamState::NoUsage => {
4367 let output_buffer =
4369 get_output_buffer_single(graph, OUTPUT_TENSOR)?;
4370
4371 #[cfg(feature = "logging")]
4372 info!(target: "stdout", "retrieved the output buffer");
4373
4374 let output = match String::from_utf8(
4376 output_buffer.clone(),
4377 ) {
4378 Ok(token) => token,
4379 Err(_) => {
4380 let mutex = CACHED_UTF8_ENCODINGS
4381 .get_or_init(|| Mutex::new(Vec::new()));
4382 let mut cached_encodings = mutex.lock().map_err(|e| {
4383 let err_msg = format!(
4384 "Fail to acquire the lock of `UTF8_ENCODINGS`. Reason: {e}"
4385 );
4386
4387 #[cfg(feature = "logging")]
4388 error!(target: "stdout", "{}", &err_msg);
4389
4390
4391 LlamaCoreError::Operation(err_msg)
4392 })?;
4393
4394 cached_encodings
4396 .extend_from_slice(&output_buffer[..]);
4397
4398 match String::from_utf8(
4399 cached_encodings.to_vec(),
4400 ) {
4401 Ok(token) => {
4402 cached_encodings.clear();
4404
4405 token
4406 }
4407 Err(e) => {
4408 if cached_encodings.len() > 4 {
4410 let err_msg = format!("Fail to convert a vector of bytes to string. The length of the utf8 bytes exceeds 4. {e}");
4411
4412 #[cfg(feature = "logging")]
4413 error!(target: "stdout", "{}", &err_msg);
4414
4415 #[cfg(feature = "logging")]
4416 error!(target: "stdout", "The cached buffer: {:?}", &cached_encodings[..]);
4417
4418 cached_encodings.clear();
4426
4427 String::from("")
4428 } else {
4429 let warn_msg = format!("Fail to convert a vector of bytes to string. {e}");
4430
4431 #[cfg(feature = "logging")]
4432 warn!(target: "stdout", "{}", &warn_msg);
4433
4434 String::from("")
4435 }
4436 }
4437 }
4438 }
4439 };
4440
4441 #[cfg(feature = "logging")]
4442 info!(target: "stdout", "decoded the output buffer");
4443
4444 let created = SystemTime::now()
4445 .duration_since(std::time::UNIX_EPOCH)
4446 .map_err(|e| {
4447 let err_msg = format!(
4448 "Failed to get the current time. Reason: {e}"
4449 );
4450
4451 #[cfg(feature = "logging")]
4452 error!(target: "stdout", "{}", &err_msg);
4453
4454 LlamaCoreError::Operation(err_msg)
4455 })?;
4456
4457 let chat_completion_chunk = ChatCompletionChunk {
4458 id,
4459 object: "chat.completion.chunk".to_string(),
4460 created: created.as_secs(),
4461 model: graph.name().to_owned(),
4462 system_fingerprint: "fp_44709d6fcb".to_string(),
4463 choices: vec![ChatCompletionChunkChoice {
4464 index: 0,
4465 delta: ChatCompletionChunkChoiceDelta {
4466 role: ChatCompletionRole::Assistant,
4467 content: Some(output),
4468 tool_calls: vec![],
4469 },
4470 logprobs: None,
4471 finish_reason: None,
4472 }],
4473 usage: None,
4474 };
4475
4476 #[cfg(feature = "logging")]
4477 info!(target: "stdout", "created chat completion chunk");
4478
4479 let chunk_str =
4481 serde_json::to_string(&chat_completion_chunk)
4482 .map_err(|e| {
4483 let err_msg = format!(
4484 "Failed to serialize chat completion chunk. Reason: {e}"
4485 );
4486
4487 #[cfg(feature = "logging")]
4488 error!(target: "stdout", "{}", &err_msg);
4489
4490 LlamaCoreError::Operation(err_msg)
4491 })?;
4492
4493 Ok(format!("data: {chunk_str}\n\n"))
4494 }
4495 StreamState::Done => {
4496 *stream_state = StreamState::EndOfSequence;
4497
4498 Ok("data: [DONE]\n\n".to_string())
4499 }
4500 StreamState::EndOfSequence => {
4501 Ok("[GGML] End of sequence".to_string())
4502 }
4503 }
4504 }
4505 Err(wasmedge_wasi_nn::Error::BackendError(
4506 wasmedge_wasi_nn::BackendError::EndOfSequence,
4507 )) => {
4508 #[cfg(feature = "logging")]
4509 debug!(target: "stdout", "End of sequence");
4510
4511 match stream_state {
4512 StreamState::Usage => {
4513 *stream_state = StreamState::Done;
4514
4515 let token_info = get_token_info_by_graph(graph)?;
4517
4518 let usage = Some(Usage {
4519 prompt_tokens: token_info.prompt_tokens,
4520 completion_tokens: token_info.completion_tokens,
4521 total_tokens: token_info.prompt_tokens
4522 + token_info.completion_tokens,
4523 });
4524
4525 #[cfg(feature = "logging")]
4526 info!(target: "stdout", "token_info: {} prompt tokens, {} completion tokens", token_info.prompt_tokens, token_info.completion_tokens);
4527
4528 let created = SystemTime::now()
4529 .duration_since(std::time::UNIX_EPOCH)
4530 .map_err(|e| {
4531 let err_msg = format!(
4532 "Failed to get the current time. Reason: {e}"
4533 );
4534
4535 #[cfg(feature = "logging")]
4536 error!(target: "stdout", "{}", &err_msg);
4537
4538 LlamaCoreError::Operation(err_msg)
4539 })?;
4540
4541 let chat_completion_chunk = ChatCompletionChunk {
4542 id,
4543 object: "chat.completion.chunk".to_string(),
4544 created: created.as_secs(),
4545 model: graph.name().to_owned(),
4546 system_fingerprint: "fp_44709d6fcb".to_string(),
4547 choices: vec![],
4548 usage,
4549 };
4550
4551 let chunk_str =
4553 serde_json::to_string(&chat_completion_chunk)
4554 .map_err(|e| {
4555 let err_msg = format!(
4556 "Failed to serialize chat completion chunk. Reason: {e}"
4557 );
4558
4559 #[cfg(feature = "logging")]
4560 error!(target: "stdout", "{}", &err_msg);
4561
4562 LlamaCoreError::Operation(err_msg)
4563 })?;
4564
4565 Ok(format!("data: {chunk_str}\n\n"))
4566 }
4567 StreamState::Done | StreamState::NoUsage => {
4568 *stream_state = StreamState::EndOfSequence;
4569
4570 Ok("data: [DONE]\n\n".to_string())
4571 }
4572 StreamState::EndOfSequence => {
4573 Ok("[GGML] End of sequence".to_string())
4574 }
4575 }
4576 }
4577 Err(wasmedge_wasi_nn::Error::BackendError(
4578 wasmedge_wasi_nn::BackendError::ContextFull,
4579 )) => {
4580 #[cfg(feature = "logging")]
4581 debug!(target: "stdout", "Context full");
4582
4583 match context_full_state {
4584 ContextFullState::Message => {
4585 match include_usage {
4586 true => {
4587 *context_full_state = ContextFullState::Usage
4588 }
4589 false => {
4590 *context_full_state = ContextFullState::Done
4591 }
4592 }
4593
4594 let created = SystemTime::now()
4595 .duration_since(std::time::UNIX_EPOCH)
4596 .map_err(|e| {
4597 let err_msg = format!(
4598 "Failed to get the current time. Reason: {e}"
4599 );
4600
4601 #[cfg(feature = "logging")]
4602 error!(target: "stdout", "{}", &err_msg);
4603
4604 LlamaCoreError::Operation(err_msg)
4605 })?;
4606
4607 let chat_completion_chunk = ChatCompletionChunk {
4608 id,
4609 object: "chat.completion.chunk".to_string(),
4610 created: created.as_secs(),
4611 model: graph.name().to_owned(),
4612 system_fingerprint: "fp_44709d6fcb".to_string(),
4613 choices: vec![ChatCompletionChunkChoice {
4614 index: 0,
4615 delta: ChatCompletionChunkChoiceDelta {
4616 role: ChatCompletionRole::Assistant,
4617 content: Some(
4618 "<|WASMEDGE-GGML-CONTEXT-FULL|>"
4619 .to_string(),
4620 ),
4621 tool_calls: vec![],
4622 },
4623 logprobs: None,
4624 finish_reason: Some(FinishReason::length),
4625 }],
4626 usage: None,
4627 };
4628
4629 let chunk_str =
4631 serde_json::to_string(&chat_completion_chunk)
4632 .map_err(|e| {
4633 let err_msg = format!(
4634 "Failed to serialize chat completion chunk. Reason: {e}"
4635 );
4636
4637 #[cfg(feature = "logging")]
4638 error!(target: "stdout", "{}", &err_msg);
4639
4640 LlamaCoreError::Operation(err_msg)
4641 })?;
4642
4643 Ok(format!("data: {chunk_str}\n\n"))
4644 }
4645 ContextFullState::Usage => {
4646 *context_full_state = ContextFullState::Done;
4647
4648 let token_info = get_token_info_by_graph(graph)?;
4650
4651 let usage = Some(Usage {
4652 prompt_tokens: token_info.prompt_tokens,
4653 completion_tokens: token_info.completion_tokens,
4654 total_tokens: token_info.prompt_tokens
4655 + token_info.completion_tokens,
4656 });
4657
4658 let created = SystemTime::now()
4659 .duration_since(std::time::UNIX_EPOCH)
4660 .map_err(|e| {
4661 let err_msg = format!(
4662 "Failed to get the current time. Reason: {e}"
4663 );
4664
4665 #[cfg(feature = "logging")]
4666 error!(target: "stdout", "{}", &err_msg);
4667
4668 LlamaCoreError::Operation(err_msg)
4669 })?;
4670
4671 let chat_completion_chunk = ChatCompletionChunk {
4672 id,
4673 object: "chat.completion.chunk".to_string(),
4674 created: created.as_secs(),
4675 model: graph.name().to_owned(),
4676 system_fingerprint: "fp_44709d6fcb".to_string(),
4677 choices: vec![],
4678 usage,
4679 };
4680
4681 let chunk_str =
4683 serde_json::to_string(&chat_completion_chunk)
4684 .map_err(|e| {
4685 let err_msg = format!(
4686 "Failed to serialize chat completion chunk. Reason: {e}"
4687 );
4688
4689 #[cfg(feature = "logging")]
4690 error!(target: "stdout", "{}", &err_msg);
4691
4692 LlamaCoreError::Operation(err_msg)
4693 })?;
4694
4695 Ok(format!("data: {chunk_str}\n\n"))
4696 }
4697 ContextFullState::Done => {
4698 *context_full_state = ContextFullState::EndOfSequence;
4699
4700 Ok("data: [DONE]\n\n".to_string())
4701 }
4702 ContextFullState::EndOfSequence => {
4703 Ok("[GGML] End of sequence".to_string())
4704 }
4705 }
4706 }
4707 Err(wasmedge_wasi_nn::Error::BackendError(
4708 wasmedge_wasi_nn::BackendError::PromptTooLong,
4709 )) => {
4710 #[cfg(feature = "logging")]
4711 debug!(target: "stdout", "Prompt too long");
4712
4713 match prompt_too_long_state {
4714 PromptTooLongState::Message => {
4715 match include_usage {
4716 true => {
4717 *prompt_too_long_state =
4718 PromptTooLongState::Usage
4719 }
4720 false => {
4721 *prompt_too_long_state =
4722 PromptTooLongState::Done
4723 }
4724 }
4725
4726 let created = SystemTime::now()
4727 .duration_since(std::time::UNIX_EPOCH)
4728 .map_err(|e| {
4729 let err_msg = format!(
4730 "Failed to get the current time. Reason: {e}"
4731 );
4732
4733 #[cfg(feature = "logging")]
4734 error!(target: "stdout", "{}", &err_msg);
4735
4736 LlamaCoreError::Operation(err_msg)
4737 })?;
4738
4739 let chat_completion_chunk = ChatCompletionChunk {
4740 id,
4741 object: "chat.completion.chunk".to_string(),
4742 created: created.as_secs(),
4743 model: graph.name().to_owned(),
4744 system_fingerprint: "fp_44709d6fcb".to_string(),
4745 choices: vec![ChatCompletionChunkChoice {
4746 index: 0,
4747 delta: ChatCompletionChunkChoiceDelta {
4748 role: ChatCompletionRole::Assistant,
4749 content: None,
4750 tool_calls: vec![],
4751 },
4752 logprobs: None,
4753 finish_reason: Some(FinishReason::length),
4754 }],
4755 usage: None,
4756 };
4757
4758 let chunk_str =
4760 serde_json::to_string(&chat_completion_chunk)
4761 .map_err(|e| {
4762 let err_msg = format!(
4763 "Failed to serialize chat completion chunk. Reason: {e}"
4764 );
4765
4766 #[cfg(feature = "logging")]
4767 error!(target: "stdout", "{}", &err_msg);
4768
4769 LlamaCoreError::Operation(err_msg)
4770 })?;
4771
4772 Ok(format!("data: {chunk_str}\n\n"))
4773 }
4774 PromptTooLongState::Usage => {
4775 *prompt_too_long_state = PromptTooLongState::Done;
4776
4777 let token_info = get_token_info_by_graph(graph)?;
4779
4780 let usage = Some(Usage {
4781 prompt_tokens: token_info.prompt_tokens,
4782 completion_tokens: token_info.completion_tokens,
4783 total_tokens: token_info.prompt_tokens
4784 + token_info.completion_tokens,
4785 });
4786
4787 let created = SystemTime::now()
4788 .duration_since(std::time::UNIX_EPOCH)
4789 .map_err(|e| {
4790 let err_msg = format!(
4791 "Failed to get the current time. Reason: {e}"
4792 );
4793
4794 #[cfg(feature = "logging")]
4795 error!(target: "stdout", "{}", &err_msg);
4796
4797 LlamaCoreError::Operation(err_msg)
4798 })?;
4799
4800 let chat_completion_chunk = ChatCompletionChunk {
4801 id,
4802 object: "chat.completion.chunk".to_string(),
4803 created: created.as_secs(),
4804 model: graph.name().to_owned(),
4805 system_fingerprint: "fp_44709d6fcb".to_string(),
4806 choices: vec![],
4807 usage,
4808 };
4809
4810 let chunk_str =
4812 serde_json::to_string(&chat_completion_chunk)
4813 .map_err(|e| {
4814 let err_msg = format!(
4815 "Failed to serialize chat completion chunk. Reason: {e}"
4816 );
4817
4818 #[cfg(feature = "logging")]
4819 error!(target: "stdout", "{}", &err_msg);
4820
4821 LlamaCoreError::Operation(err_msg)
4822 })?;
4823
4824 Ok(format!("data: {chunk_str}\n\n"))
4825 }
4826 PromptTooLongState::Done => {
4827 *prompt_too_long_state =
4828 PromptTooLongState::EndOfSequence;
4829
4830 Ok("data: [DONE]\n\n".to_string())
4831 }
4832 PromptTooLongState::EndOfSequence => {
4833 Ok("[GGML] End of sequence".to_string())
4834 }
4835 }
4836 }
4837 Err(e) => {
4838 let err_msg = format!(
4839 "Failed to compute the chat completion. Reason: {e}"
4840 );
4841
4842 #[cfg(feature = "logging")]
4843 error!(target: "stdout", "{}", &err_msg);
4844
4845 Err(LlamaCoreError::Backend(BackendError::ComputeSingle(
4846 err_msg,
4847 )))
4848 }
4849 }
4850 }
4851 None => {
4852 let err_msg = "There is no model available in the chat graphs.";
4853
4854 #[cfg(feature = "logging")]
4855 error!(target: "stdout", "{}", &err_msg);
4856
4857 Err(LlamaCoreError::Operation(err_msg.into()))
4858 }
4859 }
4860 }
4861 }
4862 }
4863 None => {
4864 match chat_graphs.iter_mut().next() {
4865 Some((_, graph)) => {
4866 match graph.compute_single() {
4868 Ok(_) => {
4869 #[cfg(feature = "logging")]
4870 debug!(target: "stdout", "Compute the chat stream chunk successfully.");
4871
4872 match stream_state {
4873 StreamState::Usage | StreamState::NoUsage => {
4874 let output_buffer =
4876 get_output_buffer_single(graph, OUTPUT_TENSOR)?;
4877
4878 #[cfg(feature = "logging")]
4879 info!(target: "stdout", "retrieved the output buffer");
4880
4881 let output = match String::from_utf8(output_buffer.clone()) {
4883 Ok(token) => token,
4884 Err(_) => {
4885 let mutex = CACHED_UTF8_ENCODINGS
4886 .get_or_init(|| Mutex::new(Vec::new()));
4887 let mut cached_encodings = mutex.lock().map_err(|e| {
4888 let err_msg = format!(
4889 "Fail to acquire the lock of `UTF8_ENCODINGS`. Reason: {e}"
4890 );
4891
4892 #[cfg(feature = "logging")]
4893 error!(target: "stdout", "{}", &err_msg);
4894
4895 LlamaCoreError::Operation(err_msg)
4896 })?;
4897
4898 cached_encodings.extend_from_slice(&output_buffer[..]);
4899
4900 match String::from_utf8(cached_encodings.to_vec()) {
4901 Ok(token) => {
4902 cached_encodings.clear();
4904
4905 token
4906 }
4907 Err(e) => {
4908 if cached_encodings.len() > 4 {
4910 let err_msg = format!("Fail to convert a vector of bytes to string. The length of the utf8 bytes exceeds 4. {e}");
4911
4912 #[cfg(feature = "logging")]
4913 error!(target: "stdout", "{}", &err_msg);
4914
4915 #[cfg(feature = "logging")]
4916 error!(target: "stdout", "The cached buffer: {:?}", &cached_encodings[..]);
4917
4918 cached_encodings.clear();
4925
4926 String::from("")
4927 } else {
4928 let warn_msg = format!("Fail to convert a vector of bytes to string. {e}");
4929
4930 #[cfg(feature = "logging")]
4931 warn!(target: "stdout", "{}", &warn_msg);
4932
4933 String::from("")
4934 }
4935 }
4936 }
4937 }
4938 };
4939
4940 #[cfg(feature = "logging")]
4941 info!(target: "stdout", "decoded the output buffer");
4942
4943 let created = SystemTime::now()
4944 .duration_since(std::time::UNIX_EPOCH)
4945 .map_err(|e| {
4946 let err_msg = format!(
4947 "Failed to get the current time. Reason: {e}"
4948 );
4949
4950 #[cfg(feature = "logging")]
4951 error!(target: "stdout", "{}", &err_msg);
4952
4953 LlamaCoreError::Operation(err_msg)
4954 })?;
4955
4956 let chat_completion_chunk = ChatCompletionChunk {
4957 id,
4958 object: "chat.completion.chunk".to_string(),
4959 created: created.as_secs(),
4960 model: graph.name().to_owned(),
4961 system_fingerprint: "fp_44709d6fcb".to_string(),
4962 choices: vec![ChatCompletionChunkChoice {
4963 index: 0,
4964 delta: ChatCompletionChunkChoiceDelta {
4965 role: ChatCompletionRole::Assistant,
4966 content: Some(output),
4967 tool_calls: vec![],
4968 },
4969 logprobs: None,
4970 finish_reason: None,
4971 }],
4972 usage: None,
4973 };
4974
4975 #[cfg(feature = "logging")]
4976 info!(target: "stdout", "created chat completion chunk");
4977
4978 let chunk_str = serde_json::to_string(&chat_completion_chunk)
4980 .map_err(|e| {
4981 let err_msg = format!(
4982 "Failed to serialize chat completion chunk. Reason: {e}"
4983 );
4984
4985 #[cfg(feature = "logging")]
4986 error!(target: "stdout", "{}", &err_msg);
4987
4988 LlamaCoreError::Operation(err_msg)
4989 })?;
4990
4991 Ok(format!("data: {chunk_str}\n\n"))
4992 }
4993 StreamState::Done => {
4994 *stream_state = StreamState::EndOfSequence;
4995
4996 Ok("data: [DONE]\n\n".to_string())
4997 }
4998 StreamState::EndOfSequence => {
4999 Ok("[GGML] End of sequence".to_string())
5000 }
5001 }
5002 }
5003 Err(wasmedge_wasi_nn::Error::BackendError(
5004 wasmedge_wasi_nn::BackendError::EndOfSequence,
5005 )) => {
5006 #[cfg(feature = "logging")]
5007 debug!(target: "stdout", "End of sequence");
5008
5009 match stream_state {
5010 StreamState::Usage => {
5011 *stream_state = StreamState::Done;
5012
5013 let token_info = get_token_info_by_graph(graph)?;
5015
5016 let usage = Some(Usage {
5017 prompt_tokens: token_info.prompt_tokens,
5018 completion_tokens: token_info.completion_tokens,
5019 total_tokens: token_info.prompt_tokens
5020 + token_info.completion_tokens,
5021 });
5022
5023 #[cfg(feature = "logging")]
5024 info!(target: "stdout", "token_info: {} prompt tokens, {} completion tokens", token_info.prompt_tokens, token_info.completion_tokens);
5025
5026 let created = SystemTime::now()
5027 .duration_since(std::time::UNIX_EPOCH)
5028 .map_err(|e| {
5029 let err_msg = format!(
5030 "Failed to get the current time. Reason: {e}"
5031 );
5032
5033 #[cfg(feature = "logging")]
5034 error!(target: "stdout", "{}", &err_msg);
5035
5036 LlamaCoreError::Operation(err_msg)
5037 })?;
5038
5039 let chat_completion_chunk = ChatCompletionChunk {
5040 id,
5041 object: "chat.completion.chunk".to_string(),
5042 created: created.as_secs(),
5043 model: graph.name().to_owned(),
5044 system_fingerprint: "fp_44709d6fcb".to_string(),
5045 choices: vec![],
5046 usage,
5047 };
5048
5049 let chunk_str = serde_json::to_string(&chat_completion_chunk)
5051 .map_err(|e| {
5052 let err_msg = format!(
5053 "Failed to serialize chat completion chunk. Reason: {e}"
5054 );
5055
5056 #[cfg(feature = "logging")]
5057 error!(target: "stdout", "{}", &err_msg);
5058
5059 LlamaCoreError::Operation(err_msg)
5060 })?;
5061
5062 Ok(format!("data: {chunk_str}\n\n"))
5063 }
5064 StreamState::Done | StreamState::NoUsage => {
5065 *stream_state = StreamState::EndOfSequence;
5066
5067 Ok("data: [DONE]\n\n".to_string())
5068 }
5069 StreamState::EndOfSequence => {
5070 Ok("[GGML] End of sequence".to_string())
5071 }
5072 }
5073 }
5074 Err(wasmedge_wasi_nn::Error::BackendError(
5075 wasmedge_wasi_nn::BackendError::ContextFull,
5076 )) => {
5077 #[cfg(feature = "logging")]
5078 debug!(target: "stdout", "Context full");
5079
5080 match context_full_state {
5081 ContextFullState::Message => {
5082 match include_usage {
5083 true => *context_full_state = ContextFullState::Usage,
5084 false => *context_full_state = ContextFullState::Done,
5085 }
5086
5087 let created = SystemTime::now()
5088 .duration_since(std::time::UNIX_EPOCH)
5089 .map_err(|e| {
5090 let err_msg = format!(
5091 "Failed to get the current time. Reason: {e}"
5092 );
5093
5094 #[cfg(feature = "logging")]
5095 error!(target: "stdout", "{}", &err_msg);
5096
5097 LlamaCoreError::Operation(err_msg)
5098 })?;
5099
5100 let chat_completion_chunk = ChatCompletionChunk {
5101 id,
5102 object: "chat.completion.chunk".to_string(),
5103 created: created.as_secs(),
5104 model: graph.name().to_owned(),
5105 system_fingerprint: "fp_44709d6fcb".to_string(),
5106 choices: vec![ChatCompletionChunkChoice {
5107 index: 0,
5108 delta: ChatCompletionChunkChoiceDelta {
5109 role: ChatCompletionRole::Assistant,
5110 content: Some(
5111 "<|WASMEDGE-GGML-CONTEXT-FULL|>".to_string(),
5112 ),
5113 tool_calls: vec![],
5114 },
5115 logprobs: None,
5116 finish_reason: Some(FinishReason::length),
5117 }],
5118 usage: None,
5119 };
5120
5121 let chunk_str = serde_json::to_string(&chat_completion_chunk)
5123 .map_err(|e| {
5124 let err_msg = format!(
5125 "Failed to serialize chat completion chunk. Reason: {e}"
5126 );
5127
5128 #[cfg(feature = "logging")]
5129 error!(target: "stdout", "{}", &err_msg);
5130
5131 LlamaCoreError::Operation(err_msg)
5132 })?;
5133
5134 Ok(format!("data: {chunk_str}\n\n"))
5135 }
5136 ContextFullState::Usage => {
5137 *context_full_state = ContextFullState::Done;
5138
5139 let token_info = get_token_info_by_graph(graph)?;
5141
5142 let usage = Some(Usage {
5143 prompt_tokens: token_info.prompt_tokens,
5144 completion_tokens: token_info.completion_tokens,
5145 total_tokens: token_info.prompt_tokens
5146 + token_info.completion_tokens,
5147 });
5148
5149 let created = SystemTime::now()
5150 .duration_since(std::time::UNIX_EPOCH)
5151 .map_err(|e| {
5152 let err_msg = format!(
5153 "Failed to get the current time. Reason: {e}"
5154 );
5155
5156 #[cfg(feature = "logging")]
5157 error!(target: "stdout", "{}", &err_msg);
5158
5159 LlamaCoreError::Operation(err_msg)
5160 })?;
5161
5162 let chat_completion_chunk = ChatCompletionChunk {
5163 id,
5164 object: "chat.completion.chunk".to_string(),
5165 created: created.as_secs(),
5166 model: graph.name().to_owned(),
5167 system_fingerprint: "fp_44709d6fcb".to_string(),
5168 choices: vec![],
5169 usage,
5170 };
5171
5172 let chunk_str = serde_json::to_string(&chat_completion_chunk)
5174 .map_err(|e| {
5175 let err_msg = format!(
5176 "Failed to serialize chat completion chunk. Reason: {e}"
5177 );
5178
5179 #[cfg(feature = "logging")]
5180 error!(target: "stdout", "{}", &err_msg);
5181
5182 LlamaCoreError::Operation(err_msg)
5183 })?;
5184
5185 Ok(format!("data: {chunk_str}\n\n"))
5186 }
5187 ContextFullState::Done => {
5188 *context_full_state = ContextFullState::EndOfSequence;
5189
5190 Ok("data: [DONE]\n\n".to_string())
5191 }
5192 ContextFullState::EndOfSequence => {
5193 Ok("[GGML] End of sequence".to_string())
5194 }
5195 }
5196 }
5197 Err(wasmedge_wasi_nn::Error::BackendError(
5198 wasmedge_wasi_nn::BackendError::PromptTooLong,
5199 )) => {
5200 #[cfg(feature = "logging")]
5201 debug!(target: "stdout", "Prompt too long");
5202
5203 match prompt_too_long_state {
5204 PromptTooLongState::Message => {
5205 match include_usage {
5206 true => *prompt_too_long_state = PromptTooLongState::Usage,
5207 false => *prompt_too_long_state = PromptTooLongState::Done,
5208 }
5209
5210 let created = SystemTime::now()
5211 .duration_since(std::time::UNIX_EPOCH)
5212 .map_err(|e| {
5213 let err_msg = format!(
5214 "Failed to get the current time. Reason: {e}"
5215 );
5216
5217 #[cfg(feature = "logging")]
5218 error!(target: "stdout", "{}", &err_msg);
5219
5220 LlamaCoreError::Operation(err_msg)
5221 })?;
5222
5223 let chat_completion_chunk = ChatCompletionChunk {
5224 id,
5225 object: "chat.completion.chunk".to_string(),
5226 created: created.as_secs(),
5227 model: graph.name().to_owned(),
5228 system_fingerprint: "fp_44709d6fcb".to_string(),
5229 choices: vec![ChatCompletionChunkChoice {
5230 index: 0,
5231 delta: ChatCompletionChunkChoiceDelta {
5232 role: ChatCompletionRole::Assistant,
5233 content: None,
5234 tool_calls: vec![],
5235 },
5236 logprobs: None,
5237 finish_reason: Some(FinishReason::length),
5238 }],
5239 usage: None,
5240 };
5241
5242 let chunk_str = serde_json::to_string(&chat_completion_chunk)
5244 .map_err(|e| {
5245 let err_msg = format!(
5246 "Failed to serialize chat completion chunk. Reason: {e}"
5247 );
5248
5249 #[cfg(feature = "logging")]
5250 error!(target: "stdout", "{}", &err_msg);
5251
5252 LlamaCoreError::Operation(err_msg)
5253 })?;
5254
5255 Ok(format!("data: {chunk_str}\n\n"))
5256 }
5257 PromptTooLongState::Usage => {
5258 *prompt_too_long_state = PromptTooLongState::Done;
5259
5260 let token_info = get_token_info_by_graph(graph)?;
5262
5263 let usage = Some(Usage {
5264 prompt_tokens: token_info.prompt_tokens,
5265 completion_tokens: token_info.completion_tokens,
5266 total_tokens: token_info.prompt_tokens
5267 + token_info.completion_tokens,
5268 });
5269
5270 let created = SystemTime::now()
5271 .duration_since(std::time::UNIX_EPOCH)
5272 .map_err(|e| {
5273 let err_msg = format!(
5274 "Failed to get the current time. Reason: {e}"
5275 );
5276
5277 #[cfg(feature = "logging")]
5278 error!(target: "stdout", "{}", &err_msg);
5279
5280 LlamaCoreError::Operation(err_msg)
5281 })?;
5282
5283 let chat_completion_chunk = ChatCompletionChunk {
5284 id,
5285 object: "chat.completion.chunk".to_string(),
5286 created: created.as_secs(),
5287 model: graph.name().to_owned(),
5288 system_fingerprint: "fp_44709d6fcb".to_string(),
5289 choices: vec![],
5290 usage,
5291 };
5292
5293 let chunk_str = serde_json::to_string(&chat_completion_chunk)
5295 .map_err(|e| {
5296 let err_msg = format!(
5297 "Failed to serialize chat completion chunk. Reason: {e}"
5298 );
5299
5300 #[cfg(feature = "logging")]
5301 error!(target: "stdout", "{}", &err_msg);
5302
5303 LlamaCoreError::Operation(err_msg)
5304 })?;
5305
5306 Ok(format!("data: {chunk_str}\n\n"))
5307 }
5308 PromptTooLongState::Done => {
5309 *prompt_too_long_state = PromptTooLongState::EndOfSequence;
5310
5311 Ok("data: [DONE]\n\n".to_string())
5312 }
5313 PromptTooLongState::EndOfSequence => {
5314 Ok("[GGML] End of sequence".to_string())
5315 }
5316 }
5317 }
5318 Err(e) => {
5319 let err_msg =
5320 format!("Failed to compute the chat completion. Reason: {e}");
5321
5322 #[cfg(feature = "logging")]
5323 error!(target: "stdout", "{}", &err_msg);
5324
5325 Err(LlamaCoreError::Backend(BackendError::ComputeSingle(
5326 err_msg,
5327 )))
5328 }
5329 }
5330 }
5331 None => {
5332 let err_msg = "There is no model available in the chat graphs.";
5333
5334 #[cfg(feature = "logging")]
5335 error!(target: "stdout", "{}", &err_msg);
5336
5337 Err(LlamaCoreError::Operation(err_msg.into()))
5338 }
5339 }
5340 }
5341 };
5342
5343 #[cfg(feature = "logging")]
5344 info!(target: "stdout", "Return the chat stream chunk!");
5345
5346 res
5347}
5348
5349#[allow(dead_code)]
5350#[derive(Debug)]
5351struct ParseResult {
5352 raw: String,
5353 content: Option<String>,
5354 tool_calls: Vec<ToolCall>,
5355}