1use crate::{
4 error,
5 metadata::ggml::GgmlMetadata,
6 running_mode,
7 utils::{
8 gen_chat_id, get_output_buffer, get_output_buffer_single, get_token_info_by_graph,
9 get_token_info_by_graph_name, set_tensor_data_u8,
10 },
11 Graph, RunningMode, CACHED_UTF8_ENCODINGS, CHAT_GRAPHS, OUTPUT_TENSOR,
12};
13use chat_prompts::{
14 chat::{BuildChatPrompt, ChatPrompt},
15 PromptTemplateType,
16};
17use either::{Either, Left, Right};
18use endpoints::{
19 chat::{
20 ChatCompletionChunk, ChatCompletionChunkChoice, ChatCompletionChunkChoiceDelta,
21 ChatCompletionObject, ChatCompletionObjectChoice, ChatCompletionObjectMessage,
22 ChatCompletionRequest, ChatCompletionRequestMessage, ChatCompletionRole,
23 ChatCompletionUserMessageContent, ContentPart, Function, ToolCall, ToolCallForChunk,
24 ToolChoice,
25 },
26 common::{FinishReason, Usage},
27};
28use error::{BackendError, LlamaCoreError};
29use std::{
30 collections::VecDeque,
31 pin::Pin,
32 sync::{
33 atomic::{AtomicBool, Ordering},
34 Mutex, OnceLock,
35 },
36 task::{Context, Poll, Waker},
37 time::SystemTime,
38};
39
40static CHAT_STREAM_WAKER_QUEUE: OnceLock<Mutex<VecDeque<Waker>>> = OnceLock::new();
42
43static CHAT_STREAM_ACTIVE: AtomicBool = AtomicBool::new(false);
45
46pub async fn chat(
48 chat_request: &mut ChatCompletionRequest,
49) -> Result<
50 (
51 Either<impl futures::TryStream<Ok = String, Error = LlamaCoreError>, ChatCompletionObject>,
52 bool,
53 ),
54 LlamaCoreError,
55> {
56 #[cfg(feature = "logging")]
57 {
58 debug!(target: "stdout", "tool choice: {:?}", chat_request.tool_choice.as_ref());
59 debug!(target: "stdout", "tools: {:?}", chat_request.tools.as_ref());
60 debug!(target: "stdout", "stream mode: {:?}", chat_request.stream);
61 }
62
63 let result = match chat_request.stream {
64 Some(true) => match chat_stream(chat_request).await {
65 Ok((stream, include_tool_calls)) => Ok((Left(stream), include_tool_calls)),
66 Err(e) => Err(e),
67 },
68 Some(false) | None => match chat_once(chat_request).await {
69 Ok((chat_completion_object, include_tool_calls)) => {
70 Ok((Right(chat_completion_object), include_tool_calls))
71 }
72 Err(e) => Err(e),
73 },
74 };
75
76 #[cfg(feature = "logging")]
77 info!(target: "stdout", "Reset the model metadata");
78
79 result
80}
81
82async fn chat_stream(
83 chat_request: &mut ChatCompletionRequest,
84) -> Result<
85 (
86 impl futures::TryStream<Ok = String, Error = LlamaCoreError>,
87 bool,
88 ),
89 LlamaCoreError,
90> {
91 #[cfg(feature = "logging")]
92 info!(target: "stdout", "Process chat completion request in the stream mode");
93
94 let running_mode = running_mode()?;
95 if !running_mode.contains(RunningMode::CHAT) && !running_mode.contains(RunningMode::RAG) {
96 let err_msg = "The chat completion is only supported in the chat or rag mode.";
97
98 #[cfg(feature = "logging")]
99 error!(target: "stdout", "{err_msg}");
100
101 return Err(LlamaCoreError::Operation(err_msg.to_string()));
102 }
103
104 let model_name = chat_request.model.clone();
105 let id = match &chat_request.user {
106 Some(id) => id.clone(),
107 None => gen_chat_id(),
108 };
109 #[cfg(feature = "logging")]
110 info!(target: "stdout", "user: {}", &id);
111
112 #[cfg(feature = "logging")]
113 info!(target: "stdout", "Check model metadata");
114
115 let mut metadata = check_model_metadata(chat_request)?;
117
118 let include_usage = match chat_request.stream_options {
120 Some(ref stream_options) => stream_options.include_usage.unwrap_or_default(),
121 None => metadata.include_usage,
122 };
123 #[cfg(feature = "logging")]
124 info!(target: "stdout", "include_usage: {include_usage}");
125
126 #[cfg(feature = "logging")]
127 info!(target: "stdout", "Build the chat prompt");
128
129 let (prompt, avaible_completion_tokens, tool_use) =
131 build_prompt(model_name.as_ref(), chat_request)?;
132
133 #[cfg(feature = "logging")]
134 {
135 info!(target: "stdout", "prompt:\n{}", &prompt);
136 info!(target: "stdout", "available_completion_tokens: {avaible_completion_tokens}");
137 info!(target: "stdout", "tool_use: {tool_use}");
138 }
139
140 #[cfg(feature = "logging")]
141 info!(target: "stdout", "Update the n_predict");
142
143 update_n_predict(chat_request, &mut metadata, avaible_completion_tokens)?;
145
146 #[cfg(feature = "logging")]
147 info!(target: "stdout", "Feed the prompt to the model");
148
149 set_prompt(chat_request.model.as_ref(), &prompt)?;
151
152 let stream = match tool_use {
153 false => (ChatStream::new(model_name, id, include_usage, None), false),
154 true => {
155 let chat_graphs = match CHAT_GRAPHS.get() {
156 Some(chat_graphs) => chat_graphs,
157 None => {
158 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
159
160 #[cfg(feature = "logging")]
161 error!(target: "stdout", "{}", &err_msg);
162
163 return Err(LlamaCoreError::Operation(err_msg.into()));
164 }
165 };
166
167 let mut chat_graphs = chat_graphs.lock().map_err(|e| {
168 let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
169
170 #[cfg(feature = "logging")]
171 error!(target: "stdout", "{}", &err_msg);
172
173 LlamaCoreError::Operation(err_msg)
174 })?;
175
176 match model_name {
177 Some(model_name) => match chat_graphs.contains_key(&model_name) {
178 true => {
179 let graph = chat_graphs.get_mut(&model_name).unwrap();
180 chat_stream_for_tool(graph, id, include_usage)?
181 }
182 false => match chat_graphs.iter_mut().next() {
183 Some((_, graph)) => chat_stream_for_tool(graph, id, include_usage)?,
184 None => {
185 let err_msg = "There is no model available in the chat graphs.";
186
187 #[cfg(feature = "logging")]
188 error!(target: "stdout", "{}", &err_msg);
189
190 return Err(LlamaCoreError::Operation(err_msg.into()));
191 }
192 },
193 },
194 None => match chat_graphs.iter_mut().next() {
195 Some((_, graph)) => chat_stream_for_tool(graph, id, include_usage)?,
196 None => {
197 let err_msg = "There is no model available in the chat graphs.";
198
199 #[cfg(feature = "logging")]
200 error!(target: "stdout", "{}", &err_msg);
201
202 return Err(LlamaCoreError::Operation(err_msg.into()));
203 }
204 },
205 }
206 }
207 };
208
209 #[cfg(feature = "logging")]
210 info!(target: "stdout", "End of the chat completion stream.");
211
212 Ok(stream)
213}
214
215fn chat_stream_for_tool(
216 graph: &mut Graph<GgmlMetadata>,
217 id: impl Into<String>,
218 include_usage: bool,
219) -> Result<(ChatStream, bool), LlamaCoreError> {
220 #[cfg(feature = "logging")]
221 info!(target: "stdout", "Handle chat request with available tools by the model named {}.", graph.name());
222
223 let id = id.into();
224
225 match graph.compute() {
226 Ok(_) => {
227 let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
229 let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
230 let err_msg = format!(
231 "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
232 );
233
234 #[cfg(feature = "logging")]
235 error!(target: "stdout", "{}", &err_msg);
236
237 LlamaCoreError::Operation(err_msg)
238 })?;
239
240 #[cfg(feature = "logging")]
241 info!(target: "stdout", "raw generation:\n{output}");
242
243 let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
245 LlamaCoreError::Operation(format!("Failed to post-process the output. {e}"))
246 })?;
247
248 #[cfg(feature = "logging")]
249 info!(target: "stdout", "post-processed generation:\n{}", &message);
250
251 let token_info = get_token_info_by_graph(graph)?;
253
254 #[cfg(feature = "logging")]
255 info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
256
257 let usage = Some(Usage {
258 prompt_tokens: token_info.prompt_tokens,
259 completion_tokens: token_info.completion_tokens,
260 total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
261 });
262
263 let created = SystemTime::now()
264 .duration_since(std::time::UNIX_EPOCH)
265 .map_err(|e| {
266 let err_msg = format!("Failed to get the current time. Reason: {e}");
267
268 #[cfg(feature = "logging")]
269 error!(target: "stdout", "{}", &err_msg);
270
271 LlamaCoreError::Operation(err_msg)
272 })?;
273
274 if graph.metadata.prompt_template != PromptTemplateType::MistralTool
275 && graph.metadata.prompt_template != PromptTemplateType::ChatMLTool
276 && graph.metadata.prompt_template != PromptTemplateType::GroqLlama3Tool
277 && graph.metadata.prompt_template != PromptTemplateType::Llama3Tool
278 && graph.metadata.prompt_template != PromptTemplateType::InternLM2Tool
279 && graph.metadata.prompt_template != PromptTemplateType::NemotronTool
280 && graph.metadata.prompt_template != PromptTemplateType::FunctionaryV32
281 && graph.metadata.prompt_template != PromptTemplateType::FunctionaryV31
282 && graph.metadata.prompt_template != PromptTemplateType::MistralSmallTool
283 && graph.metadata.prompt_template != PromptTemplateType::Llama4Chat
284 && graph.metadata.prompt_template != PromptTemplateType::Qwen3NoThink
285 {
286 let err_msg = format!("Unsupported prompt template: {}. The tool use is only supported for 'mistral-tool', 'chatml-tool', 'groq-llama3-tool', 'llama-3-tool', 'internlm-2-tool', 'nemotron-tool', 'functionary-31', 'functionary-32', 'mistral-small-tool', 'llama-4-chat', and 'qwen-3-no-think' prompt templates.", graph.metadata.prompt_template);
287
288 #[cfg(feature = "logging")]
289 error!(target: "stdout", "{}", &err_msg);
290
291 return Err(LlamaCoreError::Operation(err_msg));
292 }
293
294 let parsed_result = parse_tool_calls(&message, graph.metadata.prompt_template)?;
295
296 let content = if parsed_result.tool_calls.is_empty() {
297 Some(parsed_result.raw.clone())
298 } else {
299 parsed_result.content.clone()
300 };
301
302 let (tool_calls, include_tool_calls) = match parsed_result.tool_calls.is_empty() {
303 false => {
304 let tool_calls: Vec<ToolCallForChunk> = parsed_result
305 .tool_calls
306 .into_iter()
307 .enumerate()
308 .map(|(index, tool_call)| ToolCallForChunk {
309 index,
310 id: tool_call.id,
311 ty: tool_call.ty,
312 function: tool_call.function,
313 })
314 .collect();
315 (tool_calls, true)
316 }
317 true => (vec![], false),
318 };
319
320 let tool_call_chunk = {
322 let chat_completion_chunk = ChatCompletionChunk {
323 id: id.clone(),
324 object: "chat.completion.chunk".to_string(),
325 created: created.as_secs(),
326 model: graph.name().to_owned(),
327 system_fingerprint: "fp_44709d6fcb".to_string(),
328 choices: vec![ChatCompletionChunkChoice {
329 index: 0,
330 delta: ChatCompletionChunkChoiceDelta {
331 role: ChatCompletionRole::Assistant,
332 content,
333 tool_calls,
334 },
335 logprobs: None,
336 finish_reason: None,
337 }],
338 usage: None,
339 };
340 let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
341 let err_msg = format!("Failed to serialize chat completion chunk. Reason: {e}");
342
343 #[cfg(feature = "logging")]
344 error!(target: "stdout", "{}", &err_msg);
345
346 LlamaCoreError::Operation(err_msg)
347 })?;
348
349 format!("data: {chunk_str}\n\n")
350 };
351
352 let usage_chunk = {
354 let chat_completion_chunk = ChatCompletionChunk {
355 id: id.clone(),
356 object: "chat.completion.chunk".to_string(),
357 created: created.as_secs(),
358 model: graph.name().to_owned(),
359 system_fingerprint: "fp_44709d6fcb".to_string(),
360 choices: vec![],
361 usage,
362 };
363 let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
364 let err_msg = format!("Failed to serialize chat completion chunk. Reason: {e}");
365
366 #[cfg(feature = "logging")]
367 error!(target: "stdout", "{}", &err_msg);
368
369 LlamaCoreError::Operation(err_msg)
370 })?;
371
372 format!("data: {chunk_str}\n\n")
373 };
374
375 let ending_chunk = "data: [DONE]\n\n".to_string();
377
378 let chunks = vec![tool_call_chunk, usage_chunk, ending_chunk];
379
380 let stream = ChatStream::new(
381 Some(graph.name().to_owned()),
382 id,
383 include_usage,
384 Some(chunks),
385 );
386
387 Ok((stream, include_tool_calls))
388 }
389 Err(wasmedge_wasi_nn::Error::BackendError(wasmedge_wasi_nn::BackendError::ContextFull)) => {
390 let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
392 let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
393 let err_msg = format!(
394 "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
395 );
396
397 #[cfg(feature = "logging")]
398 error!(target: "stdout", "{}", &err_msg);
399
400 LlamaCoreError::Operation(err_msg)
401 })?;
402
403 let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
405 let err_msg = format!("Failed to post-process the output. {e}");
406
407 #[cfg(feature = "logging")]
408 error!(target: "stdout", "{}", &err_msg);
409
410 LlamaCoreError::Operation(err_msg)
411 })?;
412
413 let token_info = get_token_info_by_graph(graph)?;
415
416 #[cfg(feature = "logging")]
417 info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
418
419 let usage = Some(Usage {
420 prompt_tokens: token_info.prompt_tokens,
421 completion_tokens: token_info.completion_tokens,
422 total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
423 });
424
425 let created = SystemTime::now()
426 .duration_since(std::time::UNIX_EPOCH)
427 .map_err(|e| {
428 let err_msg = format!("Failed to get the current time. Reason: {e}");
429
430 #[cfg(feature = "logging")]
431 error!(target: "stdout", "{}", &err_msg);
432
433 LlamaCoreError::Operation(err_msg)
434 })?;
435
436 let context_full_chunk = {
438 let chat_completion_chunk = ChatCompletionChunk {
439 id: id.clone(),
440 object: "chat.completion.chunk".to_string(),
441 created: created.as_secs(),
442 model: graph.name().to_owned(),
443 system_fingerprint: "fp_44709d6fcb".to_string(),
444 choices: vec![ChatCompletionChunkChoice {
445 index: 0,
446 delta: ChatCompletionChunkChoiceDelta {
447 role: ChatCompletionRole::Assistant,
448 content: Some(message),
449 tool_calls: vec![],
450 },
451 logprobs: None,
452 finish_reason: Some(FinishReason::length),
453 }],
454 usage: None,
455 };
456
457 let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
459 let err_msg = format!("Failed to serialize chat completion chunk. Reason: {e}");
460
461 #[cfg(feature = "logging")]
462 error!(target: "stdout", "{}", &err_msg);
463
464 LlamaCoreError::Operation(err_msg)
465 })?;
466
467 format!("data: {chunk_str}\n\n")
468 };
469
470 let usage_chunk = {
472 let chat_completion_chunk = ChatCompletionChunk {
473 id: id.clone(),
474 object: "chat.completion.chunk".to_string(),
475 created: created.as_secs(),
476 model: graph.name().to_owned(),
477 system_fingerprint: "fp_44709d6fcb".to_string(),
478 choices: vec![],
479 usage,
480 };
481
482 let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
484 let err_msg = format!("Failed to serialize chat completion chunk. Reason: {e}");
485
486 #[cfg(feature = "logging")]
487 error!(target: "stdout", "{}", &err_msg);
488
489 LlamaCoreError::Operation(err_msg)
490 })?;
491
492 format!("data: {chunk_str}\n\n")
493 };
494
495 let ending_chunk = "data: [DONE]\n\n".to_string();
497
498 let chunks = vec![context_full_chunk, usage_chunk, ending_chunk];
499
500 let stream = ChatStream::new(
501 Some(graph.name().to_owned()),
502 id,
503 include_usage,
504 Some(chunks),
505 );
506
507 Ok((stream, false))
508 }
509 Err(wasmedge_wasi_nn::Error::BackendError(
510 wasmedge_wasi_nn::BackendError::PromptTooLong,
511 )) => {
512 #[cfg(feature = "logging")]
513 warn!(target: "stdout", "The prompt is too long. Please reduce the length of your input and try again.");
514
515 let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
517 let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
518 let err_msg = format!(
519 "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
520 );
521
522 #[cfg(feature = "logging")]
523 error!(target: "stdout", "{}", &err_msg);
524
525 LlamaCoreError::Operation(err_msg)
526 })?;
527
528 let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
530 let err_msg = format!("Failed to post-process the output. {e}");
531
532 #[cfg(feature = "logging")]
533 error!(target: "stdout", "{}", &err_msg);
534
535 LlamaCoreError::Operation(err_msg)
536 })?;
537
538 let token_info = get_token_info_by_graph(graph)?;
540
541 #[cfg(feature = "logging")]
542 info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
543
544 let usage = Some(Usage {
545 prompt_tokens: token_info.prompt_tokens,
546 completion_tokens: token_info.completion_tokens,
547 total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
548 });
549
550 let created = SystemTime::now()
551 .duration_since(std::time::UNIX_EPOCH)
552 .map_err(|e| {
553 let err_msg = format!("Failed to get the current time. Reason: {e}");
554
555 #[cfg(feature = "logging")]
556 error!(target: "stdout", "{}", &err_msg);
557
558 LlamaCoreError::Operation(err_msg)
559 })?;
560
561 let prompt_too_long_chunk = {
563 let chat_completion_chunk = ChatCompletionChunk {
564 id: id.clone(),
565 object: "chat.completion.chunk".to_string(),
566 created: created.as_secs(),
567 model: graph.name().to_owned(),
568 system_fingerprint: "fp_44709d6fcb".to_string(),
569 choices: vec![ChatCompletionChunkChoice {
570 index: 0,
571 delta: ChatCompletionChunkChoiceDelta {
572 role: ChatCompletionRole::Assistant,
573 content: Some(message),
574 tool_calls: vec![],
575 },
576 logprobs: None,
577 finish_reason: Some(FinishReason::length),
578 }],
579 usage: None,
580 };
581
582 let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
584 let err_msg = format!("Failed to serialize chat completion chunk. Reason: {e}");
585
586 #[cfg(feature = "logging")]
587 error!(target: "stdout", "{}", &err_msg);
588
589 LlamaCoreError::Operation(err_msg)
590 })?;
591
592 format!("data: {chunk_str}\n\n")
593 };
594
595 let usage_chunk = {
597 let chat_completion_chunk = ChatCompletionChunk {
598 id: id.clone(),
599 object: "chat.completion.chunk".to_string(),
600 created: created.as_secs(),
601 model: graph.name().to_owned(),
602 system_fingerprint: "fp_44709d6fcb".to_string(),
603 choices: vec![],
604 usage,
605 };
606
607 let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
609 let err_msg = format!("Failed to serialize chat completion chunk. Reason: {e}");
610
611 #[cfg(feature = "logging")]
612 error!(target: "stdout", "{}", &err_msg);
613
614 LlamaCoreError::Operation(err_msg)
615 })?;
616
617 format!("data: {chunk_str}\n\n")
618 };
619
620 let ending_chunk = "data: [DONE]\n\n".to_string();
622
623 let chunks = vec![prompt_too_long_chunk, usage_chunk, ending_chunk];
624
625 let stream = ChatStream::new(
626 Some(graph.name().to_owned()),
627 id,
628 include_usage,
629 Some(chunks),
630 );
631
632 Ok((stream, false))
633 }
634 Err(e) => {
635 let err_msg = format!("Failed to compute the chat completion. Reason: {e}");
636
637 #[cfg(feature = "logging")]
638 error!(target: "stdout", "{}", &err_msg);
639
640 Err(LlamaCoreError::Backend(BackendError::Compute(err_msg)))
641 }
642 }
643}
644
645async fn chat_once(
646 chat_request: &mut ChatCompletionRequest,
647) -> Result<(ChatCompletionObject, bool), LlamaCoreError> {
648 #[cfg(feature = "logging")]
649 info!(target: "stdout", "Processing chat completion request in non-stream mode");
650
651 let running_mode = running_mode()?;
652 if !running_mode.contains(RunningMode::CHAT) && !running_mode.contains(RunningMode::RAG) {
653 let err_msg = "The chat completion is only supported in the chat or rag mode.";
654
655 #[cfg(feature = "logging")]
656 error!(target: "stdout", "{err_msg}");
657
658 return Err(LlamaCoreError::Operation(err_msg.to_string()));
659 }
660
661 let model_name = chat_request.model.clone();
662 let id = match &chat_request.user {
663 Some(id) => id.clone(),
664 None => gen_chat_id(),
665 };
666
667 #[cfg(feature = "logging")]
668 info!(target: "stdout", "user: {}", &id);
669
670 #[cfg(feature = "logging")]
671 info!(target: "stdout", "Check model metadata");
672
673 let mut metadata = check_model_metadata(chat_request)?;
675
676 #[cfg(feature = "logging")]
677 info!(target: "stdout", "Build the chat prompt");
678
679 let (prompt, avaible_completion_tokens, tool_use) =
681 build_prompt(model_name.as_ref(), chat_request)?;
682
683 #[cfg(feature = "logging")]
684 {
685 info!(target: "stdout", "prompt:\n{}", &prompt);
686 info!(target: "stdout", "available_completion_tokens: {avaible_completion_tokens}");
687 info!(target: "stdout", "tool_use: {tool_use}");
688 }
689
690 #[cfg(feature = "logging")]
691 info!(target: "stdout", "Update n_predict");
692
693 update_n_predict(chat_request, &mut metadata, avaible_completion_tokens)?;
695
696 #[cfg(feature = "logging")]
697 info!(target: "stdout", "Feed the prompt to the model");
698
699 set_prompt(model_name.as_ref(), &prompt)?;
701
702 #[cfg(feature = "logging")]
703 info!(target: "stdout", "Compute chat completion.");
704
705 let res = compute(model_name.as_ref(), id, tool_use);
707
708 #[cfg(feature = "logging")]
709 info!(target: "stdout", "End of the chat completion");
710
711 reset_model_metadata(model_name.as_ref())?;
713
714 res
715}
716
717fn compute(
718 model_name: Option<&String>,
719 id: impl Into<String>,
720 tool_use: bool,
721) -> Result<(ChatCompletionObject, bool), LlamaCoreError> {
722 let chat_graphs = match CHAT_GRAPHS.get() {
723 Some(chat_graphs) => chat_graphs,
724 None => {
725 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
726
727 #[cfg(feature = "logging")]
728 error!(target: "stdout", "{}", &err_msg);
729
730 return Err(LlamaCoreError::Operation(err_msg.into()));
731 }
732 };
733
734 let mut chat_graphs = chat_graphs.lock().map_err(|e| {
735 let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
736
737 #[cfg(feature = "logging")]
738 error!(target: "stdout", "{}", &err_msg);
739
740 LlamaCoreError::Operation(err_msg)
741 })?;
742
743 match model_name {
744 Some(model_name) => match chat_graphs.contains_key(model_name) {
745 true => {
746 let graph = chat_graphs.get_mut(model_name).unwrap();
747 compute_by_graph(graph, id, tool_use)
748 }
749 false => match chat_graphs.iter_mut().next() {
750 Some((_, graph)) => compute_by_graph(graph, id, tool_use),
751 None => {
752 let err_msg = "There is no model available in the chat graphs.";
753
754 #[cfg(feature = "logging")]
755 error!(target: "stdout", "{}", &err_msg);
756
757 Err(LlamaCoreError::Operation(err_msg.into()))
758 }
759 },
760 },
761 None => match chat_graphs.iter_mut().next() {
762 Some((_, graph)) => compute_by_graph(graph, id, tool_use),
763 None => {
764 let err_msg = "There is no model available in the chat graphs.";
765
766 #[cfg(feature = "logging")]
767 error!(target: "stdout", "{}", &err_msg);
768
769 Err(LlamaCoreError::Operation(err_msg.into()))
770 }
771 },
772 }
773}
774
775fn compute_by_graph(
776 graph: &mut Graph<GgmlMetadata>,
777 id: impl Into<String>,
778 tool_use: bool,
779) -> Result<(ChatCompletionObject, bool), LlamaCoreError> {
780 #[cfg(feature = "logging")]
781 info!(target: "stdout", "Compute chat completion by the model named {}.", graph.name());
782
783 match graph.compute() {
784 Ok(_) => {
785 let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
787 let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
788 let err_msg = format!(
789 "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
790 );
791
792 #[cfg(feature = "logging")]
793 error!(target: "stdout", "{}", &err_msg);
794
795 LlamaCoreError::Operation(err_msg)
796 })?;
797
798 #[cfg(feature = "logging")]
799 info!(target: "stdout", "raw generation: {output}");
800
801 let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
803 LlamaCoreError::Operation(format!("Failed to post-process the output. {e}"))
804 })?;
805
806 #[cfg(feature = "logging")]
807 info!(target: "stdout", "post-processed generation:\n{}", &message);
808
809 let token_info = get_token_info_by_graph(graph)?;
811
812 #[cfg(feature = "logging")]
813 info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
814
815 let created = SystemTime::now()
816 .duration_since(std::time::UNIX_EPOCH)
817 .map_err(|e| {
818 let err_msg = format!("Failed to get the current time. Reason: {e}");
819
820 #[cfg(feature = "logging")]
821 error!(target: "stdout", "{}", &err_msg);
822
823 LlamaCoreError::Operation(err_msg)
824 })?;
825
826 match tool_use {
827 true => {
828 if graph.metadata.prompt_template != PromptTemplateType::MistralTool
829 && graph.metadata.prompt_template != PromptTemplateType::ChatMLTool
830 && graph.metadata.prompt_template != PromptTemplateType::GroqLlama3Tool
831 && graph.metadata.prompt_template != PromptTemplateType::Llama3Tool
832 && graph.metadata.prompt_template != PromptTemplateType::InternLM2Tool
833 && graph.metadata.prompt_template != PromptTemplateType::NemotronTool
834 && graph.metadata.prompt_template != PromptTemplateType::FunctionaryV32
835 && graph.metadata.prompt_template != PromptTemplateType::FunctionaryV31
836 && graph.metadata.prompt_template != PromptTemplateType::MistralSmallTool
837 && graph.metadata.prompt_template != PromptTemplateType::Llama4Chat
838 && graph.metadata.prompt_template != PromptTemplateType::Qwen3NoThink
839 {
840 let err_msg = format!("Unsupported prompt template: {}. The tool use is only supported for 'mistral-tool', 'chatml-tool', 'groq-llama3-tool', 'llama-3-tool', 'internlm-2-tool', 'nemotron-tool', 'functionary-31', 'functionary-32', 'mistral-small-tool', 'llama-4-chat', and 'qwen-3-no-think' prompt templates.", graph.metadata.prompt_template);
841
842 #[cfg(feature = "logging")]
843 error!(target: "stdout", "{}", &err_msg);
844
845 return Err(LlamaCoreError::Operation(err_msg));
846 }
847
848 let parsed_result = parse_tool_calls(&message, graph.metadata.prompt_template)?;
849
850 let (finish_reason, content, include_tool_calls) =
851 if parsed_result.tool_calls.is_empty() {
852 (FinishReason::stop, Some(parsed_result.raw.clone()), false)
853 } else {
854 (
855 FinishReason::tool_calls,
856 parsed_result.content.clone(),
857 true,
858 )
859 };
860
861 let res = ChatCompletionObject {
862 id: id.into(),
863 object: String::from("chat.completion"),
864 created: created.as_secs(),
865 model: graph.name().to_owned(),
866 choices: vec![ChatCompletionObjectChoice {
867 index: 0,
868 message: ChatCompletionObjectMessage {
869 role: ChatCompletionRole::Assistant,
870 content,
871 tool_calls: parsed_result.tool_calls,
872 function_call: None,
873 },
874 finish_reason,
875 logprobs: None,
876 }],
877 usage: Usage {
878 prompt_tokens: token_info.prompt_tokens,
879 completion_tokens: token_info.completion_tokens,
880 total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
881 },
882 };
883
884 Ok((res, include_tool_calls))
886 }
887 false => {
888 let res = ChatCompletionObject {
890 id: id.into(),
891 object: String::from("chat.completion"),
892 created: created.as_secs(),
893 model: graph.name().to_owned(),
894 choices: vec![ChatCompletionObjectChoice {
895 index: 0,
896 message: ChatCompletionObjectMessage {
897 role: ChatCompletionRole::Assistant,
898 content: Some(message),
899 tool_calls: vec![],
900 function_call: None,
901 },
902 finish_reason: FinishReason::stop,
903 logprobs: None,
904 }],
905 usage: Usage {
906 prompt_tokens: token_info.prompt_tokens,
907 completion_tokens: token_info.completion_tokens,
908 total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
909 },
910 };
911
912 Ok((res, false))
913 }
914 }
915 }
916 Err(wasmedge_wasi_nn::Error::BackendError(wasmedge_wasi_nn::BackendError::ContextFull)) => {
917 let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
919 let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
920 let err_msg = format!(
921 "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
922 );
923
924 #[cfg(feature = "logging")]
925 error!(target: "stdout", "{}", &err_msg);
926
927 LlamaCoreError::Operation(err_msg)
928 })?;
929
930 let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
932 let err_msg = format!("Failed to post-process the output. {e}");
933
934 #[cfg(feature = "logging")]
935 error!(target: "stdout", "{}", &err_msg);
936
937 LlamaCoreError::Operation(err_msg)
938 })?;
939
940 let token_info = get_token_info_by_graph(graph)?;
942
943 #[cfg(feature = "logging")]
944 info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
945
946 let created = SystemTime::now()
947 .duration_since(std::time::UNIX_EPOCH)
948 .map_err(|e| {
949 let err_msg = format!("Failed to get the current time. Reason: {e}");
950
951 #[cfg(feature = "logging")]
952 error!(target: "stdout", "{}", &err_msg);
953
954 LlamaCoreError::Operation(err_msg)
955 })?;
956
957 let res = ChatCompletionObject {
959 id: id.into(),
960 object: String::from("chat.completion"),
961 created: created.as_secs(),
962 model: graph.name().to_owned(),
963 choices: vec![ChatCompletionObjectChoice {
964 index: 0,
965 message: ChatCompletionObjectMessage {
966 role: ChatCompletionRole::Assistant,
967 content: Some(message),
968 tool_calls: vec![],
969 function_call: None,
970 },
971 finish_reason: FinishReason::length,
972 logprobs: None,
973 }],
974 usage: Usage {
975 prompt_tokens: token_info.prompt_tokens,
976 completion_tokens: token_info.completion_tokens,
977 total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
978 },
979 };
980
981 Ok((res, false))
982 }
983 Err(wasmedge_wasi_nn::Error::BackendError(
984 wasmedge_wasi_nn::BackendError::PromptTooLong,
985 )) => {
986 #[cfg(feature = "logging")]
987 warn!(target: "stdout", "The prompt is too long. Please reduce the length of your input and try again.");
988
989 let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
991 let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
992 let err_msg = format!(
993 "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
994 );
995
996 #[cfg(feature = "logging")]
997 error!(target: "stdout", "{}", &err_msg);
998
999 LlamaCoreError::Operation(err_msg)
1000 })?;
1001
1002 let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
1004 let err_msg = format!("Failed to post-process the output. {e}");
1005
1006 #[cfg(feature = "logging")]
1007 error!(target: "stdout", "{}", &err_msg);
1008
1009 LlamaCoreError::Operation(err_msg)
1010 })?;
1011
1012 let token_info = get_token_info_by_graph(graph)?;
1014
1015 #[cfg(feature = "logging")]
1016 info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
1017
1018 let usage = Usage {
1019 prompt_tokens: token_info.prompt_tokens,
1020 completion_tokens: token_info.completion_tokens,
1021 total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
1022 };
1023
1024 let created = SystemTime::now()
1025 .duration_since(std::time::UNIX_EPOCH)
1026 .map_err(|e| {
1027 let err_msg = format!("Failed to get the current time. Reason: {e}");
1028
1029 #[cfg(feature = "logging")]
1030 error!(target: "stdout", "{}", &err_msg);
1031
1032 LlamaCoreError::Operation(err_msg)
1033 })?;
1034
1035 let res = ChatCompletionObject {
1037 id: id.into(),
1038 object: String::from("chat.completion"),
1039 created: created.as_secs(),
1040 model: graph.name().to_owned(),
1041 choices: vec![ChatCompletionObjectChoice {
1042 index: 0,
1043 message: ChatCompletionObjectMessage {
1044 role: ChatCompletionRole::Assistant,
1045 content: Some(message),
1046 tool_calls: vec![],
1047 function_call: None,
1048 },
1049 finish_reason: FinishReason::length,
1050 logprobs: None,
1051 }],
1052 usage,
1053 };
1054
1055 Ok((res, false))
1056 }
1057 Err(e) => {
1058 let err_msg = format!("Failed to compute the chat completion. Reason: {e}");
1059
1060 #[cfg(feature = "logging")]
1061 error!(target: "stdout", "{}", &err_msg);
1062
1063 Err(LlamaCoreError::Backend(BackendError::Compute(err_msg)))
1064 }
1065 }
1066}
1067
1068fn parse_tool_calls(
1069 input: &str,
1070 prompt_template: PromptTemplateType,
1071) -> Result<ParseResult, LlamaCoreError> {
1072 match prompt_template {
1073 PromptTemplateType::MistralTool => match regex::Regex::new(r"\[\{.*?\}\]") {
1074 Ok(re) => {
1075 let mut values: Vec<serde_json::Value> = vec![];
1076 for cap in re.captures_iter(input) {
1077 let matched = &cap[0];
1078
1079 #[cfg(feature = "logging")]
1080 info!(target: "stdout", "captured: {matched}");
1081
1082 match serde_json::from_str::<Vec<serde_json::Value>>(matched) {
1083 Ok(group) => values.extend(group),
1084 Err(e) => {
1085 let err_msg =
1086 format!("Failed to deserialize generated tool calls. Reason: {e}");
1087
1088 #[cfg(feature = "logging")]
1089 error!(target: "stdout", "{}", &err_msg);
1090
1091 return Err(LlamaCoreError::Operation(err_msg));
1092 }
1093 }
1094 }
1095
1096 let mut tool_calls: Vec<ToolCall> = vec![];
1097 for value in values.iter() {
1098 let name = match value.get("name") {
1099 Some(name) => name.to_string().replace("\"", ""),
1100 None => {
1101 let err_msg = format!(
1102 "Failed to get the name of the function. Tool call: {value:?}"
1103 );
1104
1105 #[cfg(feature = "logging")]
1106 error!(target: "stdout", "{}", &err_msg);
1107
1108 return Err(LlamaCoreError::Operation(err_msg));
1109 }
1110 };
1111
1112 let arguments = match value.get("arguments") {
1113 Some(arguments) => arguments.to_string(),
1114 None => {
1115 let err_msg = format!(
1116 "Failed to get the arguments of the function. Tool call: {value:?}"
1117 );
1118
1119 #[cfg(feature = "logging")]
1120 error!(target: "stdout", "{}", &err_msg);
1121
1122 return Err(LlamaCoreError::Operation(err_msg));
1123 }
1124 };
1125
1126 let function = Function { name, arguments };
1127
1128 let tool_call = ToolCall {
1129 id: "call_abc123".to_string(),
1130 ty: "function".to_string(),
1131 function,
1132 };
1133
1134 tool_calls.push(tool_call);
1135 }
1136
1137 let parsed = ParseResult {
1138 raw: input.to_owned(),
1139 content: None,
1140 tool_calls,
1141 };
1142
1143 #[cfg(feature = "logging")]
1144 info!(target: "stdout", "parsed result: {parsed:?}");
1145
1146 Ok(parsed)
1147 }
1148 Err(e) => {
1149 let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1150
1151 #[cfg(feature = "logging")]
1152 error!(target: "stdout", "{}", &err_msg);
1153
1154 Err(LlamaCoreError::Operation(err_msg))
1155 }
1156 },
1157 PromptTemplateType::ChatMLTool => {
1158 match regex::Regex::new(r"<tool_call>(.*?)</tool_call>") {
1159 Ok(re) => {
1160 let mut values: Vec<serde_json::Value> = vec![];
1161 for cap in re.captures_iter(input) {
1162 let matched = cap[1].replace("\\n", ""); #[cfg(feature = "logging")]
1165 info!(target: "stdout", "captured: {}", &matched);
1166
1167 match serde_json::from_str::<serde_json::Value>(&matched) {
1168 Ok(value) => values.push(value),
1169 Err(e) => {
1170 let err_msg = format!(
1171 "Failed to deserialize generated tool calls. Reason: {e}"
1172 );
1173
1174 #[cfg(feature = "logging")]
1175 error!(target: "stdout", "{}", &err_msg);
1176
1177 return Err(LlamaCoreError::Operation(err_msg));
1178 }
1179 }
1180 }
1181
1182 let mut tool_calls: Vec<ToolCall> = vec![];
1183 for value in values.iter() {
1184 let name = match value.get("name") {
1185 Some(name) => name.to_string().replace("\"", ""),
1186 None => {
1187 let err_msg = format!(
1188 "Failed to get the name of the function. Tool call: {value:?}"
1189 );
1190
1191 #[cfg(feature = "logging")]
1192 error!(target: "stdout", "{}", &err_msg);
1193
1194 return Err(LlamaCoreError::Operation(err_msg));
1195 }
1196 };
1197
1198 let arguments = match value.get("arguments") {
1199 Some(arguments) => arguments.to_string(),
1200 None => {
1201 let err_msg = format!(
1202 "Failed to get the arguments of the function. Tool call: {value:?}"
1203 );
1204
1205 #[cfg(feature = "logging")]
1206 error!(target: "stdout", "{}", &err_msg);
1207
1208 return Err(LlamaCoreError::Operation(err_msg));
1209 }
1210 };
1211
1212 let function = Function { name, arguments };
1213
1214 let tool_call = ToolCall {
1215 id: "call_abc123".to_string(),
1216 ty: "function".to_string(),
1217 function,
1218 };
1219
1220 tool_calls.push(tool_call);
1221 }
1222
1223 let parsed = ParseResult {
1224 raw: input.to_owned(),
1225 content: None,
1226 tool_calls,
1227 };
1228
1229 #[cfg(feature = "logging")]
1230 info!(target: "stdout", "parsed result: {parsed:?}");
1231
1232 Ok(parsed)
1233 }
1234 Err(e) => {
1235 let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1236
1237 #[cfg(feature = "logging")]
1238 error!(target: "stdout", "{}", &err_msg);
1239
1240 Err(LlamaCoreError::Operation(err_msg))
1241 }
1242 }
1243 }
1244 PromptTemplateType::GroqLlama3Tool => {
1245 #[cfg(feature = "logging")]
1246 info!(target: "stdout", "raw input: {input}");
1247
1248 match regex::Regex::new(r"(?s)<tool_call>((.|\r|\n)*?)</tool_call>") {
1249 Ok(re) => {
1250 let mut values: Vec<serde_json::Value> = vec![];
1251 for cap in re.captures_iter(input) {
1252 let matched = cap[1].trim();
1253
1254 #[cfg(feature = "logging")]
1255 info!(target: "stdout", "captured: {matched}");
1256
1257 match serde_json::from_str::<serde_json::Value>(matched) {
1258 Ok(value) => values.push(value),
1259 Err(e) => {
1260 let err_msg = format!(
1261 "Failed to deserialize generated tool calls. Reason: {e}"
1262 );
1263
1264 #[cfg(feature = "logging")]
1265 error!(target: "stdout", "{}", &err_msg);
1266
1267 return Err(LlamaCoreError::Operation(err_msg));
1268 }
1269 }
1270 }
1271
1272 let mut tool_calls: Vec<ToolCall> = vec![];
1273 for value in values.iter() {
1274 let name = match value.get("name") {
1275 Some(name) => name.to_string().replace("\"", ""),
1276 None => {
1277 let err_msg = format!(
1278 "Failed to get the name of the function. Tool call: {value:?}"
1279 );
1280
1281 #[cfg(feature = "logging")]
1282 error!(target: "stdout", "{}", &err_msg);
1283
1284 return Err(LlamaCoreError::Operation(err_msg));
1285 }
1286 };
1287
1288 let arguments = match value.get("arguments") {
1289 Some(arguments) => {
1290 if arguments.is_string() {
1291 arguments.as_str().unwrap().to_string()
1292 } else if arguments.is_object() {
1293 let map = arguments.as_object().unwrap();
1294
1295 #[cfg(feature = "logging")]
1296 info!(target: "stdout", "func arguments: {map:?}");
1297
1298 serde_json::to_string(map).unwrap()
1299 } else {
1300 serde_json::to_string(arguments).unwrap()
1301 }
1302 }
1303 None => {
1304 let err_msg = format!(
1305 "Failed to get the arguments of the function. Tool call: {value:?}"
1306 );
1307
1308 #[cfg(feature = "logging")]
1309 error!(target: "stdout", "{}", &err_msg);
1310
1311 return Err(LlamaCoreError::Operation(err_msg));
1312 }
1313 };
1314
1315 let function = Function { name, arguments };
1316
1317 let tool_call = ToolCall {
1318 id: "call_abc123".to_string(),
1319 ty: "function".to_string(),
1320 function,
1321 };
1322
1323 tool_calls.push(tool_call);
1324 }
1325
1326 let parsed = if tool_calls.is_empty() {
1327 ParseResult {
1328 raw: input.to_owned(),
1329 content: Some(input.to_owned()),
1330 tool_calls: vec![],
1331 }
1332 } else {
1333 ParseResult {
1334 raw: input.to_owned(),
1335 content: None,
1336 tool_calls,
1337 }
1338 };
1339
1340 #[cfg(feature = "logging")]
1341 info!(target: "stdout", "parsed result: {parsed:?}");
1342
1343 Ok(parsed)
1344 }
1345 Err(e) => {
1346 let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1347
1348 #[cfg(feature = "logging")]
1349 error!(target: "stdout", "{}", &err_msg);
1350
1351 Err(LlamaCoreError::Operation(err_msg))
1352 }
1353 }
1354 }
1355 PromptTemplateType::Llama3Tool => {
1356 #[cfg(feature = "logging")]
1357 info!(target: "stdout", "raw input: {input}");
1358
1359 let re = match regex::Regex::new(r"^\{(.|\r|\n)*\}$") {
1360 Ok(re) => re,
1361 Err(e) => {
1362 let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1363
1364 #[cfg(feature = "logging")]
1365 error!(target: "stdout", "{}", &err_msg);
1366
1367 return Err(LlamaCoreError::Operation(err_msg));
1368 }
1369 };
1370
1371 if re.is_match(input) {
1372 match serde_json::from_str::<serde_json::Value>(input) {
1373 Ok(value) => {
1374 let values: Vec<serde_json::Value> = vec![value];
1375
1376 let mut tool_calls: Vec<ToolCall> = vec![];
1377 for value in values.iter() {
1378 let name = match value.get("name") {
1379 Some(name) => name.to_string().replace("\"", ""),
1380 None => {
1381 let err_msg = format!(
1382 "Failed to get the name of the function. Tool call: {value:?}"
1383 );
1384
1385 #[cfg(feature = "logging")]
1386 error!(target: "stdout", "{}", &err_msg);
1387
1388 return Err(LlamaCoreError::Operation(err_msg));
1389 }
1390 };
1391
1392 let arguments = match value.get("parameters") {
1393 Some(arguments) => arguments.to_string(),
1394 None => {
1395 let err_msg = format!(
1396 "Failed to get the arguments of the function. Tool call: {value:?}"
1397 );
1398
1399 #[cfg(feature = "logging")]
1400 error!(target: "stdout", "{}", &err_msg);
1401
1402 return Err(LlamaCoreError::Operation(err_msg));
1403 }
1404 };
1405
1406 let function = Function { name, arguments };
1407
1408 let tool_call = ToolCall {
1409 id: "call_abc123".to_string(),
1410 ty: "function".to_string(),
1411 function,
1412 };
1413
1414 tool_calls.push(tool_call);
1415 }
1416
1417 let parsed = ParseResult {
1418 raw: input.to_owned(),
1419 content: None,
1420 tool_calls,
1421 };
1422
1423 #[cfg(feature = "logging")]
1424 info!(target: "stdout", "parsed result: {parsed:?}");
1425
1426 Ok(parsed)
1427 }
1428 Err(e) => {
1429 let err_msg =
1430 format!("Failed to deserialize generated tool calls. Reason: {e}");
1431
1432 #[cfg(feature = "logging")]
1433 error!(target: "stdout", "{}", &err_msg);
1434
1435 Err(LlamaCoreError::Operation(err_msg))
1436 }
1437 }
1438 } else {
1439 let parsed = ParseResult {
1440 raw: input.to_owned(),
1441 content: None,
1442 tool_calls: vec![],
1443 };
1444
1445 #[cfg(feature = "logging")]
1446 info!(target: "stdout", "parsed result: {parsed:?}");
1447
1448 Ok(parsed)
1449 }
1450 }
1451 PromptTemplateType::InternLM2Tool => {
1452 #[cfg(feature = "logging")]
1453 info!(target: "stdout", "raw input: {input}");
1454
1455 let blocks: Vec<&str> = input.trim().split("<|action_start|><|plugin|>").collect();
1456
1457 #[cfg(feature = "logging")]
1458 info!(target: "stdout", "blocks: {blocks:?}");
1459
1460 let mut tool_calls: Vec<ToolCall> = vec![];
1461 let mut content = String::new();
1462 for block in blocks {
1463 let block = block.trim();
1464 if !block.is_empty() {
1465 if block.ends_with("<|action_end|>") {
1466 let value = block.trim().trim_end_matches("<|action_end|>");
1467
1468 #[cfg(feature = "logging")]
1469 info!(target: "stdout", "tool call: {value}");
1470
1471 match serde_json::from_str::<serde_json::Value>(value) {
1472 Ok(value) => {
1473 let name = match value.get("name") {
1474 Some(name) => name.to_string().replace("\"", ""),
1475 None => {
1476 let err_msg = format!(
1477 "Failed to get the name of the function. Tool call: {value:?}"
1478 );
1479
1480 #[cfg(feature = "logging")]
1481 error!(target: "stdout", "{}", &err_msg);
1482
1483 return Err(LlamaCoreError::Operation(err_msg));
1484 }
1485 };
1486
1487 let arguments = match value.get("parameters") {
1488 Some(arguments) => arguments.to_string(),
1489 None => {
1490 let err_msg = format!(
1491 "Failed to get the arguments of the function. Tool call: {value:?}"
1492 );
1493
1494 #[cfg(feature = "logging")]
1495 error!(target: "stdout", "{}", &err_msg);
1496
1497 return Err(LlamaCoreError::Operation(err_msg));
1498 }
1499 };
1500
1501 let function = Function { name, arguments };
1502
1503 let tool_call = ToolCall {
1504 id: "call_abc123".to_string(),
1505 ty: "function".to_string(),
1506 function,
1507 };
1508
1509 tool_calls.push(tool_call);
1510 }
1511 Err(e) => {
1512 let err_msg = format!(
1513 "Failed to deserialize generated tool calls. Reason: {e}"
1514 );
1515
1516 #[cfg(feature = "logging")]
1517 error!(target: "stdout", "{}", &err_msg);
1518
1519 return Err(LlamaCoreError::Operation(err_msg));
1520 }
1521 }
1522 } else {
1523 content.push_str(block);
1524 content.push('\n');
1525 }
1526 }
1527 }
1528
1529 let parsed = match content.is_empty() {
1530 true => ParseResult {
1531 raw: input.to_owned(),
1532 content: None,
1533 tool_calls,
1534 },
1535 false => ParseResult {
1536 raw: input.to_owned(),
1537 content: Some(content.trim().to_owned()),
1538 tool_calls,
1539 },
1540 };
1541
1542 #[cfg(feature = "logging")]
1543 info!(target: "stdout", "parsed result: {parsed:?}");
1544
1545 Ok(parsed)
1546 }
1547 PromptTemplateType::NemotronTool => {
1548 #[cfg(feature = "logging")]
1549 info!(target: "stdout", "raw input: {input}");
1550
1551 match regex::Regex::new(r"(?s)<toolcall>\s*(.*?)\s*</toolcall>") {
1552 Ok(re) => {
1553 let mut values: Vec<serde_json::Value> = vec![];
1554 for cap in re.captures_iter(input) {
1555 #[cfg(feature = "logging")]
1556 info!(target: "stdout", "captured: {}", &cap[0]);
1557
1558 #[cfg(feature = "logging")]
1559 info!(target: "stdout", "extracted: {}", &cap[1]);
1560
1561 let matched = cap[1].trim();
1562
1563 #[cfg(feature = "logging")]
1564 info!(target: "stdout", "captured: {matched}");
1565
1566 match serde_json::from_str::<serde_json::Value>(matched) {
1567 Ok(value) => values.push(value),
1568 Err(e) => {
1569 let err_msg = format!(
1570 "Failed to deserialize generated tool calls. Reason: {e}"
1571 );
1572
1573 #[cfg(feature = "logging")]
1574 error!(target: "stdout", "{}", &err_msg);
1575
1576 return Err(LlamaCoreError::Operation(err_msg));
1577 }
1578 }
1579 }
1580
1581 let mut tool_calls: Vec<ToolCall> = vec![];
1582 for value in values.iter() {
1583 let name = match value.get("name") {
1584 Some(name) => name.to_string().replace("\"", ""),
1585 None => {
1586 let err_msg = format!(
1587 "Failed to get the name of the function. Tool call: {value:?}"
1588 );
1589
1590 #[cfg(feature = "logging")]
1591 error!(target: "stdout", "{}", &err_msg);
1592
1593 return Err(LlamaCoreError::Operation(err_msg));
1594 }
1595 };
1596
1597 let arguments = match value.get("arguments") {
1598 Some(arguments) => arguments.to_string(),
1599 None => {
1600 let err_msg = format!(
1601 "Failed to get the arguments of the function. Tool call: {value:?}"
1602 );
1603
1604 #[cfg(feature = "logging")]
1605 error!(target: "stdout", "{}", &err_msg);
1606
1607 return Err(LlamaCoreError::Operation(err_msg));
1608 }
1609 };
1610
1611 let function = Function { name, arguments };
1612
1613 let tool_call = ToolCall {
1614 id: "call_abc123".to_string(),
1615 ty: "function".to_string(),
1616 function,
1617 };
1618
1619 tool_calls.push(tool_call);
1620 }
1621
1622 let parsed = ParseResult {
1623 raw: input.to_owned(),
1624 content: None,
1625 tool_calls,
1626 };
1627
1628 #[cfg(feature = "logging")]
1629 info!(target: "stdout", "parsed result: {parsed:?}");
1630
1631 Ok(parsed)
1632 }
1633 Err(e) => {
1634 let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1635
1636 #[cfg(feature = "logging")]
1637 error!(target: "stdout", "{}", &err_msg);
1638
1639 Err(LlamaCoreError::Operation(err_msg))
1640 }
1641 }
1642 }
1643 PromptTemplateType::FunctionaryV32 => {
1644 #[cfg(feature = "logging")]
1645 info!(target: "stdout", "raw input: {input}");
1646
1647 match regex::Regex::new(r">>>\s*(\w+)\s*\{(.*)\}<\|eot_id\|>") {
1648 Ok(re) => {
1649 let mut tool_calls: Vec<ToolCall> = vec![];
1650 for cap in re.captures_iter(input) {
1651 #[cfg(feature = "logging")]
1652 info!(target: "stdout", "func_name: {}", &cap[1]);
1653
1654 #[cfg(feature = "logging")]
1655 info!(target: "stdout", "arguments: {}", &cap[2]);
1656
1657 let tool_call = ToolCall {
1658 id: "call_abc123".to_string(),
1659 ty: "function".to_string(),
1660 function: Function {
1661 name: cap[1].to_string(),
1662 arguments: cap[2].to_string(),
1663 },
1664 };
1665
1666 tool_calls.push(tool_call);
1667 }
1668
1669 let parsed = ParseResult {
1670 raw: input.to_owned(),
1671 content: None,
1672 tool_calls,
1673 };
1674
1675 #[cfg(feature = "logging")]
1676 info!(target: "stdout", "parsed result: {parsed:?}");
1677
1678 Ok(parsed)
1679 }
1680 Err(e) => {
1681 let warn_msg = format!("Failed to create a regex pattern. Reason: {e}");
1682
1683 #[cfg(feature = "logging")]
1684 warn!(target: "stdout", "{}", &warn_msg);
1685
1686 Ok(ParseResult {
1687 raw: input.to_owned(),
1688 content: None,
1689 tool_calls: vec![],
1690 })
1691 }
1692 }
1693 }
1694 PromptTemplateType::FunctionaryV31 => {
1695 #[cfg(feature = "logging")]
1696 info!(target: "stdout", "raw input: {input}");
1697
1698 match regex::Regex::new(r"<function=(\w+)>\s*(\{.*?\})</function>") {
1699 Ok(re) => {
1700 let mut tool_calls: Vec<ToolCall> = vec![];
1701 for cap in re.captures_iter(input) {
1702 #[cfg(feature = "logging")]
1703 info!(target: "stdout", "func_name: {}", &cap[1]);
1704
1705 #[cfg(feature = "logging")]
1706 info!(target: "stdout", "arguments: {}", &cap[2]);
1707
1708 let tool_call = ToolCall {
1709 id: "call_abc123".to_string(),
1710 ty: "function".to_string(),
1711 function: Function {
1712 name: cap[1].to_string(),
1713 arguments: cap[2].to_string(),
1714 },
1715 };
1716
1717 tool_calls.push(tool_call);
1718 }
1719
1720 let parsed = ParseResult {
1721 raw: input.to_owned(),
1722 content: None,
1723 tool_calls,
1724 };
1725
1726 #[cfg(feature = "logging")]
1727 info!(target: "stdout", "parsed result: {parsed:?}");
1728
1729 Ok(parsed)
1730 }
1731 Err(e) => {
1732 let warn_msg = format!("Failed to create a regex pattern. Reason: {e}");
1733
1734 #[cfg(feature = "logging")]
1735 warn!(target: "stdout", "{}", &warn_msg);
1736
1737 Ok(ParseResult {
1738 raw: input.to_owned(),
1739 content: None,
1740 tool_calls: vec![],
1741 })
1742 }
1743 }
1744 }
1745 PromptTemplateType::MistralSmallTool => {
1746 #[cfg(feature = "logging")]
1747 info!(target: "stdout", "raw input: {input}");
1748
1749 match regex::Regex::new(r"\[TOOL_CALLS\]\s*(\[(.*?)\])") {
1750 Ok(re) => {
1751 let mut values: Vec<serde_json::Value> = vec![];
1752 if let Some(cap) = re.captures(input) {
1753 let matched = cap[1].trim();
1754
1755 #[cfg(feature = "logging")]
1756 info!(target: "stdout", "captured: {matched}");
1757
1758 match serde_json::from_str::<Vec<serde_json::Value>>(matched) {
1759 Ok(vals) => values = vals,
1760 Err(e) => {
1761 let err_msg = format!(
1762 "Failed to deserialize generated tool calls. Reason: {e}"
1763 );
1764
1765 #[cfg(feature = "logging")]
1766 error!(target: "stdout", "{}", &err_msg);
1767
1768 return Err(LlamaCoreError::Operation(err_msg));
1769 }
1770 }
1771 };
1772
1773 let mut tool_calls: Vec<ToolCall> = vec![];
1774 for value in values.iter() {
1775 if let Some(object_map) = value.as_object() {
1776 if object_map.contains_key("function") {
1777 let mut function = Function {
1778 name: String::new(),
1779 arguments: String::new(),
1780 };
1781
1782 let value = object_map.get("function").unwrap();
1783 let func_map = value.as_object().unwrap();
1784 if func_map.contains_key("name") {
1785 let func_name = func_map.get("name").unwrap().as_str().unwrap();
1786 println!("Function name: {func_name:?}");
1787
1788 function.name = func_name.to_string();
1789 }
1790 if func_map.contains_key("arguments") {
1791 let args = func_map.get("arguments").unwrap();
1792 let arguments = args.to_string();
1793 println!("Arguments: {arguments:?}");
1794
1795 function.arguments = arguments;
1796 }
1797
1798 let tool_call = ToolCall {
1799 id: "call_abc123".to_string(),
1800 ty: "function".to_string(),
1801 function,
1802 };
1803
1804 tool_calls.push(tool_call);
1805 } else if object_map.contains_key("name") {
1806 let mut function = Function {
1807 name: String::new(),
1808 arguments: String::new(),
1809 };
1810
1811 let name = object_map.get("name").unwrap().as_str().unwrap();
1812 println!("name: {name:?}");
1813 function.name = name.to_string();
1814
1815 if object_map.contains_key("arguments") {
1816 let args = object_map.get("arguments").unwrap();
1817 let arguments = args.to_string();
1818 println!("Arguments: {arguments:?}");
1819
1820 function.arguments = arguments;
1821 }
1822
1823 let tool_call = ToolCall {
1824 id: "call_abc123".to_string(),
1825 ty: "function".to_string(),
1826 function,
1827 };
1828
1829 tool_calls.push(tool_call);
1830 }
1831 }
1832 }
1833
1834 let parsed = ParseResult {
1835 raw: input.to_owned(),
1836 content: None,
1837 tool_calls,
1838 };
1839
1840 #[cfg(feature = "logging")]
1841 info!(target: "stdout", "parsed result: {parsed:?}");
1842
1843 Ok(parsed)
1844 }
1845 Err(e) => {
1846 let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1847
1848 #[cfg(feature = "logging")]
1849 error!(target: "stdout", "{}", &err_msg);
1850
1851 Err(LlamaCoreError::Operation(err_msg))
1852 }
1853 }
1854 }
1855 PromptTemplateType::Llama4Chat => {
1856 #[cfg(feature = "logging")]
1857 info!(target: "stdout", "raw input: {input:?}");
1858
1859 let mut tool_calls: Vec<ToolCall> = vec![];
1860 if let Ok(value) = serde_json::from_str::<serde_json::Value>(input) {
1861 match value.as_object() {
1862 Some(object_map) => {
1863 #[cfg(feature = "logging")]
1864 debug!(target: "stdout", "object_map: {object_map:?}");
1865
1866 if object_map.contains_key("name") {
1868 let name = object_map.get("name").unwrap().as_str().unwrap();
1869
1870 #[cfg(feature = "logging")]
1871 debug!(target: "stdout", "name: {name:?}");
1872
1873 let mut function = Function {
1874 name: name.to_string(),
1875 arguments: String::new(),
1876 };
1877
1878 if object_map.contains_key("parameters") {
1880 let args = object_map.get("parameters").unwrap();
1881 let arguments = args.to_string();
1882
1883 #[cfg(feature = "logging")]
1884 debug!(target: "stdout", "arguments: {:?}", &arguments);
1885
1886 function.arguments = arguments;
1887 }
1888
1889 tool_calls.push(ToolCall {
1890 id: "call_abc123".to_string(),
1891 ty: "function".to_string(),
1892 function,
1893 });
1894 } else {
1895 let err_msg = format!(
1896 "Failed to get the name of the function. raw input: {input:?}"
1897 );
1898
1899 #[cfg(feature = "logging")]
1900 error!(target: "stdout", "{}", &err_msg);
1901
1902 return Err(LlamaCoreError::Operation(err_msg));
1903 }
1904 }
1905 None => {
1906 let err_msg = format!("Failed to parse the JSON string. JSON: {input}");
1907
1908 #[cfg(feature = "logging")]
1909 error!(target: "stdout", "{}", &err_msg);
1910
1911 return Err(LlamaCoreError::Operation(err_msg));
1912 }
1913 }
1914 }
1915
1916 let parsed = ParseResult {
1917 raw: input.to_owned(),
1918 content: None,
1919 tool_calls,
1920 };
1921
1922 #[cfg(feature = "logging")]
1923 info!(target: "stdout", "parsed result: {parsed:?}");
1924
1925 Ok(parsed)
1926 }
1927 PromptTemplateType::Qwen3NoThink => {
1928 #[cfg(feature = "logging")]
1929 info!(target: "stdout", "raw input: {input:?}");
1930
1931 match regex::Regex::new(r"(?s)<tool_call>((.|\r|\n)*?)</tool_call>") {
1932 Ok(re) => {
1933 let mut values: Vec<serde_json::Value> = vec![];
1934 for cap in re.captures_iter(input) {
1935 let matched = cap[1].trim();
1936
1937 #[cfg(feature = "logging")]
1938 info!(target: "stdout", "captured: {matched}");
1939
1940 match serde_json::from_str::<serde_json::Value>(matched) {
1941 Ok(value) => values.push(value),
1942 Err(e) => {
1943 let err_msg = format!(
1944 "Failed to deserialize generated tool calls. Reason: {e}"
1945 );
1946
1947 #[cfg(feature = "logging")]
1948 error!(target: "stdout", "{}", &err_msg);
1949
1950 return Err(LlamaCoreError::Operation(err_msg));
1951 }
1952 }
1953 }
1954
1955 let mut tool_calls: Vec<ToolCall> = vec![];
1956 for value in values.iter() {
1957 let name = match value.get("name") {
1958 Some(name) => name.to_string().replace("\"", ""),
1959 None => {
1960 let err_msg = format!(
1961 "Failed to get the name of the function. Tool call: {value:?}"
1962 );
1963
1964 #[cfg(feature = "logging")]
1965 error!(target: "stdout", "{}", &err_msg);
1966
1967 return Err(LlamaCoreError::Operation(err_msg));
1968 }
1969 };
1970
1971 let arguments = match value.get("arguments") {
1972 Some(arguments) => {
1973 if arguments.is_string() {
1974 arguments.as_str().unwrap().to_string()
1975 } else if arguments.is_object() {
1976 let map = arguments.as_object().unwrap();
1977
1978 #[cfg(feature = "logging")]
1979 info!(target: "stdout", "func arguments: {map:?}");
1980
1981 serde_json::to_string(map).unwrap()
1982 } else {
1983 serde_json::to_string(arguments).unwrap()
1984 }
1985 }
1986 None => {
1987 let err_msg = format!(
1988 "Failed to get the arguments of the function. Tool call: {value:?}"
1989 );
1990
1991 #[cfg(feature = "logging")]
1992 error!(target: "stdout", "{}", &err_msg);
1993
1994 return Err(LlamaCoreError::Operation(err_msg));
1995 }
1996 };
1997
1998 let function = Function { name, arguments };
1999
2000 let tool_call = ToolCall {
2001 id: "call_abc123".to_string(),
2002 ty: "function".to_string(),
2003 function,
2004 };
2005
2006 tool_calls.push(tool_call);
2007 }
2008
2009 let parsed = if tool_calls.is_empty() {
2010 ParseResult {
2011 raw: input.to_owned(),
2012 content: Some(input.to_owned()),
2013 tool_calls: vec![],
2014 }
2015 } else {
2016 ParseResult {
2017 raw: input.to_owned(),
2018 content: None,
2019 tool_calls,
2020 }
2021 };
2022
2023 #[cfg(feature = "logging")]
2024 info!(target: "stdout", "parsed result: {parsed:?}");
2025
2026 Ok(parsed)
2027 }
2028 Err(e) => {
2029 let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
2030
2031 #[cfg(feature = "logging")]
2032 error!(target: "stdout", "{}", &err_msg);
2033
2034 Err(LlamaCoreError::Operation(err_msg))
2035 }
2036 }
2037 }
2038 _ => {
2039 let err_msg = format!(
2040 "The tool use is only supported for prompt templates: {}, {}, {}, {}, {}, {}, {}, {}, {}, and {}.",
2041 PromptTemplateType::MistralTool,
2042 PromptTemplateType::ChatMLTool,
2043 PromptTemplateType::GroqLlama3Tool,
2044 PromptTemplateType::Llama3Tool,
2045 PromptTemplateType::InternLM2Tool,
2046 PromptTemplateType::NemotronTool,
2047 PromptTemplateType::FunctionaryV32,
2048 PromptTemplateType::MistralSmallTool,
2049 PromptTemplateType::Llama4Chat,
2050 PromptTemplateType::Qwen3NoThink,
2051 );
2052
2053 #[cfg(feature = "logging")]
2054 error!(target: "stdout", "{}", &err_msg);
2055
2056 Err(LlamaCoreError::Operation(err_msg))
2057 }
2058 }
2059}
2060
2061fn check_model_metadata(
2062 chat_request: &ChatCompletionRequest,
2063) -> Result<GgmlMetadata, LlamaCoreError> {
2064 let mut should_update = false;
2065 let mut metadata = get_model_metadata(chat_request.model.as_ref())?;
2066
2067 if metadata.prompt_template.is_image_supported() {
2069 if let Some(ChatCompletionRequestMessage::User(user_message)) = chat_request.messages.last()
2070 {
2071 if let ChatCompletionUserMessageContent::Parts(parts) = user_message.content() {
2072 for part in parts {
2073 if let ContentPart::Image(image_part) = part {
2074 let image = image_part.image();
2075
2076 if image.is_url() {
2077 let err_msg = "The image is provided in URL format. Only base64 format is supported.".to_string();
2078
2079 #[cfg(feature = "logging")]
2080 error!(target: "stdout", "{}", &err_msg);
2081
2082 return Err(LlamaCoreError::Operation(err_msg));
2083 } else {
2084 #[cfg(feature = "logging")]
2085 info!(target: "stdout", "The image is provided in base64 format.");
2086
2087 break;
2090 }
2091 }
2092 }
2093 }
2094 }
2095 }
2096
2097 if let Some(temp) = chat_request.temperature {
2099 if metadata.temperature != temp {
2100 metadata.temperature = temp;
2102
2103 if !should_update {
2104 should_update = true;
2105 }
2106 }
2107 }
2108
2109 if let Some(top_p) = chat_request.top_p {
2111 if metadata.top_p != top_p {
2112 metadata.top_p = top_p;
2114
2115 if !should_update {
2116 should_update = true;
2117 }
2118 }
2119 }
2120
2121 if let Some(frequency_penalty) = chat_request.frequency_penalty {
2123 if metadata.frequency_penalty != frequency_penalty {
2124 metadata.frequency_penalty = frequency_penalty;
2126
2127 if !should_update {
2128 should_update = true;
2129 }
2130 }
2131 }
2132
2133 if let Some(presence_penalty) = chat_request.presence_penalty {
2135 if metadata.presence_penalty != presence_penalty {
2136 metadata.presence_penalty = presence_penalty;
2138
2139 if !should_update {
2140 should_update = true;
2141 }
2142 }
2143 }
2144
2145 if metadata.embeddings {
2147 metadata.embeddings = false;
2148
2149 if !should_update {
2150 should_update = true;
2151 }
2152 }
2153
2154 if should_update {
2155 #[cfg(feature = "logging")]
2156 info!(target: "stdout", "Update the model metadata.");
2157
2158 update_model_metadata(chat_request.model.as_ref(), &metadata)?;
2160 }
2161
2162 Ok(metadata)
2163}
2164
2165fn update_n_predict(
2166 chat_request: &ChatCompletionRequest,
2167 metadata: &mut GgmlMetadata,
2168 available_completion_tokens: u64,
2169) -> Result<(), LlamaCoreError> {
2170 let mut should_update = false;
2171
2172 #[cfg(feature = "logging")]
2173 info!(target: "stdout", "n_predict: {}", metadata.n_predict);
2174
2175 if let Some(max_completion_tokens) = chat_request.max_completion_tokens {
2181 if metadata.n_predict != max_completion_tokens {
2182 #[cfg(feature = "logging")]
2183 info!(target: "stdout", "Update n_predict with max_completion_tokens from {} to {}", metadata.n_predict, max_completion_tokens);
2184
2185 metadata.n_predict = max_completion_tokens;
2186
2187 if !should_update {
2188 should_update = true;
2189 }
2190 }
2191 }
2192
2193 if metadata.n_predict == -2 {
2195 #[cfg(feature = "logging")]
2196 info!(target: "stdout", "Update n_predict with available_completion_tokens from {} to {}", metadata.n_predict, available_completion_tokens);
2197
2198 metadata.n_predict = available_completion_tokens as i32;
2200
2201 if !should_update {
2202 should_update = true;
2203 }
2204 }
2205
2206 if metadata.n_predict == -1
2207 || (metadata.n_predict > 0 && metadata.n_predict < available_completion_tokens as i32)
2208 || (metadata.n_predict < 0 && metadata.n_predict != -2)
2209 {
2211 #[cfg(feature = "logging")]
2212 info!(target: "stdout", "Update n_predict with available_completion_tokens from {} to {}", metadata.n_predict, available_completion_tokens);
2213
2214 metadata.n_predict = available_completion_tokens as i32;
2216
2217 if !should_update {
2218 should_update = true;
2219 }
2220 }
2221
2222 if should_update {
2223 #[cfg(feature = "logging")]
2224 info!(target: "stdout", "Update the model metadata.");
2225
2226 update_model_metadata(chat_request.model.as_ref(), metadata)?;
2228 }
2229
2230 Ok(())
2231}
2232
2233fn post_process(
2235 output: impl AsRef<str>,
2236 template_ty: &PromptTemplateType,
2237) -> Result<String, String> {
2238 let output = if *template_ty == PromptTemplateType::Baichuan2 {
2239 if output.as_ref().contains("用户:") {
2240 output.as_ref().trim_end_matches("用户:").trim().to_owned()
2241 } else {
2242 output.as_ref().trim().to_owned()
2243 }
2244 } else if *template_ty == PromptTemplateType::OpenChat {
2245 if output.as_ref().contains("<|end_of_turn|>") {
2246 output
2247 .as_ref()
2248 .trim_end_matches("<|end_of_turn|>")
2249 .trim()
2250 .to_owned()
2251 } else {
2252 output.as_ref().trim().to_owned()
2253 }
2254 } else if *template_ty == PromptTemplateType::GemmaInstruct
2255 || *template_ty == PromptTemplateType::Gemma3
2256 {
2257 let s = output.as_ref().trim();
2258 if s.ends_with("<end_of_turn>") {
2259 s.trim_end_matches("<end_of_turn>").trim().to_owned()
2260 } else {
2261 s.to_owned()
2262 }
2263 } else if *template_ty == PromptTemplateType::ChatML
2264 || *template_ty == PromptTemplateType::ChatMLTool
2265 || *template_ty == PromptTemplateType::InternLM2Tool
2266 || *template_ty == PromptTemplateType::MiniCPMV
2267 {
2268 let mut s = output.as_ref().trim();
2269 if s.ends_with("<|endoftext|>") {
2270 s = s.trim_end_matches("<|endoftext|>").trim();
2271 }
2272
2273 if s.starts_with(":") {
2274 s = s.trim_start_matches(":").trim();
2275 }
2276
2277 let x = {
2279 let pat = r#"<think>
2280
2281</think>
2282"#;
2283 if s.contains(pat) {
2284 let x = s.replace(pat, "");
2285 if x.starts_with("()") {
2286 x.trim_start_matches("()").to_owned()
2287 } else {
2288 x.to_owned()
2289 }
2290 } else {
2291 s.to_owned()
2292 }
2293 };
2294 s = x.trim();
2295
2296 if s.contains("<|im_start|>") && s.contains("<|im_end|>") {
2297 let idx_start = s.find("<|im_start|>").unwrap();
2298 let idx_end = s.find("<|im_end|>").unwrap();
2299
2300 match idx_start <= idx_end {
2301 true => s.split("<|im_start|>").collect::<Vec<_>>()[0]
2302 .trim()
2303 .to_owned(),
2304 false => s.split("<|im_end|>").collect::<Vec<_>>()[0]
2305 .trim()
2306 .to_owned(),
2307 }
2308 } else if s.contains("<|im_start|>") {
2309 s.split("<|im_start|>").collect::<Vec<_>>()[0]
2310 .trim()
2311 .to_owned()
2312 } else if s.contains("<|im_end|>") {
2313 let output = s.trim_end_matches("<|im_end|>").trim();
2314 if output.starts_with(": ") {
2315 output.trim_start_matches(": ").to_owned()
2316 } else {
2317 output.to_owned()
2318 }
2319 } else {
2320 s.to_owned()
2321 }
2322 } else if *template_ty == PromptTemplateType::Zephyr
2323 || *template_ty == PromptTemplateType::MistralLite
2324 || *template_ty == PromptTemplateType::MistralTool
2325 || *template_ty == PromptTemplateType::MistralInstruct
2326 || *template_ty == PromptTemplateType::MistralSmallChat
2327 || *template_ty == PromptTemplateType::MistralSmallTool
2328 || *template_ty == PromptTemplateType::BreezeInstruct
2329 {
2330 if output.as_ref().contains("</s><") {
2331 output.as_ref().trim_end_matches("</s><").trim().to_owned()
2332 } else if output.as_ref().contains("</s>") {
2333 output
2334 .as_ref()
2335 .strip_suffix("</s>")
2336 .unwrap()
2337 .trim()
2338 .to_owned()
2339 } else {
2340 output.as_ref().trim().to_owned()
2341 }
2342 } else if *template_ty == PromptTemplateType::DeepseekChat {
2343 if output.as_ref().contains("<|end_of_sentence|>") {
2344 output
2345 .as_ref()
2346 .trim_end_matches("<|end_of_sentence|>")
2347 .trim()
2348 .replace("<|end_of_sentence|>", " ")
2349 .trim()
2350 .to_owned()
2351 } else {
2352 output.as_ref().trim().to_owned()
2353 }
2354 } else if *template_ty == PromptTemplateType::HumanAssistant {
2355 if output.as_ref().contains("Human:") {
2356 output.as_ref().trim_end_matches("Human:").trim().to_owned()
2357 } else {
2358 output.as_ref().trim().to_owned()
2359 }
2360 } else if *template_ty == PromptTemplateType::SolarInstruct {
2361 let s = output.as_ref().trim();
2362
2363 if s.starts_with("### Answer") {
2364 let s = s.trim_start_matches("###").trim();
2365
2366 if s.starts_with("Answer:\n") {
2367 s.replace("Answer:\n", "Answer: ")
2368 } else {
2369 s.to_owned()
2370 }
2371 } else {
2372 s.to_owned()
2373 }
2374 } else if *template_ty == PromptTemplateType::Llama2Chat
2375 || *template_ty == PromptTemplateType::NemotronTool
2376 || *template_ty == PromptTemplateType::NemotronChat
2377 {
2378 let s = output.as_ref().trim();
2379 if s.ends_with("</s>") {
2380 s.trim_end_matches("</s>").trim().to_owned()
2381 } else {
2382 s.to_owned()
2383 }
2384 } else if *template_ty == PromptTemplateType::Llama3Chat
2385 || *template_ty == PromptTemplateType::GroqLlama3Tool
2386 || *template_ty == PromptTemplateType::Llama3Tool
2387 || *template_ty == PromptTemplateType::FunctionaryV32
2388 {
2389 let s = output.as_ref().trim();
2390 if s.ends_with("<|eot_id|>") {
2391 s.trim_end_matches("<|eot_id|>").trim().to_owned()
2392 } else {
2393 s.to_owned()
2394 }
2395 } else if *template_ty == PromptTemplateType::Phi3Chat {
2396 let s = output.as_ref().trim();
2397 if s.ends_with("<|end|>") {
2398 s.trim_end_matches("<|end|>").trim().to_owned()
2399 } else {
2400 s.to_owned()
2401 }
2402 } else if *template_ty == PromptTemplateType::Phi4Chat {
2403 let mut s = output.as_ref().trim();
2404
2405 if s.starts_with("think>") {
2406 s = s.trim_start_matches("think>").trim();
2407 }
2408
2409 if s.ends_with("<|im_end|>") {
2410 s.trim_end_matches("<|im_end|>").trim().to_owned()
2411 } else if s.ends_with("<|end|>") {
2412 s.trim_end_matches("<|end|>").trim().to_owned()
2413 } else {
2414 s.to_owned()
2415 }
2416 } else if *template_ty == PromptTemplateType::FunctionaryV31 {
2417 let mut s = output.as_ref().trim();
2418 if s.ends_with("<|eot_id|>") {
2419 s = s.trim_end_matches("<|eot_id|>").trim();
2420 }
2421 if s.ends_with("<|eom_id|>") {
2422 s = s.trim_end_matches("<|eom_id|>").trim();
2423 }
2424 s.to_owned()
2425 } else if *template_ty == PromptTemplateType::MoxinChat
2426 || *template_ty == PromptTemplateType::MoxinInstruct
2427 {
2428 let s = output.as_ref().trim();
2429 if s.ends_with("</s>") {
2430 s.trim_end_matches("</s>").trim().to_owned()
2431 } else if s.ends_with("[INST]") {
2432 s.trim_end_matches("[INST]").trim().to_owned()
2433 } else {
2434 s.to_owned()
2435 }
2436 } else if *template_ty == PromptTemplateType::Falcon3 {
2437 let s = output.as_ref().trim();
2438 if s.ends_with("<|endoftext|>") {
2439 s.trim_end_matches("<|endoftext|>").trim().to_owned()
2440 } else {
2441 s.to_owned()
2442 }
2443 } else if *template_ty == PromptTemplateType::Megrez {
2444 let s = output.as_ref().trim();
2445 if s.ends_with("<|turn_end|>") {
2446 s.trim_end_matches("<|turn_end|>").trim().to_owned()
2447 } else {
2448 s.to_owned()
2449 }
2450 } else if *template_ty == PromptTemplateType::Qwen2vl
2451 || *template_ty == PromptTemplateType::Qwen3NoThink
2452 || *template_ty == PromptTemplateType::ChatMLThink
2453 {
2454 let mut s = output.as_ref().trim();
2455
2456 if s.starts_with(":") {
2457 s = s.trim_start_matches(":").trim();
2458 }
2459
2460 if s.starts_with("</think>") {
2461 s = s.trim_start_matches("</think>").trim();
2462 }
2463
2464 if s.ends_with("<|im_end|>") {
2465 s.trim_end_matches("<|im_end|>").trim().to_owned()
2466 } else {
2467 s.to_owned()
2468 }
2469 } else if *template_ty == PromptTemplateType::VicunaLlava {
2470 let s = output.as_ref().trim();
2471 if s.ends_with("</s>") {
2472 s.trim_end_matches("</s>").trim().to_owned()
2473 } else {
2474 s.to_owned()
2475 }
2476 } else if *template_ty == PromptTemplateType::ExaoneDeepChat
2477 || *template_ty == PromptTemplateType::ExaoneChat
2478 {
2479 let mut s = output.as_ref().trim();
2480
2481 if s.ends_with("[|endofturn|]") {
2482 s = s.trim_end_matches("[|endofturn|]").trim();
2483 }
2484
2485 s.to_owned()
2486 } else if *template_ty == PromptTemplateType::Llama4Chat {
2487 let mut s = output.as_ref().trim();
2488
2489 if s.ends_with("<|eot|>") {
2490 s = s.trim_end_matches("<|eot|>").trim();
2491 }
2492
2493 s.to_owned()
2494 } else if *template_ty == PromptTemplateType::Smolvl {
2495 let mut s = output.as_ref().trim();
2496
2497 if s.starts_with(":") {
2498 s = s.trim_start_matches(":").trim();
2499 }
2500
2501 if s.ends_with("<end_of_utterance>") {
2502 s = s.trim_end_matches("<end_of_utterance>").trim();
2503 }
2504
2505 if s.contains("<end_of_utterance>:") {
2506 let parts = s.split("<end_of_utterance>:").collect::<Vec<_>>();
2507 parts.last().unwrap().trim().to_owned()
2508 } else {
2509 s.to_owned()
2510 }
2511 } else {
2512 output.as_ref().trim().to_owned()
2513 };
2514
2515 Ok(output)
2516}
2517
2518fn build_prompt(
2530 model_name: Option<&String>,
2531 chat_request: &mut ChatCompletionRequest,
2532) -> Result<(String, u64, bool), LlamaCoreError> {
2533 let metadata = get_model_metadata(model_name)?;
2534 let ctx_size = metadata.ctx_size as u64;
2535 let chat_prompt = ChatPrompt::from(metadata.prompt_template);
2536
2537 let max_prompt_tokens = ctx_size * 4 / 5;
2539
2540 loop {
2541 {
2543 }
2556
2557 if chat_request.messages.is_empty() {
2558 let err_msg = "The messages in the chat request are empty.";
2559
2560 #[cfg(feature = "logging")]
2561 error!(target: "stdout", "{err_msg}");
2562
2563 return Err(LlamaCoreError::Operation(err_msg.to_owned()));
2564 }
2565
2566 #[cfg(feature = "logging")]
2567 {
2568 let mut role_chain = String::new();
2569 for (idx, message) in chat_request.messages.iter().enumerate() {
2570 if idx == chat_request.messages.len() - 1 {
2571 role_chain.push_str(&format!("{}", message.role()));
2572 } else {
2573 role_chain.push_str(&format!("{} -> ", message.role()));
2574 }
2575 }
2576 info!(target: "stdout", "Role chain: {role_chain}");
2577 }
2578
2579 let (prompt, tool_use) = match chat_request.tool_choice.as_ref() {
2580 Some(tool_choice) => match tool_choice {
2581 ToolChoice::None => {
2582 match chat_prompt.build_with_tools(&mut chat_request.messages, Some(&[])) {
2583 Ok(prompt) => (prompt, false),
2584 Err(e) => {
2585 let err_msg = format!("Fail to build chat prompts. Reason: {e}");
2586
2587 #[cfg(feature = "logging")]
2588 error!(target: "stdout", "{}", &err_msg);
2589
2590 return Err(LlamaCoreError::Operation(err_msg));
2591 }
2592 }
2593 }
2594 _ => match chat_request.tools.as_ref() {
2595 Some(tools) => match chat_prompt
2596 .build_with_tools(&mut chat_request.messages, Some(tools.as_slice()))
2597 {
2598 Ok(prompt) => (prompt, true),
2599 Err(e) => {
2600 let err_msg = format!("Fail to build chat prompts. Reason: {e}");
2601
2602 #[cfg(feature = "logging")]
2603 error!(target: "stdout", "{}", &err_msg);
2604
2605 return Err(LlamaCoreError::Operation(err_msg));
2606 }
2607 },
2608 None => {
2609 #[cfg(feature = "logging")]
2610 warn!(target: "stdout", "The tool choice without tools is not supported.");
2611
2612 match chat_prompt.build_with_tools(&mut chat_request.messages, None) {
2613 Ok(prompt) => (prompt, false),
2614 Err(e) => {
2615 let err_msg = format!("Fail to build chat prompts. Reason: {e}");
2616
2617 #[cfg(feature = "logging")]
2618 error!(target: "stdout", "{}", &err_msg);
2619
2620 return Err(LlamaCoreError::Operation(err_msg));
2621 }
2622 }
2623 }
2624 },
2625 },
2626 None => match chat_prompt.build_with_tools(&mut chat_request.messages, None) {
2627 Ok(prompt) => (prompt, false),
2628 Err(e) => {
2629 let err_msg = format!("Fail to build chat prompts. Reason: {e}");
2630
2631 #[cfg(feature = "logging")]
2632 error!(target: "stdout", "{}", &err_msg);
2633
2634 return Err(LlamaCoreError::Operation(err_msg));
2635 }
2636 },
2637 };
2638 #[cfg(feature = "logging")]
2639 info!(target: "stdout", "Try to set prompt: {prompt}");
2640
2641 set_prompt(model_name, &prompt)?;
2643
2644 let token_info = get_token_info_by_graph_name(model_name)?;
2646
2647 match token_info.prompt_tokens > max_prompt_tokens {
2648 true => {
2649 match chat_request.messages[0].role() {
2650 ChatCompletionRole::System => {
2651 if chat_request.messages.len() > 2 {
2652 #[cfg(feature = "logging")]
2653 info!(target: "stdout", "Prune chat history: current length {}", chat_request.messages.len());
2654
2655 if chat_request.messages[1].role() == ChatCompletionRole::User {
2658 let user_message = chat_request.messages.remove(1);
2659
2660 #[cfg(feature = "logging")]
2661 info!(target: "stdout", "Remove a user message from the chat history: {user_message:?}");
2662 }
2663
2664 while chat_request.messages[1].role() != ChatCompletionRole::User {
2667 let message = chat_request.messages.remove(1);
2668
2669 #[cfg(feature = "logging")]
2670 info!(target: "stdout", "Remove a {} message from the chat history: {:?}", message.role(), message);
2671
2672 if chat_request.messages.len() == 1 {
2673 let err_msg = format!("The last message in the chat history should be a user message, but found a {} message.", message.role());
2674
2675 #[cfg(feature = "logging")]
2676 error!(target: "stdout", "{err_msg}");
2677
2678 return Err(LlamaCoreError::Operation(err_msg));
2679 }
2680 }
2681 } else if token_info.prompt_tokens > ctx_size {
2682 let err_msg = format!(
2683 "The number of prompt tokens ({}) is greater than the context size ({}). Please increase the context size, or simplify the input message.",
2684 token_info.prompt_tokens, ctx_size
2685 );
2686
2687 #[cfg(feature = "logging")]
2688 error!(target: "stdout", "{}", &err_msg);
2689
2690 return Err(LlamaCoreError::Operation(err_msg));
2691 } else {
2692 return Ok((prompt, ctx_size - token_info.prompt_tokens, tool_use));
2693 }
2694 }
2695 ChatCompletionRole::User => {
2696 if chat_request.messages.len() > 1 {
2697 if chat_request.messages[0].role() == ChatCompletionRole::User {
2702 let user_message = chat_request.messages.remove(0);
2703
2704 #[cfg(feature = "logging")]
2705 info!(target: "stdout", "Remove a user message from the chat history: {user_message:?}");
2706 }
2707
2708 while chat_request.messages[0].role() != ChatCompletionRole::User {
2711 let message = chat_request.messages.remove(0);
2712
2713 #[cfg(feature = "logging")]
2714 info!(target: "stdout", "Remove a {} message from the chat history: {:?}", message.role(), message);
2715
2716 if chat_request.messages.is_empty() {
2717 let err_msg = format!("The last message in the chat history should be a user message, but found a {} message.", message.role());
2718
2719 #[cfg(feature = "logging")]
2720 error!(target: "stdout", "{err_msg}");
2721
2722 return Err(LlamaCoreError::Operation(err_msg));
2723 }
2724 }
2725 } else if token_info.prompt_tokens > ctx_size {
2726 let err_msg = format!(
2727 "The number of prompt tokens ({}) is greater than the context size ({}). Please increase the context size, or simplify the input message.",
2728 token_info.prompt_tokens, ctx_size
2729 );
2730
2731 #[cfg(feature = "logging")]
2732 error!(target: "stdout", "{}", &err_msg);
2733
2734 return Err(LlamaCoreError::Operation(err_msg));
2735 } else {
2736 return Ok((prompt, ctx_size - token_info.prompt_tokens, tool_use));
2737 }
2738 }
2739 _ => {
2740 #[cfg(feature = "logging")]
2741 info!(target: "stdout", "remove a {} message from the message queue", chat_request.messages[0].role());
2742
2743 chat_request.messages.remove(0);
2744 }
2745 }
2746
2747 continue;
2748 }
2749 false => return Ok((prompt, ctx_size - max_prompt_tokens, tool_use)),
2750 }
2751 }
2752}
2753
2754fn set_prompt(model_name: Option<&String>, prompt: impl AsRef<str>) -> Result<(), LlamaCoreError> {
2755 let chat_graphs = match CHAT_GRAPHS.get() {
2756 Some(chat_graphs) => chat_graphs,
2757 None => {
2758 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
2759
2760 #[cfg(feature = "logging")]
2761 error!(target: "stdout", "{}", &err_msg);
2762
2763 return Err(LlamaCoreError::Operation(err_msg.into()));
2764 }
2765 };
2766
2767 let mut chat_graphs = chat_graphs.lock().map_err(|e| {
2768 let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
2769
2770 #[cfg(feature = "logging")]
2771 error!(target: "stdout", "{}", &err_msg);
2772
2773 LlamaCoreError::Operation(err_msg)
2774 })?;
2775
2776 match model_name {
2777 Some(model_name) => {
2778 #[cfg(feature = "logging")]
2779 info!(target: "stdout", "Set prompt to the chat model named {model_name}");
2780
2781 match chat_graphs.contains_key(model_name) {
2782 true => {
2783 let graph = chat_graphs.get_mut(model_name).unwrap();
2784 let tensor_data = prompt.as_ref().as_bytes().to_vec();
2785 set_tensor_data_u8(graph, 0, &tensor_data)
2786 }
2787 false => match chat_graphs.iter_mut().next() {
2788 Some((_, graph)) => {
2789 let tensor_data = prompt.as_ref().as_bytes().to_vec();
2790 set_tensor_data_u8(graph, 0, &tensor_data)
2791 }
2792 None => {
2793 let err_msg = "There is no model available in the chat graphs.";
2794
2795 #[cfg(feature = "logging")]
2796 error!(target: "stdout", "{}", &err_msg);
2797
2798 Err(LlamaCoreError::Operation(err_msg.into()))
2799 }
2800 },
2801 }
2802 }
2803 None => {
2804 #[cfg(feature = "logging")]
2805 info!(target: "stdout", "Set prompt to the default chat model.");
2806
2807 match chat_graphs.iter_mut().next() {
2808 Some((_, graph)) => {
2809 let tensor_data = prompt.as_ref().as_bytes().to_vec();
2810 set_tensor_data_u8(graph, 0, &tensor_data)
2811 }
2812 None => {
2813 let err_msg = "There is no model available in the chat graphs while trying to set prompt to the default model.";
2814
2815 #[cfg(feature = "logging")]
2816 error!(target: "stdout", "{err_msg}");
2817
2818 Err(LlamaCoreError::Operation(err_msg.into()))
2819 }
2820 }
2821 }
2822 }
2823}
2824
2825fn get_model_metadata(model_name: Option<&String>) -> Result<GgmlMetadata, LlamaCoreError> {
2827 let chat_graphs = match CHAT_GRAPHS.get() {
2828 Some(chat_graphs) => chat_graphs,
2829 None => {
2830 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
2831
2832 #[cfg(feature = "logging")]
2833 error!(target: "stdout", "{err_msg}");
2834
2835 return Err(LlamaCoreError::Operation(err_msg.into()));
2836 }
2837 };
2838
2839 let chat_graphs = chat_graphs.lock().map_err(|e| {
2840 let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
2841
2842 #[cfg(feature = "logging")]
2843 error!(target: "stdout", "{}", &err_msg);
2844
2845 LlamaCoreError::Operation(err_msg)
2846 })?;
2847
2848 match model_name {
2849 Some(model_name) => match chat_graphs.contains_key(model_name) {
2850 true => {
2851 let graph = chat_graphs.get(model_name).unwrap();
2852 Ok(graph.metadata.clone())
2853 }
2854 false => match chat_graphs.iter().next() {
2855 Some((_, graph)) => Ok(graph.metadata.clone()),
2856 None => {
2857 let err_msg = "There is no model available in the chat graphs.";
2858
2859 #[cfg(feature = "logging")]
2860 error!(target: "stdout", "{}", &err_msg);
2861
2862 Err(LlamaCoreError::Operation(err_msg.into()))
2863 }
2864 },
2865 },
2866 None => match chat_graphs.iter().next() {
2867 Some((_, graph)) => Ok(graph.metadata.clone()),
2868 None => {
2869 let err_msg = "There is no model available in the chat graphs.";
2870
2871 #[cfg(feature = "logging")]
2872 error!(target: "stdout", "{err_msg}");
2873
2874 Err(LlamaCoreError::Operation(err_msg.into()))
2875 }
2876 },
2877 }
2878}
2879
2880fn update_model_metadata(
2881 model_name: Option<&String>,
2882 metadata: &GgmlMetadata,
2883) -> Result<(), LlamaCoreError> {
2884 let config = match serde_json::to_string(metadata) {
2885 Ok(config) => config,
2886 Err(e) => {
2887 let err_msg = format!("Fail to serialize metadata to a JSON string. {e}");
2888
2889 #[cfg(feature = "logging")]
2890 error!(target: "stdout", "{}", &err_msg);
2891
2892 return Err(LlamaCoreError::Operation(err_msg));
2893 }
2894 };
2895
2896 let chat_graphs = match CHAT_GRAPHS.get() {
2897 Some(chat_graphs) => chat_graphs,
2898 None => {
2899 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
2900
2901 #[cfg(feature = "logging")]
2902 error!(target: "stdout", "{err_msg}");
2903
2904 return Err(LlamaCoreError::Operation(err_msg.into()));
2905 }
2906 };
2907
2908 let mut chat_graphs = chat_graphs.lock().map_err(|e| {
2909 let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. Reason: {e}");
2910
2911 #[cfg(feature = "logging")]
2912 error!(target: "stdout", "{}", &err_msg);
2913
2914 LlamaCoreError::Operation(err_msg)
2915 })?;
2916
2917 match model_name {
2918 Some(model_name) => {
2919 match chat_graphs.contains_key(model_name) {
2920 true => {
2921 let graph = chat_graphs.get_mut(model_name).unwrap();
2922 set_tensor_data_u8(graph, 1, config.as_bytes())
2924 }
2925 false => match chat_graphs.iter_mut().next() {
2926 Some((_, graph)) => {
2927 set_tensor_data_u8(graph, 1, config.as_bytes())
2929 }
2930 None => {
2931 let err_msg = "There is no model available in the chat graphs.";
2932
2933 #[cfg(feature = "logging")]
2934 error!(target: "stdout", "{}", &err_msg);
2935
2936 Err(LlamaCoreError::Operation(err_msg.into()))
2937 }
2938 },
2939 }
2940 }
2941 None => {
2942 match chat_graphs.iter_mut().next() {
2943 Some((_, graph)) => {
2944 set_tensor_data_u8(graph, 1, config.as_bytes())
2946 }
2947 None => {
2948 let err_msg = "There is no model available in the chat graphs.";
2949
2950 #[cfg(feature = "logging")]
2951 error!(target: "stdout", "{err_msg}");
2952
2953 Err(LlamaCoreError::Operation(err_msg.into()))
2954 }
2955 }
2956 }
2957 }
2958}
2959
2960fn reset_model_metadata(model_name: Option<&String>) -> Result<(), LlamaCoreError> {
2961 let metadata = get_model_metadata(model_name)?;
2963
2964 update_model_metadata(model_name, &metadata)
2966}
2967
2968#[derive(Debug, Clone, Copy, PartialEq, Eq)]
2969enum ContextFullState {
2970 Message,
2971 Usage,
2972 Done,
2973 EndOfSequence,
2974}
2975
2976#[derive(Debug, Clone, Copy, PartialEq, Eq)]
2977enum StreamState {
2978 Usage,
2979 NoUsage,
2980 Done,
2981 EndOfSequence,
2982}
2983
2984#[derive(Debug, Clone, Copy, PartialEq, Eq)]
2985enum PromptTooLongState {
2986 Message,
2987 Usage,
2988 Done,
2989 EndOfSequence,
2990}
2991
2992struct ChatStream {
2993 id: String,
2994 model: Option<String>,
2995 include_usage: bool,
2996 context_full_state: ContextFullState,
2997 prompt_too_long_state: PromptTooLongState,
2998 stream_state: StreamState,
2999 cache: Option<VecDeque<String>>,
3000 is_waiting: bool,
3001 has_lock: bool,
3002}
3003impl ChatStream {
3004 fn new(
3005 model: Option<String>,
3006 id: String,
3007 include_usage: bool,
3008 cache: Option<Vec<String>>,
3009 ) -> Self {
3010 let has_lock = CHAT_STREAM_ACTIVE
3012 .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
3013 .is_ok();
3014
3015 #[cfg(feature = "logging")]
3016 if !has_lock {
3017 info!(target: "stdout", "Lock acquisition failed in ChatStream::new, creating with waiting status");
3018 }
3019
3020 ChatStream {
3021 id,
3022 model,
3023 include_usage,
3024 context_full_state: ContextFullState::Message,
3025 prompt_too_long_state: PromptTooLongState::Message,
3026 stream_state: if include_usage {
3027 StreamState::Usage
3028 } else {
3029 StreamState::NoUsage
3030 },
3031 cache: cache.map(VecDeque::from),
3032 is_waiting: !has_lock,
3033 has_lock,
3034 }
3035 }
3036
3037 fn try_acquire_lock(&mut self) -> bool {
3039 if self.has_lock {
3040 return true;
3041 }
3042
3043 let acquired = CHAT_STREAM_ACTIVE
3044 .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
3045 .is_ok();
3046
3047 if acquired {
3048 self.has_lock = true;
3049 self.is_waiting = false;
3050 }
3051
3052 acquired
3053 }
3054}
3055impl Drop for ChatStream {
3056 fn drop(&mut self) {
3057 if self.has_lock || (self.cache.is_none() && !self.is_waiting) {
3059 #[cfg(feature = "logging")]
3060 info!(target: "stdout", "Cleaning up context for ChatStream {}", &self.id);
3061
3062 match &self.model {
3063 Some(model_name) => {
3064 match CHAT_GRAPHS.get() {
3065 Some(chat_graphs) => {
3066 match chat_graphs.lock() {
3067 Ok(mut chat_graphs) => match chat_graphs.contains_key(model_name) {
3068 true => {
3069 let graph = chat_graphs.get_mut(model_name).unwrap();
3070
3071 if let Err(e) = graph.finish_single() {
3073 let err_msg = format!(
3074 "Failed to clean up the context. Reason: {e}"
3075 );
3076
3077 #[cfg(feature = "logging")]
3078 error!(target: "stdout", "{}", &err_msg);
3079
3080 #[cfg(not(feature = "logging"))]
3081 println!(
3082 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3083 &err_msg
3084 );
3085 }
3086 }
3087 false => match chat_graphs.iter_mut().next() {
3088 Some((_, graph)) => {
3089 if let Err(e) = graph.finish_single() {
3091 let err_msg = format!(
3092 "Failed to clean up the context. Reason: {e}"
3093 );
3094
3095 #[cfg(feature = "logging")]
3096 error!(target: "stdout", "{}", &err_msg);
3097
3098 #[cfg(not(feature = "logging"))]
3099 println!(
3100 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3101 &err_msg
3102 );
3103 }
3104 }
3105 None => {
3106 let err_msg =
3107 "There is no model available in the chat graphs.";
3108
3109 #[cfg(feature = "logging")]
3110 error!(target: "stdout", "{}", &err_msg);
3111
3112 #[cfg(not(feature = "logging"))]
3113 println!(
3114 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3115 &err_msg
3116 );
3117 }
3118 },
3119 },
3120 Err(e) => {
3121 let err_msg =
3122 format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
3123
3124 #[cfg(feature = "logging")]
3125 error!(target: "stdout", "{}", &err_msg);
3126
3127 #[cfg(not(feature = "logging"))]
3128 println!(
3129 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3130 &err_msg
3131 );
3132 }
3133 }
3134 }
3135 None => {
3136 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
3137
3138 #[cfg(feature = "logging")]
3139 error!(target: "stdout", "{}", &err_msg);
3140
3141 #[cfg(not(feature = "logging"))]
3142 println!(
3143 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3144 &err_msg
3145 );
3146 }
3147 };
3148 }
3149 None => {
3150 match CHAT_GRAPHS.get() {
3151 Some(chat_graphs) => {
3152 match chat_graphs.lock() {
3153 Ok(mut chat_graphs) => match chat_graphs.iter_mut().next() {
3154 Some((_, graph)) => {
3155 if let Err(e) = graph.finish_single() {
3157 let err_msg = format!(
3158 "Failed to clean up the context. Reason: {e}"
3159 );
3160
3161 #[cfg(feature = "logging")]
3162 error!(target: "stdout", "{}", &err_msg);
3163
3164 #[cfg(not(feature = "logging"))]
3165 println!(
3166 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3167 &err_msg
3168 );
3169 }
3170 }
3171 None => {
3172 let err_msg =
3173 "There is no model available in the chat graphs.";
3174
3175 #[cfg(feature = "logging")]
3176 error!(target: "stdout", "{err_msg}");
3177
3178 #[cfg(not(feature = "logging"))]
3179 println!(
3180 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3181 err_msg
3182 );
3183 }
3184 },
3185 Err(e) => {
3186 let err_msg =
3187 format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
3188
3189 #[cfg(feature = "logging")]
3190 error!(target: "stdout", "{}", &err_msg);
3191
3192 #[cfg(not(feature = "logging"))]
3193 println!(
3194 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3195 &err_msg
3196 );
3197 }
3198 }
3199 }
3200 None => {
3201 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
3202
3203 #[cfg(feature = "logging")]
3204 error!(target: "stdout", "{}", &err_msg);
3205
3206 #[cfg(not(feature = "logging"))]
3207 println!(
3208 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3209 &err_msg
3210 );
3211 }
3212 };
3213 }
3214 }
3215
3216 #[cfg(feature = "logging")]
3217 info!(target: "stdout", "Model context cleanup done!");
3218 }
3219
3220 if let Err(e) = reset_model_metadata(self.model.as_ref()) {
3222 let err_msg = format!("Fail to reset model metadata. Reason: {e}");
3223
3224 #[cfg(feature = "logging")]
3225 error!(target: "stdout", "{}", &err_msg);
3226
3227 #[cfg(not(feature = "logging"))]
3228 println!("[ERROR][llama_core] {}", &err_msg);
3229 }
3230 #[cfg(feature = "logging")]
3231 info!(target: "stdout", "Model metadata reset done!");
3232
3233 if self.has_lock {
3235 CHAT_STREAM_ACTIVE.store(false, Ordering::SeqCst);
3237
3238 #[cfg(feature = "logging")]
3239 info!(target: "stdout", "Lock from ChatStream {} released", &self.id);
3240
3241 if let Ok(mut queue) = get_chat_stream_waker_queue().lock() {
3243 if let Some(waker) = queue.pop_front() {
3244 #[cfg(feature = "logging")]
3245 info!(target: "stdout", "Waking up a waiting ChatStream");
3246
3247 waker.wake();
3248 }
3249 }
3250 }
3251 }
3252}
3253impl futures::Stream for ChatStream {
3254 type Item = Result<String, LlamaCoreError>;
3255
3256 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
3257 let this = self.get_mut();
3258
3259 if this.is_waiting {
3261 if !this.try_acquire_lock() {
3262 if let Ok(mut queue) = get_chat_stream_waker_queue().lock() {
3264 queue.retain(|w| !w.will_wake(cx.waker()));
3266 queue.push_back(cx.waker().clone());
3268
3269 #[cfg(feature = "logging")]
3270 debug!(target: "stdout", "ChatStream {} is waiting for lock, added waker to queue", &this.id);
3271 }
3272
3273 return Poll::Pending;
3274 }
3275
3276 #[cfg(feature = "logging")]
3277 info!(target: "stdout", "ChatStream {} acquired lock and is now active", &this.id);
3278 }
3280
3281 if !this.has_lock && !this.try_acquire_lock() {
3283 this.is_waiting = true;
3285
3286 if let Ok(mut queue) = get_chat_stream_waker_queue().lock() {
3288 queue.retain(|w| !w.will_wake(cx.waker()));
3289 queue.push_back(cx.waker().clone());
3290 }
3291
3292 return Poll::Pending;
3293 }
3294
3295 if this.cache.is_none() {
3296 let res = compute_stream(
3297 this.model.clone(),
3298 this.id.clone(),
3299 this.include_usage,
3300 &mut this.prompt_too_long_state,
3301 &mut this.context_full_state,
3302 &mut this.stream_state,
3303 );
3304
3305 match res {
3306 Ok(x) => {
3307 #[cfg(feature = "logging")]
3308 info!(target: "stdout", "next item for ChatStream {}: {}", &this.id, &x);
3309
3310 if x != "[GGML] End of sequence" && !x.is_empty() {
3311 Poll::Ready(Some(Ok(x)))
3312 } else {
3313 Poll::Ready(None)
3315 }
3316 }
3317 Err(e) => Poll::Ready(Some(Err(e))),
3318 }
3319 } else {
3320 let x = this.cache.as_mut().unwrap().pop_front();
3321
3322 #[cfg(feature = "logging")]
3323 info!(target: "stdout", "Get the next item from the cache for ChatStream {}: {:?}", &this.id, &x);
3324
3325 match x {
3326 Some(x) => Poll::Ready(Some(Ok(x))),
3327 None => Poll::Ready(None),
3328 }
3329 }
3330 }
3331}
3332
3333fn get_chat_stream_waker_queue() -> &'static Mutex<VecDeque<Waker>> {
3335 CHAT_STREAM_WAKER_QUEUE.get_or_init(|| {
3336 #[cfg(feature = "logging")]
3337 info!(target: "stdout", "Initializing ChatStream waker queue");
3338 Mutex::new(VecDeque::new())
3339 })
3340}
3341
3342fn compute_stream(
3343 model_name: Option<String>,
3344 id: String,
3345 include_usage: bool,
3346 prompt_too_long_state: &mut PromptTooLongState,
3347 context_full_state: &mut ContextFullState,
3348 stream_state: &mut StreamState,
3349) -> Result<String, LlamaCoreError> {
3350 #[cfg(feature = "logging")]
3351 info!(target: "stdout", "Computing stream chunk for ChatStream {}", &id);
3352
3353 #[cfg(feature = "logging")]
3354 debug!(target: "stdout", "prompt_too_long_state: {:?}", *prompt_too_long_state);
3355 #[cfg(feature = "logging")]
3356 debug!(target: "stdout", "context_full_state: {:?}", *context_full_state);
3357 #[cfg(feature = "logging")]
3358 debug!(target: "stdout", "stream_state: {:?}", *stream_state);
3359
3360 if *prompt_too_long_state == PromptTooLongState::EndOfSequence
3361 || *context_full_state == ContextFullState::EndOfSequence
3362 || *stream_state == StreamState::EndOfSequence
3363 {
3364 #[cfg(feature = "logging")]
3365 info!(target: "stdout", "Return the chat stream chunk!");
3366
3367 return Ok("[GGML] End of sequence".to_string());
3368 }
3369
3370 let chat_graphs = match CHAT_GRAPHS.get() {
3371 Some(chat_graphs) => chat_graphs,
3372 None => {
3373 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
3374
3375 #[cfg(feature = "logging")]
3376 error!(target: "stdout", "{}", &err_msg);
3377
3378 return Err(LlamaCoreError::Operation(err_msg.into()));
3379 }
3380 };
3381
3382 let mut chat_graphs = chat_graphs.lock().map_err(|e| {
3384 let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
3385
3386 #[cfg(feature = "logging")]
3387 error!(target: "stdout", "{}", &err_msg);
3388
3389 LlamaCoreError::Operation(err_msg)
3390 })?;
3391
3392 let res = match &model_name {
3394 Some(model_name) => {
3395 match chat_graphs.contains_key(model_name) {
3396 true => {
3397 let graph = chat_graphs.get_mut(model_name).unwrap();
3398 match graph.compute_single() {
3400 Ok(_) => {
3401 #[cfg(feature = "logging")]
3402 debug!(target: "stdout", "Compute the chat stream chunk successfully.");
3403
3404 match stream_state {
3406 StreamState::Usage | StreamState::NoUsage => {
3407 let output_buffer =
3409 get_output_buffer_single(graph, OUTPUT_TENSOR)?;
3410
3411 #[cfg(feature = "logging")]
3412 info!(target: "stdout", "retrieved the output buffer");
3413
3414 let output = match String::from_utf8(output_buffer.clone()) {
3416 Ok(token) => token,
3417 Err(_) => {
3418 let mutex = CACHED_UTF8_ENCODINGS
3419 .get_or_init(|| Mutex::new(Vec::new()));
3420 let mut cached_encodings = mutex.lock().map_err(|e| {
3421 let err_msg = format!(
3422 "Fail to acquire the lock of `UTF8_ENCODINGS`. Reason: {e}"
3423 );
3424
3425 #[cfg(feature = "logging")]
3426 error!(target: "stdout", "{}", &err_msg);
3427
3428
3429 LlamaCoreError::Operation(err_msg)
3430 })?;
3431
3432 cached_encodings.extend_from_slice(&output_buffer[..]);
3434
3435 match String::from_utf8(cached_encodings.to_vec()) {
3436 Ok(token) => {
3437 cached_encodings.clear();
3439
3440 token
3441 }
3442 Err(e) => {
3443 if cached_encodings.len() > 4 {
3445 let err_msg = format!("Fail to convert a vector of bytes to string. The length of the utf8 bytes exceeds 4. {e}");
3446
3447 #[cfg(feature = "logging")]
3448 error!(target: "stdout", "{}", &err_msg);
3449
3450 #[cfg(feature = "logging")]
3451 error!(target: "stdout", "The cached buffer: {:?}", &cached_encodings[..]);
3452
3453 cached_encodings.clear();
3460
3461 String::from("")
3462 } else {
3463 let warn_msg = format!("Fail to convert a vector of bytes to string. {e}");
3464
3465 #[cfg(feature = "logging")]
3466 warn!(target: "stdout", "{}", &warn_msg);
3467
3468 String::from("")
3469 }
3470 }
3471 }
3472 }
3473 };
3474
3475 #[cfg(feature = "logging")]
3476 info!(target: "stdout", "decoded the output buffer");
3477
3478 let created = SystemTime::now()
3479 .duration_since(std::time::UNIX_EPOCH)
3480 .map_err(|e| {
3481 let err_msg = format!(
3482 "Failed to get the current time. Reason: {e}"
3483 );
3484
3485 #[cfg(feature = "logging")]
3486 error!(target: "stdout", "{}", &err_msg);
3487
3488 LlamaCoreError::Operation(err_msg)
3489 })?;
3490
3491 let chat_completion_chunk = ChatCompletionChunk {
3492 id,
3493 object: "chat.completion.chunk".to_string(),
3494 created: created.as_secs(),
3495 model: graph.name().to_owned(),
3496 system_fingerprint: "fp_44709d6fcb".to_string(),
3497 choices: vec![ChatCompletionChunkChoice {
3498 index: 0,
3499 delta: ChatCompletionChunkChoiceDelta {
3500 role: ChatCompletionRole::Assistant,
3501 content: Some(output),
3502 tool_calls: vec![],
3503 },
3504 logprobs: None,
3505 finish_reason: None,
3506 }],
3507 usage: None,
3508 };
3509
3510 #[cfg(feature = "logging")]
3511 info!(target: "stdout", "created chat completion chunk");
3512
3513 let chunk_str = serde_json::to_string(&chat_completion_chunk)
3515 .map_err(|e| {
3516 let err_msg = format!(
3517 "Failed to serialize chat completion chunk. Reason: {e}"
3518 );
3519
3520 #[cfg(feature = "logging")]
3521 error!(target: "stdout", "{}", &err_msg);
3522
3523 LlamaCoreError::Operation(err_msg)
3524 })?;
3525
3526 Ok(format!("data: {chunk_str}\n\n"))
3527 }
3528 StreamState::Done => {
3529 *stream_state = StreamState::EndOfSequence;
3530
3531 Ok("data: [DONE]\n\n".to_string())
3532 }
3533 StreamState::EndOfSequence => {
3534 Ok("[GGML] End of sequence".to_string())
3535 }
3536 }
3537 }
3538 Err(wasmedge_wasi_nn::Error::BackendError(
3539 wasmedge_wasi_nn::BackendError::EndOfSequence,
3540 )) => {
3541 #[cfg(feature = "logging")]
3542 debug!(target: "stdout", "End of sequence");
3543
3544 match stream_state {
3545 StreamState::Usage => {
3546 *stream_state = StreamState::Done;
3547
3548 let token_info = get_token_info_by_graph(graph)?;
3550
3551 let usage = Some(Usage {
3552 prompt_tokens: token_info.prompt_tokens,
3553 completion_tokens: token_info.completion_tokens,
3554 total_tokens: token_info.prompt_tokens
3555 + token_info.completion_tokens,
3556 });
3557
3558 #[cfg(feature = "logging")]
3559 info!(target: "stdout", "token_info: {} prompt tokens, {} completion tokens", token_info.prompt_tokens, token_info.completion_tokens);
3560
3561 let created = SystemTime::now()
3562 .duration_since(std::time::UNIX_EPOCH)
3563 .map_err(|e| {
3564 let err_msg = format!(
3565 "Failed to get the current time. Reason: {e}"
3566 );
3567
3568 #[cfg(feature = "logging")]
3569 error!(target: "stdout", "{}", &err_msg);
3570
3571 LlamaCoreError::Operation(err_msg)
3572 })?;
3573
3574 let chat_completion_chunk = ChatCompletionChunk {
3575 id,
3576 object: "chat.completion.chunk".to_string(),
3577 created: created.as_secs(),
3578 model: graph.name().to_owned(),
3579 system_fingerprint: "fp_44709d6fcb".to_string(),
3580 choices: vec![],
3581 usage,
3582 };
3583
3584 let chunk_str = serde_json::to_string(&chat_completion_chunk)
3586 .map_err(|e| {
3587 let err_msg = format!(
3588 "Failed to serialize chat completion chunk. Reason: {e}"
3589 );
3590
3591 #[cfg(feature = "logging")]
3592 error!(target: "stdout", "{}", &err_msg);
3593
3594 LlamaCoreError::Operation(err_msg)
3595 })?;
3596
3597 Ok(format!("data: {chunk_str}\n\n"))
3598 }
3599 StreamState::Done | StreamState::NoUsage => {
3600 *stream_state = StreamState::EndOfSequence;
3601
3602 Ok("data: [DONE]\n\n".to_string())
3603 }
3604 StreamState::EndOfSequence => {
3605 Ok("[GGML] End of sequence".to_string())
3606 }
3607 }
3608 }
3609 Err(wasmedge_wasi_nn::Error::BackendError(
3610 wasmedge_wasi_nn::BackendError::ContextFull,
3611 )) => {
3612 #[cfg(feature = "logging")]
3613 debug!(target: "stdout", "Context full");
3614
3615 match context_full_state {
3616 ContextFullState::Message => {
3617 match include_usage {
3618 true => *context_full_state = ContextFullState::Usage,
3619 false => *context_full_state = ContextFullState::Done,
3620 }
3621
3622 let created = SystemTime::now()
3623 .duration_since(std::time::UNIX_EPOCH)
3624 .map_err(|e| {
3625 let err_msg = format!(
3626 "Failed to get the current time. Reason: {e}"
3627 );
3628
3629 #[cfg(feature = "logging")]
3630 error!(target: "stdout", "{}", &err_msg);
3631
3632 LlamaCoreError::Operation(err_msg)
3633 })?;
3634
3635 let chat_completion_chunk = ChatCompletionChunk {
3636 id,
3637 object: "chat.completion.chunk".to_string(),
3638 created: created.as_secs(),
3639 model: graph.name().to_owned(),
3640 system_fingerprint: "fp_44709d6fcb".to_string(),
3641 choices: vec![ChatCompletionChunkChoice {
3642 index: 0,
3643 delta: ChatCompletionChunkChoiceDelta {
3644 role: ChatCompletionRole::Assistant,
3645 content: Some(
3646 "<|WASMEDGE-GGML-CONTEXT-FULL|>".to_string(),
3647 ),
3648 tool_calls: vec![],
3649 },
3650 logprobs: None,
3651 finish_reason: Some(FinishReason::length),
3652 }],
3653 usage: None,
3654 };
3655
3656 let chunk_str = serde_json::to_string(&chat_completion_chunk)
3658 .map_err(|e| {
3659 let err_msg = format!(
3660 "Failed to serialize chat completion chunk. Reason: {e}"
3661 );
3662
3663 #[cfg(feature = "logging")]
3664 error!(target: "stdout", "{}", &err_msg);
3665
3666 LlamaCoreError::Operation(err_msg)
3667 })?;
3668
3669 Ok(format!("data: {chunk_str}\n\n"))
3670 }
3671 ContextFullState::Usage => {
3672 *context_full_state = ContextFullState::Done;
3673
3674 let token_info = get_token_info_by_graph(graph)?;
3676
3677 let usage = Some(Usage {
3678 prompt_tokens: token_info.prompt_tokens,
3679 completion_tokens: token_info.completion_tokens,
3680 total_tokens: token_info.prompt_tokens
3681 + token_info.completion_tokens,
3682 });
3683
3684 let created = SystemTime::now()
3685 .duration_since(std::time::UNIX_EPOCH)
3686 .map_err(|e| {
3687 let err_msg = format!(
3688 "Failed to get the current time. Reason: {e}"
3689 );
3690
3691 #[cfg(feature = "logging")]
3692 error!(target: "stdout", "{}", &err_msg);
3693
3694 LlamaCoreError::Operation(err_msg)
3695 })?;
3696
3697 let chat_completion_chunk = ChatCompletionChunk {
3698 id,
3699 object: "chat.completion.chunk".to_string(),
3700 created: created.as_secs(),
3701 model: graph.name().to_owned(),
3702 system_fingerprint: "fp_44709d6fcb".to_string(),
3703 choices: vec![],
3704 usage,
3705 };
3706
3707 let chunk_str = serde_json::to_string(&chat_completion_chunk)
3709 .map_err(|e| {
3710 let err_msg = format!(
3711 "Failed to serialize chat completion chunk. Reason: {e}"
3712 );
3713
3714 #[cfg(feature = "logging")]
3715 error!(target: "stdout", "{}", &err_msg);
3716
3717 LlamaCoreError::Operation(err_msg)
3718 })?;
3719
3720 Ok(format!("data: {chunk_str}\n\n"))
3721 }
3722 ContextFullState::Done => {
3723 *context_full_state = ContextFullState::EndOfSequence;
3724
3725 Ok("data: [DONE]\n\n".to_string())
3726 }
3727 ContextFullState::EndOfSequence => {
3728 Ok("[GGML] End of sequence".to_string())
3729 }
3730 }
3731 }
3732 Err(wasmedge_wasi_nn::Error::BackendError(
3733 wasmedge_wasi_nn::BackendError::PromptTooLong,
3734 )) => {
3735 #[cfg(feature = "logging")]
3736 debug!(target: "stdout", "Prompt too long");
3737
3738 match prompt_too_long_state {
3739 PromptTooLongState::Message => {
3740 match include_usage {
3741 true => *prompt_too_long_state = PromptTooLongState::Usage,
3742 false => *prompt_too_long_state = PromptTooLongState::Done,
3743 }
3744
3745 let created = SystemTime::now()
3746 .duration_since(std::time::UNIX_EPOCH)
3747 .map_err(|e| {
3748 let err_msg = format!(
3749 "Failed to get the current time. Reason: {e}"
3750 );
3751
3752 #[cfg(feature = "logging")]
3753 error!(target: "stdout", "{}", &err_msg);
3754
3755 LlamaCoreError::Operation(err_msg)
3756 })?;
3757
3758 let chat_completion_chunk = ChatCompletionChunk {
3759 id,
3760 object: "chat.completion.chunk".to_string(),
3761 created: created.as_secs(),
3762 model: graph.name().to_owned(),
3763 system_fingerprint: "fp_44709d6fcb".to_string(),
3764 choices: vec![ChatCompletionChunkChoice {
3765 index: 0,
3766 delta: ChatCompletionChunkChoiceDelta {
3767 role: ChatCompletionRole::Assistant,
3768 content: None,
3769 tool_calls: vec![],
3770 },
3771 logprobs: None,
3772 finish_reason: Some(FinishReason::length),
3773 }],
3774 usage: None,
3775 };
3776
3777 let chunk_str = serde_json::to_string(&chat_completion_chunk)
3779 .map_err(|e| {
3780 let err_msg = format!(
3781 "Failed to serialize chat completion chunk. Reason: {e}"
3782 );
3783
3784 #[cfg(feature = "logging")]
3785 error!(target: "stdout", "{}", &err_msg);
3786
3787 LlamaCoreError::Operation(err_msg)
3788 })?;
3789
3790 Ok(format!("data: {chunk_str}\n\n"))
3791 }
3792 PromptTooLongState::Usage => {
3793 *prompt_too_long_state = PromptTooLongState::Done;
3794
3795 let token_info = get_token_info_by_graph(graph)?;
3797
3798 let usage = Some(Usage {
3799 prompt_tokens: token_info.prompt_tokens,
3800 completion_tokens: token_info.completion_tokens,
3801 total_tokens: token_info.prompt_tokens
3802 + token_info.completion_tokens,
3803 });
3804
3805 let created = SystemTime::now()
3806 .duration_since(std::time::UNIX_EPOCH)
3807 .map_err(|e| {
3808 let err_msg = format!(
3809 "Failed to get the current time. Reason: {e}"
3810 );
3811
3812 #[cfg(feature = "logging")]
3813 error!(target: "stdout", "{}", &err_msg);
3814
3815 LlamaCoreError::Operation(err_msg)
3816 })?;
3817
3818 let chat_completion_chunk = ChatCompletionChunk {
3819 id,
3820 object: "chat.completion.chunk".to_string(),
3821 created: created.as_secs(),
3822 model: graph.name().to_owned(),
3823 system_fingerprint: "fp_44709d6fcb".to_string(),
3824 choices: vec![],
3825 usage,
3826 };
3827
3828 let chunk_str = serde_json::to_string(&chat_completion_chunk)
3830 .map_err(|e| {
3831 let err_msg = format!(
3832 "Failed to serialize chat completion chunk. Reason: {e}"
3833 );
3834
3835 #[cfg(feature = "logging")]
3836 error!(target: "stdout", "{}", &err_msg);
3837
3838 LlamaCoreError::Operation(err_msg)
3839 })?;
3840
3841 Ok(format!("data: {chunk_str}\n\n"))
3842 }
3843 PromptTooLongState::Done => {
3844 *prompt_too_long_state = PromptTooLongState::EndOfSequence;
3845
3846 Ok("data: [DONE]\n\n".to_string())
3847 }
3848 PromptTooLongState::EndOfSequence => {
3849 Ok("[GGML] End of sequence".to_string())
3850 }
3851 }
3852 }
3853 Err(e) => {
3854 let err_msg =
3855 format!("Failed to compute the chat completion. Reason: {e}");
3856
3857 #[cfg(feature = "logging")]
3858 error!(target: "stdout", "{}", &err_msg);
3859
3860 Err(LlamaCoreError::Backend(BackendError::ComputeSingle(
3861 err_msg,
3862 )))
3863 }
3864 }
3865 }
3866 false => {
3867 match chat_graphs.iter_mut().next() {
3868 Some((_, graph)) => {
3869 match graph.compute_single() {
3871 Ok(_) => {
3872 #[cfg(feature = "logging")]
3873 debug!(target: "stdout", "Compute the chat stream chunk successfully.");
3874
3875 match stream_state {
3876 StreamState::Usage | StreamState::NoUsage => {
3877 let output_buffer =
3879 get_output_buffer_single(graph, OUTPUT_TENSOR)?;
3880
3881 #[cfg(feature = "logging")]
3882 info!(target: "stdout", "retrieved the output buffer");
3883
3884 let output = match String::from_utf8(
3886 output_buffer.clone(),
3887 ) {
3888 Ok(token) => token,
3889 Err(_) => {
3890 let mutex = CACHED_UTF8_ENCODINGS
3891 .get_or_init(|| Mutex::new(Vec::new()));
3892 let mut cached_encodings = mutex.lock().map_err(|e| {
3893 let err_msg = format!(
3894 "Fail to acquire the lock of `UTF8_ENCODINGS`. Reason: {e}"
3895 );
3896
3897 #[cfg(feature = "logging")]
3898 error!(target: "stdout", "{}", &err_msg);
3899
3900
3901 LlamaCoreError::Operation(err_msg)
3902 })?;
3903
3904 cached_encodings
3906 .extend_from_slice(&output_buffer[..]);
3907
3908 match String::from_utf8(
3909 cached_encodings.to_vec(),
3910 ) {
3911 Ok(token) => {
3912 cached_encodings.clear();
3914
3915 token
3916 }
3917 Err(e) => {
3918 if cached_encodings.len() > 4 {
3920 let err_msg = format!("Fail to convert a vector of bytes to string. The length of the utf8 bytes exceeds 4. {e}");
3921
3922 #[cfg(feature = "logging")]
3923 error!(target: "stdout", "{}", &err_msg);
3924
3925 #[cfg(feature = "logging")]
3926 error!(target: "stdout", "The cached buffer: {:?}", &cached_encodings[..]);
3927
3928 cached_encodings.clear();
3936
3937 String::from("")
3938 } else {
3939 let warn_msg = format!("Fail to convert a vector of bytes to string. {e}");
3940
3941 #[cfg(feature = "logging")]
3942 warn!(target: "stdout", "{}", &warn_msg);
3943
3944 String::from("")
3945 }
3946 }
3947 }
3948 }
3949 };
3950
3951 #[cfg(feature = "logging")]
3952 info!(target: "stdout", "decoded the output buffer");
3953
3954 let created = SystemTime::now()
3955 .duration_since(std::time::UNIX_EPOCH)
3956 .map_err(|e| {
3957 let err_msg = format!(
3958 "Failed to get the current time. Reason: {e}"
3959 );
3960
3961 #[cfg(feature = "logging")]
3962 error!(target: "stdout", "{}", &err_msg);
3963
3964 LlamaCoreError::Operation(err_msg)
3965 })?;
3966
3967 let chat_completion_chunk = ChatCompletionChunk {
3968 id,
3969 object: "chat.completion.chunk".to_string(),
3970 created: created.as_secs(),
3971 model: graph.name().to_owned(),
3972 system_fingerprint: "fp_44709d6fcb".to_string(),
3973 choices: vec![ChatCompletionChunkChoice {
3974 index: 0,
3975 delta: ChatCompletionChunkChoiceDelta {
3976 role: ChatCompletionRole::Assistant,
3977 content: Some(output),
3978 tool_calls: vec![],
3979 },
3980 logprobs: None,
3981 finish_reason: None,
3982 }],
3983 usage: None,
3984 };
3985
3986 #[cfg(feature = "logging")]
3987 info!(target: "stdout", "created chat completion chunk");
3988
3989 let chunk_str =
3991 serde_json::to_string(&chat_completion_chunk)
3992 .map_err(|e| {
3993 let err_msg = format!(
3994 "Failed to serialize chat completion chunk. Reason: {e}"
3995 );
3996
3997 #[cfg(feature = "logging")]
3998 error!(target: "stdout", "{}", &err_msg);
3999
4000 LlamaCoreError::Operation(err_msg)
4001 })?;
4002
4003 Ok(format!("data: {chunk_str}\n\n"))
4004 }
4005 StreamState::Done => {
4006 *stream_state = StreamState::EndOfSequence;
4007
4008 Ok("data: [DONE]\n\n".to_string())
4009 }
4010 StreamState::EndOfSequence => {
4011 Ok("[GGML] End of sequence".to_string())
4012 }
4013 }
4014 }
4015 Err(wasmedge_wasi_nn::Error::BackendError(
4016 wasmedge_wasi_nn::BackendError::EndOfSequence,
4017 )) => {
4018 #[cfg(feature = "logging")]
4019 debug!(target: "stdout", "End of sequence");
4020
4021 match stream_state {
4022 StreamState::Usage => {
4023 *stream_state = StreamState::Done;
4024
4025 let token_info = get_token_info_by_graph(graph)?;
4027
4028 let usage = Some(Usage {
4029 prompt_tokens: token_info.prompt_tokens,
4030 completion_tokens: token_info.completion_tokens,
4031 total_tokens: token_info.prompt_tokens
4032 + token_info.completion_tokens,
4033 });
4034
4035 #[cfg(feature = "logging")]
4036 info!(target: "stdout", "token_info: {} prompt tokens, {} completion tokens", token_info.prompt_tokens, token_info.completion_tokens);
4037
4038 let created = SystemTime::now()
4039 .duration_since(std::time::UNIX_EPOCH)
4040 .map_err(|e| {
4041 let err_msg = format!(
4042 "Failed to get the current time. Reason: {e}"
4043 );
4044
4045 #[cfg(feature = "logging")]
4046 error!(target: "stdout", "{}", &err_msg);
4047
4048 LlamaCoreError::Operation(err_msg)
4049 })?;
4050
4051 let chat_completion_chunk = ChatCompletionChunk {
4052 id,
4053 object: "chat.completion.chunk".to_string(),
4054 created: created.as_secs(),
4055 model: graph.name().to_owned(),
4056 system_fingerprint: "fp_44709d6fcb".to_string(),
4057 choices: vec![],
4058 usage,
4059 };
4060
4061 let chunk_str =
4063 serde_json::to_string(&chat_completion_chunk)
4064 .map_err(|e| {
4065 let err_msg = format!(
4066 "Failed to serialize chat completion chunk. Reason: {e}"
4067 );
4068
4069 #[cfg(feature = "logging")]
4070 error!(target: "stdout", "{}", &err_msg);
4071
4072 LlamaCoreError::Operation(err_msg)
4073 })?;
4074
4075 Ok(format!("data: {chunk_str}\n\n"))
4076 }
4077 StreamState::Done | StreamState::NoUsage => {
4078 *stream_state = StreamState::EndOfSequence;
4079
4080 Ok("data: [DONE]\n\n".to_string())
4081 }
4082 StreamState::EndOfSequence => {
4083 Ok("[GGML] End of sequence".to_string())
4084 }
4085 }
4086 }
4087 Err(wasmedge_wasi_nn::Error::BackendError(
4088 wasmedge_wasi_nn::BackendError::ContextFull,
4089 )) => {
4090 #[cfg(feature = "logging")]
4091 debug!(target: "stdout", "Context full");
4092
4093 match context_full_state {
4094 ContextFullState::Message => {
4095 match include_usage {
4096 true => {
4097 *context_full_state = ContextFullState::Usage
4098 }
4099 false => {
4100 *context_full_state = ContextFullState::Done
4101 }
4102 }
4103
4104 let created = SystemTime::now()
4105 .duration_since(std::time::UNIX_EPOCH)
4106 .map_err(|e| {
4107 let err_msg = format!(
4108 "Failed to get the current time. Reason: {e}"
4109 );
4110
4111 #[cfg(feature = "logging")]
4112 error!(target: "stdout", "{}", &err_msg);
4113
4114 LlamaCoreError::Operation(err_msg)
4115 })?;
4116
4117 let chat_completion_chunk = ChatCompletionChunk {
4118 id,
4119 object: "chat.completion.chunk".to_string(),
4120 created: created.as_secs(),
4121 model: graph.name().to_owned(),
4122 system_fingerprint: "fp_44709d6fcb".to_string(),
4123 choices: vec![ChatCompletionChunkChoice {
4124 index: 0,
4125 delta: ChatCompletionChunkChoiceDelta {
4126 role: ChatCompletionRole::Assistant,
4127 content: Some(
4128 "<|WASMEDGE-GGML-CONTEXT-FULL|>"
4129 .to_string(),
4130 ),
4131 tool_calls: vec![],
4132 },
4133 logprobs: None,
4134 finish_reason: Some(FinishReason::length),
4135 }],
4136 usage: None,
4137 };
4138
4139 let chunk_str =
4141 serde_json::to_string(&chat_completion_chunk)
4142 .map_err(|e| {
4143 let err_msg = format!(
4144 "Failed to serialize chat completion chunk. Reason: {e}"
4145 );
4146
4147 #[cfg(feature = "logging")]
4148 error!(target: "stdout", "{}", &err_msg);
4149
4150 LlamaCoreError::Operation(err_msg)
4151 })?;
4152
4153 Ok(format!("data: {chunk_str}\n\n"))
4154 }
4155 ContextFullState::Usage => {
4156 *context_full_state = ContextFullState::Done;
4157
4158 let token_info = get_token_info_by_graph(graph)?;
4160
4161 let usage = Some(Usage {
4162 prompt_tokens: token_info.prompt_tokens,
4163 completion_tokens: token_info.completion_tokens,
4164 total_tokens: token_info.prompt_tokens
4165 + token_info.completion_tokens,
4166 });
4167
4168 let created = SystemTime::now()
4169 .duration_since(std::time::UNIX_EPOCH)
4170 .map_err(|e| {
4171 let err_msg = format!(
4172 "Failed to get the current time. Reason: {e}"
4173 );
4174
4175 #[cfg(feature = "logging")]
4176 error!(target: "stdout", "{}", &err_msg);
4177
4178 LlamaCoreError::Operation(err_msg)
4179 })?;
4180
4181 let chat_completion_chunk = ChatCompletionChunk {
4182 id,
4183 object: "chat.completion.chunk".to_string(),
4184 created: created.as_secs(),
4185 model: graph.name().to_owned(),
4186 system_fingerprint: "fp_44709d6fcb".to_string(),
4187 choices: vec![],
4188 usage,
4189 };
4190
4191 let chunk_str =
4193 serde_json::to_string(&chat_completion_chunk)
4194 .map_err(|e| {
4195 let err_msg = format!(
4196 "Failed to serialize chat completion chunk. Reason: {e}"
4197 );
4198
4199 #[cfg(feature = "logging")]
4200 error!(target: "stdout", "{}", &err_msg);
4201
4202 LlamaCoreError::Operation(err_msg)
4203 })?;
4204
4205 Ok(format!("data: {chunk_str}\n\n"))
4206 }
4207 ContextFullState::Done => {
4208 *context_full_state = ContextFullState::EndOfSequence;
4209
4210 Ok("data: [DONE]\n\n".to_string())
4211 }
4212 ContextFullState::EndOfSequence => {
4213 Ok("[GGML] End of sequence".to_string())
4214 }
4215 }
4216 }
4217 Err(wasmedge_wasi_nn::Error::BackendError(
4218 wasmedge_wasi_nn::BackendError::PromptTooLong,
4219 )) => {
4220 #[cfg(feature = "logging")]
4221 debug!(target: "stdout", "Prompt too long");
4222
4223 match prompt_too_long_state {
4224 PromptTooLongState::Message => {
4225 match include_usage {
4226 true => {
4227 *prompt_too_long_state =
4228 PromptTooLongState::Usage
4229 }
4230 false => {
4231 *prompt_too_long_state =
4232 PromptTooLongState::Done
4233 }
4234 }
4235
4236 let created = SystemTime::now()
4237 .duration_since(std::time::UNIX_EPOCH)
4238 .map_err(|e| {
4239 let err_msg = format!(
4240 "Failed to get the current time. Reason: {e}"
4241 );
4242
4243 #[cfg(feature = "logging")]
4244 error!(target: "stdout", "{}", &err_msg);
4245
4246 LlamaCoreError::Operation(err_msg)
4247 })?;
4248
4249 let chat_completion_chunk = ChatCompletionChunk {
4250 id,
4251 object: "chat.completion.chunk".to_string(),
4252 created: created.as_secs(),
4253 model: graph.name().to_owned(),
4254 system_fingerprint: "fp_44709d6fcb".to_string(),
4255 choices: vec![ChatCompletionChunkChoice {
4256 index: 0,
4257 delta: ChatCompletionChunkChoiceDelta {
4258 role: ChatCompletionRole::Assistant,
4259 content: None,
4260 tool_calls: vec![],
4261 },
4262 logprobs: None,
4263 finish_reason: Some(FinishReason::length),
4264 }],
4265 usage: None,
4266 };
4267
4268 let chunk_str =
4270 serde_json::to_string(&chat_completion_chunk)
4271 .map_err(|e| {
4272 let err_msg = format!(
4273 "Failed to serialize chat completion chunk. Reason: {e}"
4274 );
4275
4276 #[cfg(feature = "logging")]
4277 error!(target: "stdout", "{}", &err_msg);
4278
4279 LlamaCoreError::Operation(err_msg)
4280 })?;
4281
4282 Ok(format!("data: {chunk_str}\n\n"))
4283 }
4284 PromptTooLongState::Usage => {
4285 *prompt_too_long_state = PromptTooLongState::Done;
4286
4287 let token_info = get_token_info_by_graph(graph)?;
4289
4290 let usage = Some(Usage {
4291 prompt_tokens: token_info.prompt_tokens,
4292 completion_tokens: token_info.completion_tokens,
4293 total_tokens: token_info.prompt_tokens
4294 + token_info.completion_tokens,
4295 });
4296
4297 let created = SystemTime::now()
4298 .duration_since(std::time::UNIX_EPOCH)
4299 .map_err(|e| {
4300 let err_msg = format!(
4301 "Failed to get the current time. Reason: {e}"
4302 );
4303
4304 #[cfg(feature = "logging")]
4305 error!(target: "stdout", "{}", &err_msg);
4306
4307 LlamaCoreError::Operation(err_msg)
4308 })?;
4309
4310 let chat_completion_chunk = ChatCompletionChunk {
4311 id,
4312 object: "chat.completion.chunk".to_string(),
4313 created: created.as_secs(),
4314 model: graph.name().to_owned(),
4315 system_fingerprint: "fp_44709d6fcb".to_string(),
4316 choices: vec![],
4317 usage,
4318 };
4319
4320 let chunk_str =
4322 serde_json::to_string(&chat_completion_chunk)
4323 .map_err(|e| {
4324 let err_msg = format!(
4325 "Failed to serialize chat completion chunk. Reason: {e}"
4326 );
4327
4328 #[cfg(feature = "logging")]
4329 error!(target: "stdout", "{}", &err_msg);
4330
4331 LlamaCoreError::Operation(err_msg)
4332 })?;
4333
4334 Ok(format!("data: {chunk_str}\n\n"))
4335 }
4336 PromptTooLongState::Done => {
4337 *prompt_too_long_state =
4338 PromptTooLongState::EndOfSequence;
4339
4340 Ok("data: [DONE]\n\n".to_string())
4341 }
4342 PromptTooLongState::EndOfSequence => {
4343 Ok("[GGML] End of sequence".to_string())
4344 }
4345 }
4346 }
4347 Err(e) => {
4348 let err_msg = format!(
4349 "Failed to compute the chat completion. Reason: {e}"
4350 );
4351
4352 #[cfg(feature = "logging")]
4353 error!(target: "stdout", "{}", &err_msg);
4354
4355 Err(LlamaCoreError::Backend(BackendError::ComputeSingle(
4356 err_msg,
4357 )))
4358 }
4359 }
4360 }
4361 None => {
4362 let err_msg = "There is no model available in the chat graphs.";
4363
4364 #[cfg(feature = "logging")]
4365 error!(target: "stdout", "{}", &err_msg);
4366
4367 Err(LlamaCoreError::Operation(err_msg.into()))
4368 }
4369 }
4370 }
4371 }
4372 }
4373 None => {
4374 match chat_graphs.iter_mut().next() {
4375 Some((_, graph)) => {
4376 match graph.compute_single() {
4378 Ok(_) => {
4379 #[cfg(feature = "logging")]
4380 debug!(target: "stdout", "Compute the chat stream chunk successfully.");
4381
4382 match stream_state {
4383 StreamState::Usage | StreamState::NoUsage => {
4384 let output_buffer =
4386 get_output_buffer_single(graph, OUTPUT_TENSOR)?;
4387
4388 #[cfg(feature = "logging")]
4389 info!(target: "stdout", "retrieved the output buffer");
4390
4391 let output = match String::from_utf8(output_buffer.clone()) {
4393 Ok(token) => token,
4394 Err(_) => {
4395 let mutex = CACHED_UTF8_ENCODINGS
4396 .get_or_init(|| Mutex::new(Vec::new()));
4397 let mut cached_encodings = mutex.lock().map_err(|e| {
4398 let err_msg = format!(
4399 "Fail to acquire the lock of `UTF8_ENCODINGS`. Reason: {e}"
4400 );
4401
4402 #[cfg(feature = "logging")]
4403 error!(target: "stdout", "{}", &err_msg);
4404
4405 LlamaCoreError::Operation(err_msg)
4406 })?;
4407
4408 cached_encodings.extend_from_slice(&output_buffer[..]);
4409
4410 match String::from_utf8(cached_encodings.to_vec()) {
4411 Ok(token) => {
4412 cached_encodings.clear();
4414
4415 token
4416 }
4417 Err(e) => {
4418 if cached_encodings.len() > 4 {
4420 let err_msg = format!("Fail to convert a vector of bytes to string. The length of the utf8 bytes exceeds 4. {e}");
4421
4422 #[cfg(feature = "logging")]
4423 error!(target: "stdout", "{}", &err_msg);
4424
4425 #[cfg(feature = "logging")]
4426 error!(target: "stdout", "The cached buffer: {:?}", &cached_encodings[..]);
4427
4428 cached_encodings.clear();
4435
4436 String::from("")
4437 } else {
4438 let warn_msg = format!("Fail to convert a vector of bytes to string. {e}");
4439
4440 #[cfg(feature = "logging")]
4441 warn!(target: "stdout", "{}", &warn_msg);
4442
4443 String::from("")
4444 }
4445 }
4446 }
4447 }
4448 };
4449
4450 #[cfg(feature = "logging")]
4451 info!(target: "stdout", "decoded the output buffer");
4452
4453 let created = SystemTime::now()
4454 .duration_since(std::time::UNIX_EPOCH)
4455 .map_err(|e| {
4456 let err_msg = format!(
4457 "Failed to get the current time. Reason: {e}"
4458 );
4459
4460 #[cfg(feature = "logging")]
4461 error!(target: "stdout", "{}", &err_msg);
4462
4463 LlamaCoreError::Operation(err_msg)
4464 })?;
4465
4466 let chat_completion_chunk = ChatCompletionChunk {
4467 id,
4468 object: "chat.completion.chunk".to_string(),
4469 created: created.as_secs(),
4470 model: graph.name().to_owned(),
4471 system_fingerprint: "fp_44709d6fcb".to_string(),
4472 choices: vec![ChatCompletionChunkChoice {
4473 index: 0,
4474 delta: ChatCompletionChunkChoiceDelta {
4475 role: ChatCompletionRole::Assistant,
4476 content: Some(output),
4477 tool_calls: vec![],
4478 },
4479 logprobs: None,
4480 finish_reason: None,
4481 }],
4482 usage: None,
4483 };
4484
4485 #[cfg(feature = "logging")]
4486 info!(target: "stdout", "created chat completion chunk");
4487
4488 let chunk_str = serde_json::to_string(&chat_completion_chunk)
4490 .map_err(|e| {
4491 let err_msg = format!(
4492 "Failed to serialize chat completion chunk. Reason: {e}"
4493 );
4494
4495 #[cfg(feature = "logging")]
4496 error!(target: "stdout", "{}", &err_msg);
4497
4498 LlamaCoreError::Operation(err_msg)
4499 })?;
4500
4501 Ok(format!("data: {chunk_str}\n\n"))
4502 }
4503 StreamState::Done => {
4504 *stream_state = StreamState::EndOfSequence;
4505
4506 Ok("data: [DONE]\n\n".to_string())
4507 }
4508 StreamState::EndOfSequence => {
4509 Ok("[GGML] End of sequence".to_string())
4510 }
4511 }
4512 }
4513 Err(wasmedge_wasi_nn::Error::BackendError(
4514 wasmedge_wasi_nn::BackendError::EndOfSequence,
4515 )) => {
4516 #[cfg(feature = "logging")]
4517 debug!(target: "stdout", "End of sequence");
4518
4519 match stream_state {
4520 StreamState::Usage => {
4521 *stream_state = StreamState::Done;
4522
4523 let token_info = get_token_info_by_graph(graph)?;
4525
4526 let usage = Some(Usage {
4527 prompt_tokens: token_info.prompt_tokens,
4528 completion_tokens: token_info.completion_tokens,
4529 total_tokens: token_info.prompt_tokens
4530 + token_info.completion_tokens,
4531 });
4532
4533 #[cfg(feature = "logging")]
4534 info!(target: "stdout", "token_info: {} prompt tokens, {} completion tokens", token_info.prompt_tokens, token_info.completion_tokens);
4535
4536 let created = SystemTime::now()
4537 .duration_since(std::time::UNIX_EPOCH)
4538 .map_err(|e| {
4539 let err_msg = format!(
4540 "Failed to get the current time. Reason: {e}"
4541 );
4542
4543 #[cfg(feature = "logging")]
4544 error!(target: "stdout", "{}", &err_msg);
4545
4546 LlamaCoreError::Operation(err_msg)
4547 })?;
4548
4549 let chat_completion_chunk = ChatCompletionChunk {
4550 id,
4551 object: "chat.completion.chunk".to_string(),
4552 created: created.as_secs(),
4553 model: graph.name().to_owned(),
4554 system_fingerprint: "fp_44709d6fcb".to_string(),
4555 choices: vec![],
4556 usage,
4557 };
4558
4559 let chunk_str = serde_json::to_string(&chat_completion_chunk)
4561 .map_err(|e| {
4562 let err_msg = format!(
4563 "Failed to serialize chat completion chunk. Reason: {e}"
4564 );
4565
4566 #[cfg(feature = "logging")]
4567 error!(target: "stdout", "{}", &err_msg);
4568
4569 LlamaCoreError::Operation(err_msg)
4570 })?;
4571
4572 Ok(format!("data: {chunk_str}\n\n"))
4573 }
4574 StreamState::Done | StreamState::NoUsage => {
4575 *stream_state = StreamState::EndOfSequence;
4576
4577 Ok("data: [DONE]\n\n".to_string())
4578 }
4579 StreamState::EndOfSequence => {
4580 Ok("[GGML] End of sequence".to_string())
4581 }
4582 }
4583 }
4584 Err(wasmedge_wasi_nn::Error::BackendError(
4585 wasmedge_wasi_nn::BackendError::ContextFull,
4586 )) => {
4587 #[cfg(feature = "logging")]
4588 debug!(target: "stdout", "Context full");
4589
4590 match context_full_state {
4591 ContextFullState::Message => {
4592 match include_usage {
4593 true => *context_full_state = ContextFullState::Usage,
4594 false => *context_full_state = ContextFullState::Done,
4595 }
4596
4597 let created = SystemTime::now()
4598 .duration_since(std::time::UNIX_EPOCH)
4599 .map_err(|e| {
4600 let err_msg = format!(
4601 "Failed to get the current time. Reason: {e}"
4602 );
4603
4604 #[cfg(feature = "logging")]
4605 error!(target: "stdout", "{}", &err_msg);
4606
4607 LlamaCoreError::Operation(err_msg)
4608 })?;
4609
4610 let chat_completion_chunk = ChatCompletionChunk {
4611 id,
4612 object: "chat.completion.chunk".to_string(),
4613 created: created.as_secs(),
4614 model: graph.name().to_owned(),
4615 system_fingerprint: "fp_44709d6fcb".to_string(),
4616 choices: vec![ChatCompletionChunkChoice {
4617 index: 0,
4618 delta: ChatCompletionChunkChoiceDelta {
4619 role: ChatCompletionRole::Assistant,
4620 content: Some(
4621 "<|WASMEDGE-GGML-CONTEXT-FULL|>".to_string(),
4622 ),
4623 tool_calls: vec![],
4624 },
4625 logprobs: None,
4626 finish_reason: Some(FinishReason::length),
4627 }],
4628 usage: None,
4629 };
4630
4631 let chunk_str = serde_json::to_string(&chat_completion_chunk)
4633 .map_err(|e| {
4634 let err_msg = format!(
4635 "Failed to serialize chat completion chunk. Reason: {e}"
4636 );
4637
4638 #[cfg(feature = "logging")]
4639 error!(target: "stdout", "{}", &err_msg);
4640
4641 LlamaCoreError::Operation(err_msg)
4642 })?;
4643
4644 Ok(format!("data: {chunk_str}\n\n"))
4645 }
4646 ContextFullState::Usage => {
4647 *context_full_state = ContextFullState::Done;
4648
4649 let token_info = get_token_info_by_graph(graph)?;
4651
4652 let usage = Some(Usage {
4653 prompt_tokens: token_info.prompt_tokens,
4654 completion_tokens: token_info.completion_tokens,
4655 total_tokens: token_info.prompt_tokens
4656 + token_info.completion_tokens,
4657 });
4658
4659 let created = SystemTime::now()
4660 .duration_since(std::time::UNIX_EPOCH)
4661 .map_err(|e| {
4662 let err_msg = format!(
4663 "Failed to get the current time. Reason: {e}"
4664 );
4665
4666 #[cfg(feature = "logging")]
4667 error!(target: "stdout", "{}", &err_msg);
4668
4669 LlamaCoreError::Operation(err_msg)
4670 })?;
4671
4672 let chat_completion_chunk = ChatCompletionChunk {
4673 id,
4674 object: "chat.completion.chunk".to_string(),
4675 created: created.as_secs(),
4676 model: graph.name().to_owned(),
4677 system_fingerprint: "fp_44709d6fcb".to_string(),
4678 choices: vec![],
4679 usage,
4680 };
4681
4682 let chunk_str = serde_json::to_string(&chat_completion_chunk)
4684 .map_err(|e| {
4685 let err_msg = format!(
4686 "Failed to serialize chat completion chunk. Reason: {e}"
4687 );
4688
4689 #[cfg(feature = "logging")]
4690 error!(target: "stdout", "{}", &err_msg);
4691
4692 LlamaCoreError::Operation(err_msg)
4693 })?;
4694
4695 Ok(format!("data: {chunk_str}\n\n"))
4696 }
4697 ContextFullState::Done => {
4698 *context_full_state = ContextFullState::EndOfSequence;
4699
4700 Ok("data: [DONE]\n\n".to_string())
4701 }
4702 ContextFullState::EndOfSequence => {
4703 Ok("[GGML] End of sequence".to_string())
4704 }
4705 }
4706 }
4707 Err(wasmedge_wasi_nn::Error::BackendError(
4708 wasmedge_wasi_nn::BackendError::PromptTooLong,
4709 )) => {
4710 #[cfg(feature = "logging")]
4711 debug!(target: "stdout", "Prompt too long");
4712
4713 match prompt_too_long_state {
4714 PromptTooLongState::Message => {
4715 match include_usage {
4716 true => *prompt_too_long_state = PromptTooLongState::Usage,
4717 false => *prompt_too_long_state = PromptTooLongState::Done,
4718 }
4719
4720 let created = SystemTime::now()
4721 .duration_since(std::time::UNIX_EPOCH)
4722 .map_err(|e| {
4723 let err_msg = format!(
4724 "Failed to get the current time. Reason: {e}"
4725 );
4726
4727 #[cfg(feature = "logging")]
4728 error!(target: "stdout", "{}", &err_msg);
4729
4730 LlamaCoreError::Operation(err_msg)
4731 })?;
4732
4733 let chat_completion_chunk = ChatCompletionChunk {
4734 id,
4735 object: "chat.completion.chunk".to_string(),
4736 created: created.as_secs(),
4737 model: graph.name().to_owned(),
4738 system_fingerprint: "fp_44709d6fcb".to_string(),
4739 choices: vec![ChatCompletionChunkChoice {
4740 index: 0,
4741 delta: ChatCompletionChunkChoiceDelta {
4742 role: ChatCompletionRole::Assistant,
4743 content: None,
4744 tool_calls: vec![],
4745 },
4746 logprobs: None,
4747 finish_reason: Some(FinishReason::length),
4748 }],
4749 usage: None,
4750 };
4751
4752 let chunk_str = serde_json::to_string(&chat_completion_chunk)
4754 .map_err(|e| {
4755 let err_msg = format!(
4756 "Failed to serialize chat completion chunk. Reason: {e}"
4757 );
4758
4759 #[cfg(feature = "logging")]
4760 error!(target: "stdout", "{}", &err_msg);
4761
4762 LlamaCoreError::Operation(err_msg)
4763 })?;
4764
4765 Ok(format!("data: {chunk_str}\n\n"))
4766 }
4767 PromptTooLongState::Usage => {
4768 *prompt_too_long_state = PromptTooLongState::Done;
4769
4770 let token_info = get_token_info_by_graph(graph)?;
4772
4773 let usage = Some(Usage {
4774 prompt_tokens: token_info.prompt_tokens,
4775 completion_tokens: token_info.completion_tokens,
4776 total_tokens: token_info.prompt_tokens
4777 + token_info.completion_tokens,
4778 });
4779
4780 let created = SystemTime::now()
4781 .duration_since(std::time::UNIX_EPOCH)
4782 .map_err(|e| {
4783 let err_msg = format!(
4784 "Failed to get the current time. Reason: {e}"
4785 );
4786
4787 #[cfg(feature = "logging")]
4788 error!(target: "stdout", "{}", &err_msg);
4789
4790 LlamaCoreError::Operation(err_msg)
4791 })?;
4792
4793 let chat_completion_chunk = ChatCompletionChunk {
4794 id,
4795 object: "chat.completion.chunk".to_string(),
4796 created: created.as_secs(),
4797 model: graph.name().to_owned(),
4798 system_fingerprint: "fp_44709d6fcb".to_string(),
4799 choices: vec![],
4800 usage,
4801 };
4802
4803 let chunk_str = serde_json::to_string(&chat_completion_chunk)
4805 .map_err(|e| {
4806 let err_msg = format!(
4807 "Failed to serialize chat completion chunk. Reason: {e}"
4808 );
4809
4810 #[cfg(feature = "logging")]
4811 error!(target: "stdout", "{}", &err_msg);
4812
4813 LlamaCoreError::Operation(err_msg)
4814 })?;
4815
4816 Ok(format!("data: {chunk_str}\n\n"))
4817 }
4818 PromptTooLongState::Done => {
4819 *prompt_too_long_state = PromptTooLongState::EndOfSequence;
4820
4821 Ok("data: [DONE]\n\n".to_string())
4822 }
4823 PromptTooLongState::EndOfSequence => {
4824 Ok("[GGML] End of sequence".to_string())
4825 }
4826 }
4827 }
4828 Err(e) => {
4829 let err_msg =
4830 format!("Failed to compute the chat completion. Reason: {e}");
4831
4832 #[cfg(feature = "logging")]
4833 error!(target: "stdout", "{}", &err_msg);
4834
4835 Err(LlamaCoreError::Backend(BackendError::ComputeSingle(
4836 err_msg,
4837 )))
4838 }
4839 }
4840 }
4841 None => {
4842 let err_msg = "There is no model available in the chat graphs.";
4843
4844 #[cfg(feature = "logging")]
4845 error!(target: "stdout", "{}", &err_msg);
4846
4847 Err(LlamaCoreError::Operation(err_msg.into()))
4848 }
4849 }
4850 }
4851 };
4852
4853 #[cfg(feature = "logging")]
4854 info!(target: "stdout", "Return the chat stream chunk!");
4855
4856 res
4857}
4858
4859#[allow(dead_code)]
4860#[derive(Debug)]
4861struct ParseResult {
4862 raw: String,
4863 content: Option<String>,
4864 tool_calls: Vec<ToolCall>,
4865}