1use crate::{
4 error,
5 metadata::ggml::GgmlMetadata,
6 running_mode,
7 utils::{
8 gen_chat_id, get_output_buffer, get_output_buffer_single, get_token_info_by_graph,
9 get_token_info_by_graph_name, set_tensor_data_u8,
10 },
11 Graph, RunningMode, CACHED_UTF8_ENCODINGS, CHAT_GRAPHS, OUTPUT_TENSOR,
12};
13use chat_prompts::{
14 chat::{BuildChatPrompt, ChatPrompt},
15 PromptTemplateType,
16};
17use either::{Either, Left, Right};
18use endpoints::{
19 chat::{
20 ChatCompletionChunk, ChatCompletionChunkChoice, ChatCompletionChunkChoiceDelta,
21 ChatCompletionObject, ChatCompletionObjectChoice, ChatCompletionObjectMessage,
22 ChatCompletionRequest, ChatCompletionRequestMessage, ChatCompletionRole,
23 ChatCompletionUserMessageContent, ContentPart, Function, ToolCall, ToolCallForChunk,
24 ToolChoice,
25 },
26 common::{FinishReason, Usage},
27};
28use error::{BackendError, LlamaCoreError};
29use futures::StreamExt;
30use std::{
31 collections::VecDeque,
32 fs::{self, File},
33 path::Path,
34 pin::Pin,
35 sync::{
36 atomic::{AtomicBool, Ordering},
37 Mutex, OnceLock,
38 },
39 task::{Context, Poll, Waker},
40 time::SystemTime,
41};
42
43static CHAT_STREAM_WAKER_QUEUE: OnceLock<Mutex<VecDeque<Waker>>> = OnceLock::new();
45
46static CHAT_STREAM_ACTIVE: AtomicBool = AtomicBool::new(false);
48
49pub async fn chat(
51 chat_request: &mut ChatCompletionRequest,
52) -> Result<
53 (
54 Either<impl futures::TryStream<Ok = String, Error = LlamaCoreError>, ChatCompletionObject>,
55 bool,
56 ),
57 LlamaCoreError,
58> {
59 #[cfg(feature = "logging")]
60 {
61 debug!(target: "stdout", "tool choice: {:?}", chat_request.tool_choice.as_ref());
62 debug!(target: "stdout", "tools: {:?}", chat_request.tools.as_ref());
63 debug!(target: "stdout", "stream mode: {:?}", chat_request.stream);
64 }
65
66 let result = match chat_request.stream {
67 Some(true) => match chat_stream(chat_request).await {
68 Ok((stream, include_tool_calls)) => Ok((Left(stream), include_tool_calls)),
69 Err(e) => Err(e),
70 },
71 Some(false) | None => match chat_once(chat_request).await {
72 Ok((chat_completion_object, include_tool_calls)) => {
73 Ok((Right(chat_completion_object), include_tool_calls))
74 }
75 Err(e) => Err(e),
76 },
77 };
78
79 #[cfg(feature = "logging")]
80 info!(target: "stdout", "Reset the model metadata");
81
82 result
83}
84
85async fn chat_stream(
102 chat_request: &mut ChatCompletionRequest,
103) -> Result<
104 (
105 impl futures::TryStream<Ok = String, Error = LlamaCoreError>,
106 bool,
107 ),
108 LlamaCoreError,
109> {
110 #[cfg(feature = "logging")]
111 info!(target: "stdout", "Process chat completion request in the stream mode");
112
113 let running_mode = running_mode()?;
114 if !running_mode.contains(RunningMode::CHAT) && !running_mode.contains(RunningMode::RAG) {
115 let err_msg = "The chat completion is only supported in the chat or rag mode.";
116
117 #[cfg(feature = "logging")]
118 error!(target: "stdout", "{err_msg}");
119
120 return Err(LlamaCoreError::Operation(err_msg.to_string()));
121 }
122
123 let model_name = chat_request.model.clone();
124 let id = match &chat_request.user {
125 Some(id) => id.clone(),
126 None => gen_chat_id(),
127 };
128 #[cfg(feature = "logging")]
129 info!(target: "stdout", "user: {}", &id);
130
131 #[cfg(feature = "logging")]
132 info!(target: "stdout", "Check model metadata");
133
134 let mut metadata = check_model_metadata(chat_request)?;
136
137 let include_usage = match chat_request.stream_options {
139 Some(ref stream_options) => stream_options.include_usage.unwrap_or_default(),
140 None => metadata.include_usage,
141 };
142 #[cfg(feature = "logging")]
143 info!(target: "stdout", "include_usage: {include_usage}");
144
145 #[cfg(feature = "logging")]
146 info!(target: "stdout", "Build the chat prompt");
147
148 let (prompt, avaible_completion_tokens, tool_use) =
150 build_prompt(model_name.as_ref(), chat_request)?;
151
152 #[cfg(feature = "logging")]
153 {
154 info!(target: "stdout", "prompt:\n{}", &prompt);
155 info!(target: "stdout", "available_completion_tokens: {avaible_completion_tokens}");
156 info!(target: "stdout", "tool_use: {tool_use}");
157 }
158
159 #[cfg(feature = "logging")]
160 info!(target: "stdout", "Update the n_predict");
161
162 update_n_predict(chat_request, &mut metadata, avaible_completion_tokens)?;
164
165 #[cfg(feature = "logging")]
166 info!(target: "stdout", "Feed the prompt to the model");
167
168 set_prompt(chat_request.model.as_ref(), &prompt)?;
170
171 let stream = match tool_use {
172 false => (ChatStream::new(model_name, id, include_usage, None), false),
173 true => {
174 let chat_graphs = match CHAT_GRAPHS.get() {
175 Some(chat_graphs) => chat_graphs,
176 None => {
177 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
178
179 #[cfg(feature = "logging")]
180 error!(target: "stdout", "{}", &err_msg);
181
182 return Err(LlamaCoreError::Operation(err_msg.into()));
183 }
184 };
185
186 let mut chat_graphs = chat_graphs.lock().map_err(|e| {
187 let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
188
189 #[cfg(feature = "logging")]
190 error!(target: "stdout", "{}", &err_msg);
191
192 LlamaCoreError::Operation(err_msg)
193 })?;
194
195 match model_name {
196 Some(model_name) => match chat_graphs.contains_key(&model_name) {
197 true => {
198 let graph = chat_graphs.get_mut(&model_name).unwrap();
199 chat_stream_for_tool(graph, id, include_usage)?
200 }
201 false => match chat_graphs.iter_mut().next() {
202 Some((_, graph)) => chat_stream_for_tool(graph, id, include_usage)?,
203 None => {
204 let err_msg = "There is no model available in the chat graphs.";
205
206 #[cfg(feature = "logging")]
207 error!(target: "stdout", "{}", &err_msg);
208
209 return Err(LlamaCoreError::Operation(err_msg.into()));
210 }
211 },
212 },
213 None => match chat_graphs.iter_mut().next() {
214 Some((_, graph)) => chat_stream_for_tool(graph, id, include_usage)?,
215 None => {
216 let err_msg = "There is no model available in the chat graphs.";
217
218 #[cfg(feature = "logging")]
219 error!(target: "stdout", "{}", &err_msg);
220
221 return Err(LlamaCoreError::Operation(err_msg.into()));
222 }
223 },
224 }
225 }
226 };
227
228 #[cfg(feature = "logging")]
229 info!(target: "stdout", "End of the chat completion stream.");
230
231 Ok(stream)
232}
233
234fn chat_stream_for_tool(
235 graph: &mut Graph<GgmlMetadata>,
236 id: impl Into<String>,
237 include_usage: bool,
238) -> Result<(ChatStream, bool), LlamaCoreError> {
239 #[cfg(feature = "logging")]
240 info!(target: "stdout", "Handle chat request with available tools by the model named {}.", graph.name());
241
242 let id = id.into();
243
244 match graph.compute() {
245 Ok(_) => {
246 let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
248 let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
249 let err_msg = format!(
250 "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
251 );
252
253 #[cfg(feature = "logging")]
254 error!(target: "stdout", "{}", &err_msg);
255
256 LlamaCoreError::Operation(err_msg)
257 })?;
258
259 #[cfg(feature = "logging")]
260 info!(target: "stdout", "raw generation:\n{output}");
261
262 let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
264 LlamaCoreError::Operation(format!("Failed to post-process the output. {e}"))
265 })?;
266
267 #[cfg(feature = "logging")]
268 info!(target: "stdout", "post-processed generation:\n{}", &message);
269
270 let token_info = get_token_info_by_graph(graph)?;
272
273 #[cfg(feature = "logging")]
274 info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
275
276 let usage = Some(Usage {
277 prompt_tokens: token_info.prompt_tokens,
278 completion_tokens: token_info.completion_tokens,
279 total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
280 });
281
282 let created = SystemTime::now()
283 .duration_since(std::time::UNIX_EPOCH)
284 .map_err(|e| {
285 let err_msg = format!("Failed to get the current time. Reason: {e}");
286
287 #[cfg(feature = "logging")]
288 error!(target: "stdout", "{}", &err_msg);
289
290 LlamaCoreError::Operation(err_msg)
291 })?;
292
293 if graph.metadata.prompt_template != PromptTemplateType::MistralTool
294 && graph.metadata.prompt_template != PromptTemplateType::ChatMLTool
295 && graph.metadata.prompt_template != PromptTemplateType::GroqLlama3Tool
296 && graph.metadata.prompt_template != PromptTemplateType::Llama3Tool
297 && graph.metadata.prompt_template != PromptTemplateType::InternLM2Tool
298 && graph.metadata.prompt_template != PromptTemplateType::NemotronTool
299 && graph.metadata.prompt_template != PromptTemplateType::FunctionaryV32
300 && graph.metadata.prompt_template != PromptTemplateType::FunctionaryV31
301 && graph.metadata.prompt_template != PromptTemplateType::MistralSmallTool
302 && graph.metadata.prompt_template != PromptTemplateType::Llama4Chat
303 {
304 let err_msg = format!("Unsupported prompt template: {}. The tool use is only supported for 'mistral-tool', 'chatml-tool', 'groq-llama3-tool', 'llama-3-tool', 'internlm-2-tool', 'nemotron-tool', 'functionary-31', 'functionary-32', 'mistral-small-tool', and 'llama-4-chat' prompt templates.", graph.metadata.prompt_template);
305
306 #[cfg(feature = "logging")]
307 error!(target: "stdout", "{}", &err_msg);
308
309 return Err(LlamaCoreError::Operation(err_msg));
310 }
311
312 let parsed_result = parse_tool_calls(&message, graph.metadata.prompt_template)?;
313
314 let content = if parsed_result.tool_calls.is_empty() {
315 Some(parsed_result.raw.clone())
316 } else {
317 parsed_result.content.clone()
318 };
319
320 let (tool_calls, include_tool_calls) = match parsed_result.tool_calls.is_empty() {
321 false => {
322 let tool_calls: Vec<ToolCallForChunk> = parsed_result
323 .tool_calls
324 .into_iter()
325 .enumerate()
326 .map(|(index, tool_call)| ToolCallForChunk {
327 index,
328 id: tool_call.id,
329 ty: tool_call.ty,
330 function: tool_call.function,
331 })
332 .collect();
333 (tool_calls, true)
334 }
335 true => (vec![], false),
336 };
337
338 let tool_call_chunk = {
340 let chat_completion_chunk = ChatCompletionChunk {
341 id: id.clone(),
342 object: "chat.completion.chunk".to_string(),
343 created: created.as_secs(),
344 model: graph.name().to_owned(),
345 system_fingerprint: "fp_44709d6fcb".to_string(),
346 choices: vec![ChatCompletionChunkChoice {
347 index: 0,
348 delta: ChatCompletionChunkChoiceDelta {
349 role: ChatCompletionRole::Assistant,
350 content,
351 tool_calls,
352 },
353 logprobs: None,
354 finish_reason: None,
355 }],
356 usage: None,
357 };
358 let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
359 let err_msg = format!("Failed to serialize chat completion chunk. Reason: {e}");
360
361 #[cfg(feature = "logging")]
362 error!(target: "stdout", "{}", &err_msg);
363
364 LlamaCoreError::Operation(err_msg)
365 })?;
366
367 format!("data: {chunk_str}\n\n")
368 };
369
370 let usage_chunk = {
372 let chat_completion_chunk = ChatCompletionChunk {
373 id: id.clone(),
374 object: "chat.completion.chunk".to_string(),
375 created: created.as_secs(),
376 model: graph.name().to_owned(),
377 system_fingerprint: "fp_44709d6fcb".to_string(),
378 choices: vec![],
379 usage,
380 };
381 let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
382 let err_msg = format!("Failed to serialize chat completion chunk. Reason: {e}");
383
384 #[cfg(feature = "logging")]
385 error!(target: "stdout", "{}", &err_msg);
386
387 LlamaCoreError::Operation(err_msg)
388 })?;
389
390 format!("data: {chunk_str}\n\n")
391 };
392
393 let ending_chunk = "data: [DONE]\n\n".to_string();
395
396 let chunks = vec![tool_call_chunk, usage_chunk, ending_chunk];
397
398 let stream = ChatStream::new(
399 Some(graph.name().to_owned()),
400 id,
401 include_usage,
402 Some(chunks),
403 );
404
405 Ok((stream, include_tool_calls))
406 }
407 Err(wasmedge_wasi_nn::Error::BackendError(wasmedge_wasi_nn::BackendError::ContextFull)) => {
408 let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
410 let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
411 let err_msg = format!(
412 "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
413 );
414
415 #[cfg(feature = "logging")]
416 error!(target: "stdout", "{}", &err_msg);
417
418 LlamaCoreError::Operation(err_msg)
419 })?;
420
421 let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
423 let err_msg = format!("Failed to post-process the output. {e}");
424
425 #[cfg(feature = "logging")]
426 error!(target: "stdout", "{}", &err_msg);
427
428 LlamaCoreError::Operation(err_msg)
429 })?;
430
431 let token_info = get_token_info_by_graph(graph)?;
433
434 #[cfg(feature = "logging")]
435 info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
436
437 let usage = Some(Usage {
438 prompt_tokens: token_info.prompt_tokens,
439 completion_tokens: token_info.completion_tokens,
440 total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
441 });
442
443 let created = SystemTime::now()
444 .duration_since(std::time::UNIX_EPOCH)
445 .map_err(|e| {
446 let err_msg = format!("Failed to get the current time. Reason: {e}");
447
448 #[cfg(feature = "logging")]
449 error!(target: "stdout", "{}", &err_msg);
450
451 LlamaCoreError::Operation(err_msg)
452 })?;
453
454 let context_full_chunk = {
456 let chat_completion_chunk = ChatCompletionChunk {
457 id: id.clone(),
458 object: "chat.completion.chunk".to_string(),
459 created: created.as_secs(),
460 model: graph.name().to_owned(),
461 system_fingerprint: "fp_44709d6fcb".to_string(),
462 choices: vec![ChatCompletionChunkChoice {
463 index: 0,
464 delta: ChatCompletionChunkChoiceDelta {
465 role: ChatCompletionRole::Assistant,
466 content: Some(message),
467 tool_calls: vec![],
468 },
469 logprobs: None,
470 finish_reason: Some(FinishReason::length),
471 }],
472 usage: None,
473 };
474
475 let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
477 let err_msg = format!("Failed to serialize chat completion chunk. Reason: {e}");
478
479 #[cfg(feature = "logging")]
480 error!(target: "stdout", "{}", &err_msg);
481
482 LlamaCoreError::Operation(err_msg)
483 })?;
484
485 format!("data: {chunk_str}\n\n")
486 };
487
488 let usage_chunk = {
490 let chat_completion_chunk = ChatCompletionChunk {
491 id: id.clone(),
492 object: "chat.completion.chunk".to_string(),
493 created: created.as_secs(),
494 model: graph.name().to_owned(),
495 system_fingerprint: "fp_44709d6fcb".to_string(),
496 choices: vec![],
497 usage,
498 };
499
500 let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
502 let err_msg = format!("Failed to serialize chat completion chunk. Reason: {e}");
503
504 #[cfg(feature = "logging")]
505 error!(target: "stdout", "{}", &err_msg);
506
507 LlamaCoreError::Operation(err_msg)
508 })?;
509
510 format!("data: {chunk_str}\n\n")
511 };
512
513 let ending_chunk = "data: [DONE]\n\n".to_string();
515
516 let chunks = vec![context_full_chunk, usage_chunk, ending_chunk];
517
518 let stream = ChatStream::new(
519 Some(graph.name().to_owned()),
520 id,
521 include_usage,
522 Some(chunks),
523 );
524
525 Ok((stream, false))
526 }
527 Err(wasmedge_wasi_nn::Error::BackendError(
528 wasmedge_wasi_nn::BackendError::PromptTooLong,
529 )) => {
530 #[cfg(feature = "logging")]
531 warn!(target: "stdout", "The prompt is too long. Please reduce the length of your input and try again.");
532
533 let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
535 let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
536 let err_msg = format!(
537 "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
538 );
539
540 #[cfg(feature = "logging")]
541 error!(target: "stdout", "{}", &err_msg);
542
543 LlamaCoreError::Operation(err_msg)
544 })?;
545
546 let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
548 let err_msg = format!("Failed to post-process the output. {e}");
549
550 #[cfg(feature = "logging")]
551 error!(target: "stdout", "{}", &err_msg);
552
553 LlamaCoreError::Operation(err_msg)
554 })?;
555
556 let token_info = get_token_info_by_graph(graph)?;
558
559 #[cfg(feature = "logging")]
560 info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
561
562 let usage = Some(Usage {
563 prompt_tokens: token_info.prompt_tokens,
564 completion_tokens: token_info.completion_tokens,
565 total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
566 });
567
568 let created = SystemTime::now()
569 .duration_since(std::time::UNIX_EPOCH)
570 .map_err(|e| {
571 let err_msg = format!("Failed to get the current time. Reason: {e}");
572
573 #[cfg(feature = "logging")]
574 error!(target: "stdout", "{}", &err_msg);
575
576 LlamaCoreError::Operation(err_msg)
577 })?;
578
579 let prompt_too_long_chunk = {
581 let chat_completion_chunk = ChatCompletionChunk {
582 id: id.clone(),
583 object: "chat.completion.chunk".to_string(),
584 created: created.as_secs(),
585 model: graph.name().to_owned(),
586 system_fingerprint: "fp_44709d6fcb".to_string(),
587 choices: vec![ChatCompletionChunkChoice {
588 index: 0,
589 delta: ChatCompletionChunkChoiceDelta {
590 role: ChatCompletionRole::Assistant,
591 content: Some(message),
592 tool_calls: vec![],
593 },
594 logprobs: None,
595 finish_reason: Some(FinishReason::length),
596 }],
597 usage: None,
598 };
599
600 let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
602 let err_msg = format!("Failed to serialize chat completion chunk. Reason: {e}");
603
604 #[cfg(feature = "logging")]
605 error!(target: "stdout", "{}", &err_msg);
606
607 LlamaCoreError::Operation(err_msg)
608 })?;
609
610 format!("data: {chunk_str}\n\n")
611 };
612
613 let usage_chunk = {
615 let chat_completion_chunk = ChatCompletionChunk {
616 id: id.clone(),
617 object: "chat.completion.chunk".to_string(),
618 created: created.as_secs(),
619 model: graph.name().to_owned(),
620 system_fingerprint: "fp_44709d6fcb".to_string(),
621 choices: vec![],
622 usage,
623 };
624
625 let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
627 let err_msg = format!("Failed to serialize chat completion chunk. Reason: {e}");
628
629 #[cfg(feature = "logging")]
630 error!(target: "stdout", "{}", &err_msg);
631
632 LlamaCoreError::Operation(err_msg)
633 })?;
634
635 format!("data: {chunk_str}\n\n")
636 };
637
638 let ending_chunk = "data: [DONE]\n\n".to_string();
640
641 let chunks = vec![prompt_too_long_chunk, usage_chunk, ending_chunk];
642
643 let stream = ChatStream::new(
644 Some(graph.name().to_owned()),
645 id,
646 include_usage,
647 Some(chunks),
648 );
649
650 Ok((stream, false))
651 }
652 Err(e) => {
653 let err_msg = format!("Failed to compute the chat completion. Reason: {e}");
654
655 #[cfg(feature = "logging")]
656 error!(target: "stdout", "{}", &err_msg);
657
658 Err(LlamaCoreError::Backend(BackendError::Compute(err_msg)))
659 }
660 }
661}
662
663async fn chat_once(
664 chat_request: &mut ChatCompletionRequest,
665) -> Result<(ChatCompletionObject, bool), LlamaCoreError> {
666 #[cfg(feature = "logging")]
667 info!(target: "stdout", "Processing chat completion request in non-stream mode");
668
669 let running_mode = running_mode()?;
670 if !running_mode.contains(RunningMode::CHAT) && !running_mode.contains(RunningMode::RAG) {
671 let err_msg = "The chat completion is only supported in the chat or rag mode.";
672
673 #[cfg(feature = "logging")]
674 error!(target: "stdout", "{err_msg}");
675
676 return Err(LlamaCoreError::Operation(err_msg.to_string()));
677 }
678
679 let model_name = chat_request.model.clone();
680 let id = match &chat_request.user {
681 Some(id) => id.clone(),
682 None => gen_chat_id(),
683 };
684
685 #[cfg(feature = "logging")]
686 info!(target: "stdout", "user: {}", &id);
687
688 #[cfg(feature = "logging")]
689 info!(target: "stdout", "Check model metadata");
690
691 let mut metadata = check_model_metadata(chat_request)?;
693
694 #[cfg(feature = "logging")]
695 info!(target: "stdout", "Build the chat prompt");
696
697 let (prompt, avaible_completion_tokens, tool_use) =
699 build_prompt(model_name.as_ref(), chat_request)?;
700
701 #[cfg(feature = "logging")]
702 {
703 info!(target: "stdout", "prompt:\n{}", &prompt);
704 info!(target: "stdout", "available_completion_tokens: {avaible_completion_tokens}");
705 info!(target: "stdout", "tool_use: {tool_use}");
706 }
707
708 #[cfg(feature = "logging")]
709 info!(target: "stdout", "Update n_predict");
710
711 update_n_predict(chat_request, &mut metadata, avaible_completion_tokens)?;
713
714 #[cfg(feature = "logging")]
715 info!(target: "stdout", "Feed the prompt to the model");
716
717 set_prompt(model_name.as_ref(), &prompt)?;
719
720 #[cfg(feature = "logging")]
721 info!(target: "stdout", "Compute chat completion.");
722
723 let res = compute(model_name.as_ref(), id, tool_use);
725
726 #[cfg(feature = "logging")]
727 info!(target: "stdout", "End of the chat completion");
728
729 reset_model_metadata(model_name.as_ref())?;
731
732 res
733}
734
735fn compute(
736 model_name: Option<&String>,
737 id: impl Into<String>,
738 tool_use: bool,
739) -> Result<(ChatCompletionObject, bool), LlamaCoreError> {
740 let chat_graphs = match CHAT_GRAPHS.get() {
741 Some(chat_graphs) => chat_graphs,
742 None => {
743 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
744
745 #[cfg(feature = "logging")]
746 error!(target: "stdout", "{}", &err_msg);
747
748 return Err(LlamaCoreError::Operation(err_msg.into()));
749 }
750 };
751
752 let mut chat_graphs = chat_graphs.lock().map_err(|e| {
753 let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
754
755 #[cfg(feature = "logging")]
756 error!(target: "stdout", "{}", &err_msg);
757
758 LlamaCoreError::Operation(err_msg)
759 })?;
760
761 match model_name {
762 Some(model_name) => match chat_graphs.contains_key(model_name) {
763 true => {
764 let graph = chat_graphs.get_mut(model_name).unwrap();
765 compute_by_graph(graph, id, tool_use)
766 }
767 false => match chat_graphs.iter_mut().next() {
768 Some((_, graph)) => compute_by_graph(graph, id, tool_use),
769 None => {
770 let err_msg = "There is no model available in the chat graphs.";
771
772 #[cfg(feature = "logging")]
773 error!(target: "stdout", "{}", &err_msg);
774
775 Err(LlamaCoreError::Operation(err_msg.into()))
776 }
777 },
778 },
779 None => match chat_graphs.iter_mut().next() {
780 Some((_, graph)) => compute_by_graph(graph, id, tool_use),
781 None => {
782 let err_msg = "There is no model available in the chat graphs.";
783
784 #[cfg(feature = "logging")]
785 error!(target: "stdout", "{}", &err_msg);
786
787 Err(LlamaCoreError::Operation(err_msg.into()))
788 }
789 },
790 }
791}
792
793fn compute_by_graph(
794 graph: &mut Graph<GgmlMetadata>,
795 id: impl Into<String>,
796 tool_use: bool,
797) -> Result<(ChatCompletionObject, bool), LlamaCoreError> {
798 #[cfg(feature = "logging")]
799 info!(target: "stdout", "Compute chat completion by the model named {}.", graph.name());
800
801 match graph.compute() {
802 Ok(_) => {
803 let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
805 let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
806 let err_msg = format!(
807 "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
808 );
809
810 #[cfg(feature = "logging")]
811 error!(target: "stdout", "{}", &err_msg);
812
813 LlamaCoreError::Operation(err_msg)
814 })?;
815
816 #[cfg(feature = "logging")]
817 info!(target: "stdout", "raw generation: {output}");
818
819 let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
821 LlamaCoreError::Operation(format!("Failed to post-process the output. {e}"))
822 })?;
823
824 #[cfg(feature = "logging")]
825 info!(target: "stdout", "post-processed generation:\n{}", &message);
826
827 let token_info = get_token_info_by_graph(graph)?;
829
830 #[cfg(feature = "logging")]
831 info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
832
833 let created = SystemTime::now()
834 .duration_since(std::time::UNIX_EPOCH)
835 .map_err(|e| {
836 let err_msg = format!("Failed to get the current time. Reason: {e}");
837
838 #[cfg(feature = "logging")]
839 error!(target: "stdout", "{}", &err_msg);
840
841 LlamaCoreError::Operation(err_msg)
842 })?;
843
844 match tool_use {
845 true => {
846 if graph.metadata.prompt_template != PromptTemplateType::MistralTool
847 && graph.metadata.prompt_template != PromptTemplateType::ChatMLTool
848 && graph.metadata.prompt_template != PromptTemplateType::GroqLlama3Tool
849 && graph.metadata.prompt_template != PromptTemplateType::Llama3Tool
850 && graph.metadata.prompt_template != PromptTemplateType::InternLM2Tool
851 && graph.metadata.prompt_template != PromptTemplateType::NemotronTool
852 && graph.metadata.prompt_template != PromptTemplateType::FunctionaryV32
853 && graph.metadata.prompt_template != PromptTemplateType::FunctionaryV31
854 && graph.metadata.prompt_template != PromptTemplateType::MistralSmallTool
855 && graph.metadata.prompt_template != PromptTemplateType::Llama4Chat
856 {
857 let err_msg = format!("Unsupported prompt template: {}. The tool use is only supported for 'mistral-tool', 'chatml-tool', 'groq-llama3-tool', 'llama-3-tool', 'internlm-2-tool', 'nemotron-tool', 'functionary-31', 'functionary-32', 'mistral-small-tool', and 'llama-4-chat' prompt templates.", graph.metadata.prompt_template);
858
859 #[cfg(feature = "logging")]
860 error!(target: "stdout", "{}", &err_msg);
861
862 return Err(LlamaCoreError::Operation(err_msg));
863 }
864
865 let parsed_result = parse_tool_calls(&message, graph.metadata.prompt_template)?;
866
867 let (finish_reason, content, include_tool_calls) =
868 if parsed_result.tool_calls.is_empty() {
869 (FinishReason::stop, Some(parsed_result.raw.clone()), false)
870 } else {
871 (
872 FinishReason::tool_calls,
873 parsed_result.content.clone(),
874 true,
875 )
876 };
877
878 let res = ChatCompletionObject {
879 id: id.into(),
880 object: String::from("chat.completion"),
881 created: created.as_secs(),
882 model: graph.name().to_owned(),
883 choices: vec![ChatCompletionObjectChoice {
884 index: 0,
885 message: ChatCompletionObjectMessage {
886 role: ChatCompletionRole::Assistant,
887 content,
888 tool_calls: parsed_result.tool_calls,
889 function_call: None,
890 },
891 finish_reason,
892 logprobs: None,
893 }],
894 usage: Usage {
895 prompt_tokens: token_info.prompt_tokens,
896 completion_tokens: token_info.completion_tokens,
897 total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
898 },
899 };
900
901 Ok((res, include_tool_calls))
903 }
904 false => {
905 let res = ChatCompletionObject {
907 id: id.into(),
908 object: String::from("chat.completion"),
909 created: created.as_secs(),
910 model: graph.name().to_owned(),
911 choices: vec![ChatCompletionObjectChoice {
912 index: 0,
913 message: ChatCompletionObjectMessage {
914 role: ChatCompletionRole::Assistant,
915 content: Some(message),
916 tool_calls: vec![],
917 function_call: None,
918 },
919 finish_reason: FinishReason::stop,
920 logprobs: None,
921 }],
922 usage: Usage {
923 prompt_tokens: token_info.prompt_tokens,
924 completion_tokens: token_info.completion_tokens,
925 total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
926 },
927 };
928
929 Ok((res, false))
930 }
931 }
932 }
933 Err(wasmedge_wasi_nn::Error::BackendError(wasmedge_wasi_nn::BackendError::ContextFull)) => {
934 let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
936 let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
937 let err_msg = format!(
938 "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
939 );
940
941 #[cfg(feature = "logging")]
942 error!(target: "stdout", "{}", &err_msg);
943
944 LlamaCoreError::Operation(err_msg)
945 })?;
946
947 let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
949 let err_msg = format!("Failed to post-process the output. {e}");
950
951 #[cfg(feature = "logging")]
952 error!(target: "stdout", "{}", &err_msg);
953
954 LlamaCoreError::Operation(err_msg)
955 })?;
956
957 let token_info = get_token_info_by_graph(graph)?;
959
960 #[cfg(feature = "logging")]
961 info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
962
963 let created = SystemTime::now()
964 .duration_since(std::time::UNIX_EPOCH)
965 .map_err(|e| {
966 let err_msg = format!("Failed to get the current time. Reason: {e}");
967
968 #[cfg(feature = "logging")]
969 error!(target: "stdout", "{}", &err_msg);
970
971 LlamaCoreError::Operation(err_msg)
972 })?;
973
974 let res = ChatCompletionObject {
976 id: id.into(),
977 object: String::from("chat.completion"),
978 created: created.as_secs(),
979 model: graph.name().to_owned(),
980 choices: vec![ChatCompletionObjectChoice {
981 index: 0,
982 message: ChatCompletionObjectMessage {
983 role: ChatCompletionRole::Assistant,
984 content: Some(message),
985 tool_calls: vec![],
986 function_call: None,
987 },
988 finish_reason: FinishReason::length,
989 logprobs: None,
990 }],
991 usage: Usage {
992 prompt_tokens: token_info.prompt_tokens,
993 completion_tokens: token_info.completion_tokens,
994 total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
995 },
996 };
997
998 Ok((res, false))
999 }
1000 Err(wasmedge_wasi_nn::Error::BackendError(
1001 wasmedge_wasi_nn::BackendError::PromptTooLong,
1002 )) => {
1003 #[cfg(feature = "logging")]
1004 warn!(target: "stdout", "The prompt is too long. Please reduce the length of your input and try again.");
1005
1006 let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
1008 let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
1009 let err_msg = format!(
1010 "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
1011 );
1012
1013 #[cfg(feature = "logging")]
1014 error!(target: "stdout", "{}", &err_msg);
1015
1016 LlamaCoreError::Operation(err_msg)
1017 })?;
1018
1019 let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
1021 let err_msg = format!("Failed to post-process the output. {e}");
1022
1023 #[cfg(feature = "logging")]
1024 error!(target: "stdout", "{}", &err_msg);
1025
1026 LlamaCoreError::Operation(err_msg)
1027 })?;
1028
1029 let token_info = get_token_info_by_graph(graph)?;
1031
1032 #[cfg(feature = "logging")]
1033 info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
1034
1035 let usage = Usage {
1036 prompt_tokens: token_info.prompt_tokens,
1037 completion_tokens: token_info.completion_tokens,
1038 total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
1039 };
1040
1041 let created = SystemTime::now()
1042 .duration_since(std::time::UNIX_EPOCH)
1043 .map_err(|e| {
1044 let err_msg = format!("Failed to get the current time. Reason: {e}");
1045
1046 #[cfg(feature = "logging")]
1047 error!(target: "stdout", "{}", &err_msg);
1048
1049 LlamaCoreError::Operation(err_msg)
1050 })?;
1051
1052 let res = ChatCompletionObject {
1054 id: id.into(),
1055 object: String::from("chat.completion"),
1056 created: created.as_secs(),
1057 model: graph.name().to_owned(),
1058 choices: vec![ChatCompletionObjectChoice {
1059 index: 0,
1060 message: ChatCompletionObjectMessage {
1061 role: ChatCompletionRole::Assistant,
1062 content: Some(message),
1063 tool_calls: vec![],
1064 function_call: None,
1065 },
1066 finish_reason: FinishReason::length,
1067 logprobs: None,
1068 }],
1069 usage,
1070 };
1071
1072 Ok((res, false))
1073 }
1074 Err(e) => {
1075 let err_msg = format!("Failed to compute the chat completion. Reason: {e}");
1076
1077 #[cfg(feature = "logging")]
1078 error!(target: "stdout", "{}", &err_msg);
1079
1080 Err(LlamaCoreError::Backend(BackendError::Compute(err_msg)))
1081 }
1082 }
1083}
1084
1085fn parse_tool_calls(
1086 input: &str,
1087 prompt_template: PromptTemplateType,
1088) -> Result<ParseResult, LlamaCoreError> {
1089 match prompt_template {
1090 PromptTemplateType::MistralTool => match regex::Regex::new(r"\[\{.*?\}\]") {
1091 Ok(re) => {
1092 let mut values: Vec<serde_json::Value> = vec![];
1093 for cap in re.captures_iter(input) {
1094 let matched = &cap[0];
1095
1096 #[cfg(feature = "logging")]
1097 info!(target: "stdout", "captured: {matched}");
1098
1099 match serde_json::from_str::<Vec<serde_json::Value>>(matched) {
1100 Ok(group) => values.extend(group),
1101 Err(e) => {
1102 let err_msg =
1103 format!("Failed to deserialize generated tool calls. Reason: {e}");
1104
1105 #[cfg(feature = "logging")]
1106 error!(target: "stdout", "{}", &err_msg);
1107
1108 return Err(LlamaCoreError::Operation(err_msg));
1109 }
1110 }
1111 }
1112
1113 let mut tool_calls: Vec<ToolCall> = vec![];
1114 for value in values.iter() {
1115 let name = match value.get("name") {
1116 Some(name) => name.to_string().replace("\"", ""),
1117 None => {
1118 let err_msg = format!(
1119 "Failed to get the name of the function. Tool call: {value:?}"
1120 );
1121
1122 #[cfg(feature = "logging")]
1123 error!(target: "stdout", "{}", &err_msg);
1124
1125 return Err(LlamaCoreError::Operation(err_msg));
1126 }
1127 };
1128
1129 let arguments = match value.get("arguments") {
1130 Some(arguments) => arguments.to_string(),
1131 None => {
1132 let err_msg = format!(
1133 "Failed to get the arguments of the function. Tool call: {value:?}"
1134 );
1135
1136 #[cfg(feature = "logging")]
1137 error!(target: "stdout", "{}", &err_msg);
1138
1139 return Err(LlamaCoreError::Operation(err_msg));
1140 }
1141 };
1142
1143 let function = Function { name, arguments };
1144
1145 let tool_call = ToolCall {
1146 id: "call_abc123".to_string(),
1147 ty: "function".to_string(),
1148 function,
1149 };
1150
1151 tool_calls.push(tool_call);
1152 }
1153
1154 let parsed = ParseResult {
1155 raw: input.to_owned(),
1156 content: None,
1157 tool_calls,
1158 };
1159
1160 #[cfg(feature = "logging")]
1161 info!(target: "stdout", "parsed result: {parsed:?}");
1162
1163 Ok(parsed)
1164 }
1165 Err(e) => {
1166 let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1167
1168 #[cfg(feature = "logging")]
1169 error!(target: "stdout", "{}", &err_msg);
1170
1171 Err(LlamaCoreError::Operation(err_msg))
1172 }
1173 },
1174 PromptTemplateType::ChatMLTool => {
1175 match regex::Regex::new(r"<tool_call>(.*?)</tool_call>") {
1176 Ok(re) => {
1177 let mut values: Vec<serde_json::Value> = vec![];
1178 for cap in re.captures_iter(input) {
1179 let matched = cap[1].replace("\\n", ""); #[cfg(feature = "logging")]
1182 info!(target: "stdout", "captured: {}", &matched);
1183
1184 match serde_json::from_str::<serde_json::Value>(&matched) {
1185 Ok(value) => values.push(value),
1186 Err(e) => {
1187 let err_msg = format!(
1188 "Failed to deserialize generated tool calls. Reason: {e}"
1189 );
1190
1191 #[cfg(feature = "logging")]
1192 error!(target: "stdout", "{}", &err_msg);
1193
1194 return Err(LlamaCoreError::Operation(err_msg));
1195 }
1196 }
1197 }
1198
1199 let mut tool_calls: Vec<ToolCall> = vec![];
1200 for value in values.iter() {
1201 let name = match value.get("name") {
1202 Some(name) => name.to_string().replace("\"", ""),
1203 None => {
1204 let err_msg = format!(
1205 "Failed to get the name of the function. Tool call: {value:?}"
1206 );
1207
1208 #[cfg(feature = "logging")]
1209 error!(target: "stdout", "{}", &err_msg);
1210
1211 return Err(LlamaCoreError::Operation(err_msg));
1212 }
1213 };
1214
1215 let arguments = match value.get("arguments") {
1216 Some(arguments) => arguments.to_string(),
1217 None => {
1218 let err_msg = format!(
1219 "Failed to get the arguments of the function. Tool call: {value:?}"
1220 );
1221
1222 #[cfg(feature = "logging")]
1223 error!(target: "stdout", "{}", &err_msg);
1224
1225 return Err(LlamaCoreError::Operation(err_msg));
1226 }
1227 };
1228
1229 let function = Function { name, arguments };
1230
1231 let tool_call = ToolCall {
1232 id: "call_abc123".to_string(),
1233 ty: "function".to_string(),
1234 function,
1235 };
1236
1237 tool_calls.push(tool_call);
1238 }
1239
1240 let parsed = ParseResult {
1241 raw: input.to_owned(),
1242 content: None,
1243 tool_calls,
1244 };
1245
1246 #[cfg(feature = "logging")]
1247 info!(target: "stdout", "parsed result: {parsed:?}");
1248
1249 Ok(parsed)
1250 }
1251 Err(e) => {
1252 let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1253
1254 #[cfg(feature = "logging")]
1255 error!(target: "stdout", "{}", &err_msg);
1256
1257 Err(LlamaCoreError::Operation(err_msg))
1258 }
1259 }
1260 }
1261 PromptTemplateType::GroqLlama3Tool => {
1262 #[cfg(feature = "logging")]
1263 info!(target: "stdout", "raw input: {input}");
1264
1265 match regex::Regex::new(r"(?s)<tool_call>((.|\r|\n)*?)</tool_call>") {
1266 Ok(re) => {
1267 let mut values: Vec<serde_json::Value> = vec![];
1268 for cap in re.captures_iter(input) {
1269 let matched = cap[1].trim();
1270
1271 #[cfg(feature = "logging")]
1272 info!(target: "stdout", "captured: {matched}");
1273
1274 match serde_json::from_str::<serde_json::Value>(matched) {
1275 Ok(value) => values.push(value),
1276 Err(e) => {
1277 let err_msg = format!(
1278 "Failed to deserialize generated tool calls. Reason: {e}"
1279 );
1280
1281 #[cfg(feature = "logging")]
1282 error!(target: "stdout", "{}", &err_msg);
1283
1284 return Err(LlamaCoreError::Operation(err_msg));
1285 }
1286 }
1287 }
1288
1289 let mut tool_calls: Vec<ToolCall> = vec![];
1290 for value in values.iter() {
1291 let name = match value.get("name") {
1292 Some(name) => name.to_string().replace("\"", ""),
1293 None => {
1294 let err_msg = format!(
1295 "Failed to get the name of the function. Tool call: {value:?}"
1296 );
1297
1298 #[cfg(feature = "logging")]
1299 error!(target: "stdout", "{}", &err_msg);
1300
1301 return Err(LlamaCoreError::Operation(err_msg));
1302 }
1303 };
1304
1305 let arguments = match value.get("arguments") {
1306 Some(arguments) => {
1307 if arguments.is_string() {
1308 arguments.as_str().unwrap().to_string()
1309 } else if arguments.is_object() {
1310 let map = arguments.as_object().unwrap();
1311
1312 #[cfg(feature = "logging")]
1313 info!(target: "stdout", "func arguments: {map:?}");
1314
1315 serde_json::to_string(map).unwrap()
1316 } else {
1317 serde_json::to_string(arguments).unwrap()
1318 }
1319 }
1320 None => {
1321 let err_msg = format!(
1322 "Failed to get the arguments of the function. Tool call: {value:?}"
1323 );
1324
1325 #[cfg(feature = "logging")]
1326 error!(target: "stdout", "{}", &err_msg);
1327
1328 return Err(LlamaCoreError::Operation(err_msg));
1329 }
1330 };
1331
1332 let function = Function { name, arguments };
1333
1334 let tool_call = ToolCall {
1335 id: "call_abc123".to_string(),
1336 ty: "function".to_string(),
1337 function,
1338 };
1339
1340 tool_calls.push(tool_call);
1341 }
1342
1343 let parsed = if tool_calls.is_empty() {
1344 ParseResult {
1345 raw: input.to_owned(),
1346 content: Some(input.to_owned()),
1347 tool_calls: vec![],
1348 }
1349 } else {
1350 ParseResult {
1351 raw: input.to_owned(),
1352 content: None,
1353 tool_calls,
1354 }
1355 };
1356
1357 #[cfg(feature = "logging")]
1358 info!(target: "stdout", "parsed result: {parsed:?}");
1359
1360 Ok(parsed)
1361 }
1362 Err(e) => {
1363 let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1364
1365 #[cfg(feature = "logging")]
1366 error!(target: "stdout", "{}", &err_msg);
1367
1368 Err(LlamaCoreError::Operation(err_msg))
1369 }
1370 }
1371 }
1372 PromptTemplateType::Llama3Tool => {
1373 #[cfg(feature = "logging")]
1374 info!(target: "stdout", "raw input: {input}");
1375
1376 let re = match regex::Regex::new(r"^\{(.|\r|\n)*\}$") {
1377 Ok(re) => re,
1378 Err(e) => {
1379 let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1380
1381 #[cfg(feature = "logging")]
1382 error!(target: "stdout", "{}", &err_msg);
1383
1384 return Err(LlamaCoreError::Operation(err_msg));
1385 }
1386 };
1387
1388 if re.is_match(input) {
1389 match serde_json::from_str::<serde_json::Value>(input) {
1390 Ok(value) => {
1391 let values: Vec<serde_json::Value> = vec![value];
1392
1393 let mut tool_calls: Vec<ToolCall> = vec![];
1394 for value in values.iter() {
1395 let name = match value.get("name") {
1396 Some(name) => name.to_string().replace("\"", ""),
1397 None => {
1398 let err_msg = format!(
1399 "Failed to get the name of the function. Tool call: {value:?}"
1400 );
1401
1402 #[cfg(feature = "logging")]
1403 error!(target: "stdout", "{}", &err_msg);
1404
1405 return Err(LlamaCoreError::Operation(err_msg));
1406 }
1407 };
1408
1409 let arguments = match value.get("parameters") {
1410 Some(arguments) => arguments.to_string(),
1411 None => {
1412 let err_msg = format!(
1413 "Failed to get the arguments of the function. Tool call: {value:?}"
1414 );
1415
1416 #[cfg(feature = "logging")]
1417 error!(target: "stdout", "{}", &err_msg);
1418
1419 return Err(LlamaCoreError::Operation(err_msg));
1420 }
1421 };
1422
1423 let function = Function { name, arguments };
1424
1425 let tool_call = ToolCall {
1426 id: "call_abc123".to_string(),
1427 ty: "function".to_string(),
1428 function,
1429 };
1430
1431 tool_calls.push(tool_call);
1432 }
1433
1434 let parsed = ParseResult {
1435 raw: input.to_owned(),
1436 content: None,
1437 tool_calls,
1438 };
1439
1440 #[cfg(feature = "logging")]
1441 info!(target: "stdout", "parsed result: {parsed:?}");
1442
1443 Ok(parsed)
1444 }
1445 Err(e) => {
1446 let err_msg =
1447 format!("Failed to deserialize generated tool calls. Reason: {e}");
1448
1449 #[cfg(feature = "logging")]
1450 error!(target: "stdout", "{}", &err_msg);
1451
1452 Err(LlamaCoreError::Operation(err_msg))
1453 }
1454 }
1455 } else {
1456 let parsed = ParseResult {
1457 raw: input.to_owned(),
1458 content: None,
1459 tool_calls: vec![],
1460 };
1461
1462 #[cfg(feature = "logging")]
1463 info!(target: "stdout", "parsed result: {parsed:?}");
1464
1465 Ok(parsed)
1466 }
1467 }
1468 PromptTemplateType::InternLM2Tool => {
1469 #[cfg(feature = "logging")]
1470 info!(target: "stdout", "raw input: {input}");
1471
1472 let blocks: Vec<&str> = input.trim().split("<|action_start|><|plugin|>").collect();
1473
1474 #[cfg(feature = "logging")]
1475 info!(target: "stdout", "blocks: {blocks:?}");
1476
1477 let mut tool_calls: Vec<ToolCall> = vec![];
1478 let mut content = String::new();
1479 for block in blocks {
1480 let block = block.trim();
1481 if !block.is_empty() {
1482 if block.ends_with("<|action_end|>") {
1483 let value = block.trim().trim_end_matches("<|action_end|>");
1484
1485 #[cfg(feature = "logging")]
1486 info!(target: "stdout", "tool call: {value}");
1487
1488 match serde_json::from_str::<serde_json::Value>(value) {
1489 Ok(value) => {
1490 let name = match value.get("name") {
1491 Some(name) => name.to_string().replace("\"", ""),
1492 None => {
1493 let err_msg = format!(
1494 "Failed to get the name of the function. Tool call: {value:?}"
1495 );
1496
1497 #[cfg(feature = "logging")]
1498 error!(target: "stdout", "{}", &err_msg);
1499
1500 return Err(LlamaCoreError::Operation(err_msg));
1501 }
1502 };
1503
1504 let arguments = match value.get("parameters") {
1505 Some(arguments) => arguments.to_string(),
1506 None => {
1507 let err_msg = format!(
1508 "Failed to get the arguments of the function. Tool call: {value:?}"
1509 );
1510
1511 #[cfg(feature = "logging")]
1512 error!(target: "stdout", "{}", &err_msg);
1513
1514 return Err(LlamaCoreError::Operation(err_msg));
1515 }
1516 };
1517
1518 let function = Function { name, arguments };
1519
1520 let tool_call = ToolCall {
1521 id: "call_abc123".to_string(),
1522 ty: "function".to_string(),
1523 function,
1524 };
1525
1526 tool_calls.push(tool_call);
1527 }
1528 Err(e) => {
1529 let err_msg = format!(
1530 "Failed to deserialize generated tool calls. Reason: {e}"
1531 );
1532
1533 #[cfg(feature = "logging")]
1534 error!(target: "stdout", "{}", &err_msg);
1535
1536 return Err(LlamaCoreError::Operation(err_msg));
1537 }
1538 }
1539 } else {
1540 content.push_str(block);
1541 content.push('\n');
1542 }
1543 }
1544 }
1545
1546 let parsed = match content.is_empty() {
1547 true => ParseResult {
1548 raw: input.to_owned(),
1549 content: None,
1550 tool_calls,
1551 },
1552 false => ParseResult {
1553 raw: input.to_owned(),
1554 content: Some(content.trim().to_owned()),
1555 tool_calls,
1556 },
1557 };
1558
1559 #[cfg(feature = "logging")]
1560 info!(target: "stdout", "parsed result: {parsed:?}");
1561
1562 Ok(parsed)
1563 }
1564 PromptTemplateType::NemotronTool => {
1565 #[cfg(feature = "logging")]
1566 info!(target: "stdout", "raw input: {input}");
1567
1568 match regex::Regex::new(r"(?s)<toolcall>\s*(.*?)\s*</toolcall>") {
1569 Ok(re) => {
1570 let mut values: Vec<serde_json::Value> = vec![];
1571 for cap in re.captures_iter(input) {
1572 #[cfg(feature = "logging")]
1573 info!(target: "stdout", "captured: {}", &cap[0]);
1574
1575 #[cfg(feature = "logging")]
1576 info!(target: "stdout", "extracted: {}", &cap[1]);
1577
1578 let matched = cap[1].trim();
1579
1580 #[cfg(feature = "logging")]
1581 info!(target: "stdout", "captured: {matched}");
1582
1583 match serde_json::from_str::<serde_json::Value>(matched) {
1584 Ok(value) => values.push(value),
1585 Err(e) => {
1586 let err_msg = format!(
1587 "Failed to deserialize generated tool calls. Reason: {e}"
1588 );
1589
1590 #[cfg(feature = "logging")]
1591 error!(target: "stdout", "{}", &err_msg);
1592
1593 return Err(LlamaCoreError::Operation(err_msg));
1594 }
1595 }
1596 }
1597
1598 let mut tool_calls: Vec<ToolCall> = vec![];
1599 for value in values.iter() {
1600 let name = match value.get("name") {
1601 Some(name) => name.to_string().replace("\"", ""),
1602 None => {
1603 let err_msg = format!(
1604 "Failed to get the name of the function. Tool call: {value:?}"
1605 );
1606
1607 #[cfg(feature = "logging")]
1608 error!(target: "stdout", "{}", &err_msg);
1609
1610 return Err(LlamaCoreError::Operation(err_msg));
1611 }
1612 };
1613
1614 let arguments = match value.get("arguments") {
1615 Some(arguments) => arguments.to_string(),
1616 None => {
1617 let err_msg = format!(
1618 "Failed to get the arguments of the function. Tool call: {value:?}"
1619 );
1620
1621 #[cfg(feature = "logging")]
1622 error!(target: "stdout", "{}", &err_msg);
1623
1624 return Err(LlamaCoreError::Operation(err_msg));
1625 }
1626 };
1627
1628 let function = Function { name, arguments };
1629
1630 let tool_call = ToolCall {
1631 id: "call_abc123".to_string(),
1632 ty: "function".to_string(),
1633 function,
1634 };
1635
1636 tool_calls.push(tool_call);
1637 }
1638
1639 let parsed = ParseResult {
1640 raw: input.to_owned(),
1641 content: None,
1642 tool_calls,
1643 };
1644
1645 #[cfg(feature = "logging")]
1646 info!(target: "stdout", "parsed result: {parsed:?}");
1647
1648 Ok(parsed)
1649 }
1650 Err(e) => {
1651 let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1652
1653 #[cfg(feature = "logging")]
1654 error!(target: "stdout", "{}", &err_msg);
1655
1656 Err(LlamaCoreError::Operation(err_msg))
1657 }
1658 }
1659 }
1660 PromptTemplateType::FunctionaryV32 => {
1661 #[cfg(feature = "logging")]
1662 info!(target: "stdout", "raw input: {input}");
1663
1664 match regex::Regex::new(r">>>\s*(\w+)\s*\{(.*)\}<\|eot_id\|>") {
1665 Ok(re) => {
1666 let mut tool_calls: Vec<ToolCall> = vec![];
1667 for cap in re.captures_iter(input) {
1668 #[cfg(feature = "logging")]
1669 info!(target: "stdout", "func_name: {}", &cap[1]);
1670
1671 #[cfg(feature = "logging")]
1672 info!(target: "stdout", "arguments: {}", &cap[2]);
1673
1674 let tool_call = ToolCall {
1675 id: "call_abc123".to_string(),
1676 ty: "function".to_string(),
1677 function: Function {
1678 name: cap[1].to_string(),
1679 arguments: cap[2].to_string(),
1680 },
1681 };
1682
1683 tool_calls.push(tool_call);
1684 }
1685
1686 let parsed = ParseResult {
1687 raw: input.to_owned(),
1688 content: None,
1689 tool_calls,
1690 };
1691
1692 #[cfg(feature = "logging")]
1693 info!(target: "stdout", "parsed result: {parsed:?}");
1694
1695 Ok(parsed)
1696 }
1697 Err(e) => {
1698 let warn_msg = format!("Failed to create a regex pattern. Reason: {e}");
1699
1700 #[cfg(feature = "logging")]
1701 warn!(target: "stdout", "{}", &warn_msg);
1702
1703 Ok(ParseResult {
1704 raw: input.to_owned(),
1705 content: None,
1706 tool_calls: vec![],
1707 })
1708 }
1709 }
1710 }
1711 PromptTemplateType::FunctionaryV31 => {
1712 #[cfg(feature = "logging")]
1713 info!(target: "stdout", "raw input: {input}");
1714
1715 match regex::Regex::new(r"<function=(\w+)>\s*(\{.*?\})</function>") {
1716 Ok(re) => {
1717 let mut tool_calls: Vec<ToolCall> = vec![];
1718 for cap in re.captures_iter(input) {
1719 #[cfg(feature = "logging")]
1720 info!(target: "stdout", "func_name: {}", &cap[1]);
1721
1722 #[cfg(feature = "logging")]
1723 info!(target: "stdout", "arguments: {}", &cap[2]);
1724
1725 let tool_call = ToolCall {
1726 id: "call_abc123".to_string(),
1727 ty: "function".to_string(),
1728 function: Function {
1729 name: cap[1].to_string(),
1730 arguments: cap[2].to_string(),
1731 },
1732 };
1733
1734 tool_calls.push(tool_call);
1735 }
1736
1737 let parsed = ParseResult {
1738 raw: input.to_owned(),
1739 content: None,
1740 tool_calls,
1741 };
1742
1743 #[cfg(feature = "logging")]
1744 info!(target: "stdout", "parsed result: {parsed:?}");
1745
1746 Ok(parsed)
1747 }
1748 Err(e) => {
1749 let warn_msg = format!("Failed to create a regex pattern. Reason: {e}");
1750
1751 #[cfg(feature = "logging")]
1752 warn!(target: "stdout", "{}", &warn_msg);
1753
1754 Ok(ParseResult {
1755 raw: input.to_owned(),
1756 content: None,
1757 tool_calls: vec![],
1758 })
1759 }
1760 }
1761 }
1762 PromptTemplateType::MistralSmallTool => {
1763 #[cfg(feature = "logging")]
1764 info!(target: "stdout", "raw input: {input}");
1765
1766 match regex::Regex::new(r"\[TOOL_CALLS\]\s*(\[(.*?)\])") {
1767 Ok(re) => {
1768 let mut values: Vec<serde_json::Value> = vec![];
1769 if let Some(cap) = re.captures(input) {
1770 let matched = cap[1].trim();
1771
1772 #[cfg(feature = "logging")]
1773 info!(target: "stdout", "captured: {matched}");
1774
1775 match serde_json::from_str::<Vec<serde_json::Value>>(matched) {
1776 Ok(vals) => values = vals,
1777 Err(e) => {
1778 let err_msg = format!(
1779 "Failed to deserialize generated tool calls. Reason: {e}"
1780 );
1781
1782 #[cfg(feature = "logging")]
1783 error!(target: "stdout", "{}", &err_msg);
1784
1785 return Err(LlamaCoreError::Operation(err_msg));
1786 }
1787 }
1788 };
1789
1790 let mut tool_calls: Vec<ToolCall> = vec![];
1791 for value in values.iter() {
1792 if let Some(object_map) = value.as_object() {
1793 if object_map.contains_key("function") {
1794 let mut function = Function {
1795 name: String::new(),
1796 arguments: String::new(),
1797 };
1798
1799 let value = object_map.get("function").unwrap();
1800 let func_map = value.as_object().unwrap();
1801 if func_map.contains_key("name") {
1802 let func_name = func_map.get("name").unwrap().as_str().unwrap();
1803 println!("Function name: {func_name:?}");
1804
1805 function.name = func_name.to_string();
1806 }
1807 if func_map.contains_key("arguments") {
1808 let args = func_map.get("arguments").unwrap();
1809 let arguments = args.to_string();
1810 println!("Arguments: {arguments:?}");
1811
1812 function.arguments = arguments;
1813 }
1814
1815 let tool_call = ToolCall {
1816 id: "call_abc123".to_string(),
1817 ty: "function".to_string(),
1818 function,
1819 };
1820
1821 tool_calls.push(tool_call);
1822 } else if object_map.contains_key("name") {
1823 let mut function = Function {
1824 name: String::new(),
1825 arguments: String::new(),
1826 };
1827
1828 let name = object_map.get("name").unwrap().as_str().unwrap();
1829 println!("name: {name:?}");
1830 function.name = name.to_string();
1831
1832 if object_map.contains_key("arguments") {
1833 let args = object_map.get("arguments").unwrap();
1834 let arguments = args.to_string();
1835 println!("Arguments: {arguments:?}");
1836
1837 function.arguments = arguments;
1838 }
1839
1840 let tool_call = ToolCall {
1841 id: "call_abc123".to_string(),
1842 ty: "function".to_string(),
1843 function,
1844 };
1845
1846 tool_calls.push(tool_call);
1847 }
1848 }
1849 }
1850
1851 let parsed = ParseResult {
1852 raw: input.to_owned(),
1853 content: None,
1854 tool_calls,
1855 };
1856
1857 #[cfg(feature = "logging")]
1858 info!(target: "stdout", "parsed result: {parsed:?}");
1859
1860 Ok(parsed)
1861 }
1862 Err(e) => {
1863 let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1864
1865 #[cfg(feature = "logging")]
1866 error!(target: "stdout", "{}", &err_msg);
1867
1868 Err(LlamaCoreError::Operation(err_msg))
1869 }
1870 }
1871 }
1872 PromptTemplateType::Llama4Chat => {
1873 #[cfg(feature = "logging")]
1874 info!(target: "stdout", "raw input: {input:?}");
1875
1876 let mut tool_calls: Vec<ToolCall> = vec![];
1877 if let Ok(value) = serde_json::from_str::<serde_json::Value>(input) {
1878 match value.as_object() {
1879 Some(object_map) => {
1880 #[cfg(feature = "logging")]
1881 debug!(target: "stdout", "object_map: {object_map:?}");
1882
1883 if object_map.contains_key("name") {
1885 let name = object_map.get("name").unwrap().as_str().unwrap();
1886
1887 #[cfg(feature = "logging")]
1888 debug!(target: "stdout", "name: {name:?}");
1889
1890 let mut function = Function {
1891 name: name.to_string(),
1892 arguments: String::new(),
1893 };
1894
1895 if object_map.contains_key("parameters") {
1897 let args = object_map.get("parameters").unwrap();
1898 let arguments = args.to_string();
1899
1900 #[cfg(feature = "logging")]
1901 debug!(target: "stdout", "arguments: {:?}", &arguments);
1902
1903 function.arguments = arguments;
1904 }
1905
1906 tool_calls.push(ToolCall {
1907 id: "call_abc123".to_string(),
1908 ty: "function".to_string(),
1909 function,
1910 });
1911 } else {
1912 let err_msg = format!(
1913 "Failed to get the name of the function. raw input: {input:?}"
1914 );
1915
1916 #[cfg(feature = "logging")]
1917 error!(target: "stdout", "{}", &err_msg);
1918
1919 return Err(LlamaCoreError::Operation(err_msg));
1920 }
1921 }
1922 None => {
1923 let err_msg = format!("Failed to parse the JSON string. JSON: {input}");
1924
1925 #[cfg(feature = "logging")]
1926 error!(target: "stdout", "{}", &err_msg);
1927
1928 return Err(LlamaCoreError::Operation(err_msg));
1929 }
1930 }
1931 }
1932
1933 let parsed = ParseResult {
1934 raw: input.to_owned(),
1935 content: None,
1936 tool_calls,
1937 };
1938
1939 #[cfg(feature = "logging")]
1940 info!(target: "stdout", "parsed result: {parsed:?}");
1941
1942 Ok(parsed)
1943 }
1944 _ => {
1945 let err_msg = format!(
1946 "The tool use is only supported for prompt templates: {}, {}, {}, {}, {}, {}, {}, and {}.",
1947 PromptTemplateType::MistralTool,
1948 PromptTemplateType::ChatMLTool,
1949 PromptTemplateType::GroqLlama3Tool,
1950 PromptTemplateType::Llama3Tool,
1951 PromptTemplateType::InternLM2Tool,
1952 PromptTemplateType::NemotronTool,
1953 PromptTemplateType::FunctionaryV32,
1954 PromptTemplateType::MistralSmallTool,
1955 );
1956
1957 #[cfg(feature = "logging")]
1958 error!(target: "stdout", "{}", &err_msg);
1959
1960 Err(LlamaCoreError::Operation(err_msg))
1961 }
1962 }
1963}
1964
1965fn check_model_metadata(
1966 chat_request: &ChatCompletionRequest,
1967) -> Result<GgmlMetadata, LlamaCoreError> {
1968 let mut should_update = false;
1969 let mut metadata = get_model_metadata(chat_request.model.as_ref())?;
1970
1971 if metadata.prompt_template.is_image_supported() {
1973 if let Some(ChatCompletionRequestMessage::User(user_message)) = chat_request.messages.last()
1974 {
1975 if let ChatCompletionUserMessageContent::Parts(parts) = user_message.content() {
1976 for part in parts {
1977 if let ContentPart::Image(image_part) = part {
1978 let image = image_part.image();
1979
1980 if image.is_url() {
1981 #[cfg(feature = "logging")]
1982 info!(target: "stdout", "The image is provided in URL format.");
1983
1984 let img_path_str = download_image(&image.url)?;
1986
1987 #[cfg(feature = "logging")]
1988 info!(target: "stdout", "The image is saved to {img_path_str}");
1989
1990 metadata.image = Some(img_path_str);
1992
1993 if !should_update {
1994 should_update = true;
1995 }
1996
1997 break;
2000 } else {
2001 #[cfg(feature = "logging")]
2002 info!(target: "stdout", "The image is provided in base64 format.");
2003
2004 break;
2007 }
2008 }
2009 }
2010 }
2011 }
2012 }
2013
2014 if let Some(temp) = chat_request.temperature {
2016 if metadata.temperature != temp {
2017 metadata.temperature = temp;
2019
2020 if !should_update {
2021 should_update = true;
2022 }
2023 }
2024 }
2025
2026 if let Some(top_p) = chat_request.top_p {
2028 if metadata.top_p != top_p {
2029 metadata.top_p = top_p;
2031
2032 if !should_update {
2033 should_update = true;
2034 }
2035 }
2036 }
2037
2038 if let Some(frequency_penalty) = chat_request.frequency_penalty {
2040 if metadata.frequency_penalty != frequency_penalty {
2041 metadata.frequency_penalty = frequency_penalty;
2043
2044 if !should_update {
2045 should_update = true;
2046 }
2047 }
2048 }
2049
2050 if let Some(presence_penalty) = chat_request.presence_penalty {
2052 if metadata.presence_penalty != presence_penalty {
2053 metadata.presence_penalty = presence_penalty;
2055
2056 if !should_update {
2057 should_update = true;
2058 }
2059 }
2060 }
2061
2062 if metadata.embeddings {
2064 metadata.embeddings = false;
2065
2066 if !should_update {
2067 should_update = true;
2068 }
2069 }
2070
2071 if should_update {
2072 #[cfg(feature = "logging")]
2073 info!(target: "stdout", "Update the model metadata.");
2074
2075 update_model_metadata(chat_request.model.as_ref(), &metadata)?;
2077 }
2078
2079 Ok(metadata)
2080}
2081
2082fn update_n_predict(
2083 chat_request: &ChatCompletionRequest,
2084 metadata: &mut GgmlMetadata,
2085 available_completion_tokens: u64,
2086) -> Result<(), LlamaCoreError> {
2087 let mut should_update = false;
2088
2089 #[cfg(feature = "logging")]
2090 info!(target: "stdout", "n_predict: {}", metadata.n_predict);
2091
2092 if let Some(max_completion_tokens) = chat_request.max_completion_tokens {
2098 if metadata.n_predict != max_completion_tokens {
2099 #[cfg(feature = "logging")]
2100 info!(target: "stdout", "Update n_predict with max_completion_tokens from {} to {}", metadata.n_predict, max_completion_tokens);
2101
2102 metadata.n_predict = max_completion_tokens;
2103
2104 if !should_update {
2105 should_update = true;
2106 }
2107 }
2108 }
2109
2110 if metadata.n_predict == -2 {
2112 #[cfg(feature = "logging")]
2113 info!(target: "stdout", "Update n_predict with available_completion_tokens from {} to {}", metadata.n_predict, available_completion_tokens);
2114
2115 metadata.n_predict = available_completion_tokens as i32;
2117
2118 if !should_update {
2119 should_update = true;
2120 }
2121 }
2122
2123 if metadata.n_predict == -1
2124 || (metadata.n_predict > 0 && metadata.n_predict < available_completion_tokens as i32)
2125 || (metadata.n_predict < 0 && metadata.n_predict != -2)
2126 {
2128 #[cfg(feature = "logging")]
2129 info!(target: "stdout", "Update n_predict with available_completion_tokens from {} to {}", metadata.n_predict, available_completion_tokens);
2130
2131 metadata.n_predict = available_completion_tokens as i32;
2133
2134 if !should_update {
2135 should_update = true;
2136 }
2137 }
2138
2139 if should_update {
2140 #[cfg(feature = "logging")]
2141 info!(target: "stdout", "Update the model metadata.");
2142
2143 update_model_metadata(chat_request.model.as_ref(), metadata)?;
2145 }
2146
2147 Ok(())
2148}
2149
2150fn post_process(
2152 output: impl AsRef<str>,
2153 template_ty: &PromptTemplateType,
2154) -> Result<String, String> {
2155 let output = if *template_ty == PromptTemplateType::Baichuan2 {
2156 if output.as_ref().contains("用户:") {
2157 output.as_ref().trim_end_matches("用户:").trim().to_owned()
2158 } else {
2159 output.as_ref().trim().to_owned()
2160 }
2161 } else if *template_ty == PromptTemplateType::OpenChat {
2162 if output.as_ref().contains("<|end_of_turn|>") {
2163 output
2164 .as_ref()
2165 .trim_end_matches("<|end_of_turn|>")
2166 .trim()
2167 .to_owned()
2168 } else {
2169 output.as_ref().trim().to_owned()
2170 }
2171 } else if *template_ty == PromptTemplateType::GemmaInstruct
2172 || *template_ty == PromptTemplateType::Gemma3
2173 {
2174 let s = output.as_ref().trim();
2175 if s.ends_with("<end_of_turn>") {
2176 s.trim_end_matches("<end_of_turn>").trim().to_owned()
2177 } else {
2178 s.to_owned()
2179 }
2180 } else if *template_ty == PromptTemplateType::ChatML
2181 || *template_ty == PromptTemplateType::ChatMLTool
2182 || *template_ty == PromptTemplateType::InternLM2Tool
2183 || *template_ty == PromptTemplateType::MiniCPMV
2184 {
2185 let mut s = output.as_ref().trim();
2186 if s.ends_with("<|endoftext|>") {
2187 s = s.trim_end_matches("<|endoftext|>").trim();
2188 }
2189
2190 if s.starts_with(":") {
2191 s = s.trim_start_matches(":").trim();
2192 }
2193
2194 if s.contains("<|im_start|>") && s.contains("<|im_end|>") {
2195 let idx_start = s.find("<|im_start|>").unwrap();
2196 let idx_end = s.find("<|im_end|>").unwrap();
2197
2198 match idx_start <= idx_end {
2199 true => s.split("<|im_start|>").collect::<Vec<_>>()[0]
2200 .trim()
2201 .to_owned(),
2202 false => s.split("<|im_end|>").collect::<Vec<_>>()[0]
2203 .trim()
2204 .to_owned(),
2205 }
2206 } else if s.contains("<|im_start|>") {
2207 s.split("<|im_start|>").collect::<Vec<_>>()[0]
2208 .trim()
2209 .to_owned()
2210 } else if s.contains("<|im_end|>") {
2211 let output = s.trim_end_matches("<|im_end|>").trim();
2212 if output.starts_with(": ") {
2213 output.trim_start_matches(": ").to_owned()
2214 } else {
2215 output.to_owned()
2216 }
2217 } else {
2218 s.to_owned()
2219 }
2220 } else if *template_ty == PromptTemplateType::Zephyr
2221 || *template_ty == PromptTemplateType::MistralLite
2222 || *template_ty == PromptTemplateType::MistralTool
2223 || *template_ty == PromptTemplateType::MistralInstruct
2224 || *template_ty == PromptTemplateType::MistralSmallChat
2225 || *template_ty == PromptTemplateType::MistralSmallTool
2226 || *template_ty == PromptTemplateType::BreezeInstruct
2227 {
2228 if output.as_ref().contains("</s><") {
2229 output.as_ref().trim_end_matches("</s><").trim().to_owned()
2230 } else if output.as_ref().contains("</s>") {
2231 output
2232 .as_ref()
2233 .strip_suffix("</s>")
2234 .unwrap()
2235 .trim()
2236 .to_owned()
2237 } else {
2238 output.as_ref().trim().to_owned()
2239 }
2240 } else if *template_ty == PromptTemplateType::DeepseekChat {
2241 if output.as_ref().contains("<|end_of_sentence|>") {
2242 output
2243 .as_ref()
2244 .trim_end_matches("<|end_of_sentence|>")
2245 .trim()
2246 .replace("<|end_of_sentence|>", " ")
2247 .trim()
2248 .to_owned()
2249 } else {
2250 output.as_ref().trim().to_owned()
2251 }
2252 } else if *template_ty == PromptTemplateType::HumanAssistant {
2253 if output.as_ref().contains("Human:") {
2254 output.as_ref().trim_end_matches("Human:").trim().to_owned()
2255 } else {
2256 output.as_ref().trim().to_owned()
2257 }
2258 } else if *template_ty == PromptTemplateType::SolarInstruct {
2259 let s = output.as_ref().trim();
2260
2261 if s.starts_with("### Answer") {
2262 let s = s.trim_start_matches("###").trim();
2263
2264 if s.starts_with("Answer:\n") {
2265 s.replace("Answer:\n", "Answer: ")
2266 } else {
2267 s.to_owned()
2268 }
2269 } else {
2270 s.to_owned()
2271 }
2272 } else if *template_ty == PromptTemplateType::Llama2Chat
2273 || *template_ty == PromptTemplateType::NemotronTool
2274 || *template_ty == PromptTemplateType::NemotronChat
2275 {
2276 let s = output.as_ref().trim();
2277 if s.ends_with("</s>") {
2278 s.trim_end_matches("</s>").trim().to_owned()
2279 } else {
2280 s.to_owned()
2281 }
2282 } else if *template_ty == PromptTemplateType::Llama3Chat
2283 || *template_ty == PromptTemplateType::GroqLlama3Tool
2284 || *template_ty == PromptTemplateType::Llama3Tool
2285 || *template_ty == PromptTemplateType::FunctionaryV32
2286 {
2287 let s = output.as_ref().trim();
2288 if s.ends_with("<|eot_id|>") {
2289 s.trim_end_matches("<|eot_id|>").trim().to_owned()
2290 } else {
2291 s.to_owned()
2292 }
2293 } else if *template_ty == PromptTemplateType::Phi3Chat {
2294 let s = output.as_ref().trim();
2295 if s.ends_with("<|end|>") {
2296 s.trim_end_matches("<|end|>").trim().to_owned()
2297 } else {
2298 s.to_owned()
2299 }
2300 } else if *template_ty == PromptTemplateType::Phi4Chat {
2301 let s = output.as_ref().trim();
2302 if s.ends_with("<|im_end|>") {
2303 s.trim_end_matches("<|im_end|>").trim().to_owned()
2304 } else {
2305 s.to_owned()
2306 }
2307 } else if *template_ty == PromptTemplateType::FunctionaryV31 {
2308 let mut s = output.as_ref().trim();
2309 if s.ends_with("<|eot_id|>") {
2310 s = s.trim_end_matches("<|eot_id|>").trim();
2311 }
2312 if s.ends_with("<|eom_id|>") {
2313 s = s.trim_end_matches("<|eom_id|>").trim();
2314 }
2315 s.to_owned()
2316 } else if *template_ty == PromptTemplateType::MoxinChat {
2317 let s = output.as_ref().trim();
2318 if s.ends_with("</s>") {
2319 s.trim_end_matches("</s>").trim().to_owned()
2320 } else if s.ends_with("[INST]") {
2321 s.trim_end_matches("[INST]").trim().to_owned()
2322 } else {
2323 s.to_owned()
2324 }
2325 } else if *template_ty == PromptTemplateType::Falcon3 {
2326 let s = output.as_ref().trim();
2327 if s.ends_with("<|endoftext|>") {
2328 s.trim_end_matches("<|endoftext|>").trim().to_owned()
2329 } else {
2330 s.to_owned()
2331 }
2332 } else if *template_ty == PromptTemplateType::Megrez {
2333 let s = output.as_ref().trim();
2334 if s.ends_with("<|turn_end|>") {
2335 s.trim_end_matches("<|turn_end|>").trim().to_owned()
2336 } else {
2337 s.to_owned()
2338 }
2339 } else if *template_ty == PromptTemplateType::Qwen2vl
2340 || *template_ty == PromptTemplateType::ChatMLThink
2341 {
2342 let mut s = output.as_ref().trim();
2343
2344 if s.starts_with(":") {
2345 s = s.trim_start_matches(":").trim();
2346 }
2347
2348 if s.ends_with("<|im_end|>") {
2349 s.trim_end_matches("<|im_end|>").trim().to_owned()
2350 } else {
2351 s.to_owned()
2352 }
2353 } else if *template_ty == PromptTemplateType::VicunaLlava {
2354 let s = output.as_ref().trim();
2355 if s.ends_with("</s>") {
2356 s.trim_end_matches("</s>").trim().to_owned()
2357 } else {
2358 s.to_owned()
2359 }
2360 } else if *template_ty == PromptTemplateType::ExaoneDeepChat
2361 || *template_ty == PromptTemplateType::ExaoneChat
2362 {
2363 let mut s = output.as_ref().trim();
2364
2365 if s.ends_with("[|endofturn|]") {
2366 s = s.trim_end_matches("[|endofturn|]").trim();
2367 }
2368
2369 s.to_owned()
2370 } else if *template_ty == PromptTemplateType::Llama4Chat {
2371 let mut s = output.as_ref().trim();
2372
2373 if s.ends_with("<|eot|>") {
2374 s = s.trim_end_matches("<|eot|>").trim();
2375 }
2376
2377 s.to_owned()
2378 } else {
2379 output.as_ref().trim().to_owned()
2380 };
2381
2382 Ok(output)
2383}
2384
2385fn build_prompt(
2397 model_name: Option<&String>,
2398 chat_request: &mut ChatCompletionRequest,
2399) -> Result<(String, u64, bool), LlamaCoreError> {
2400 let metadata = get_model_metadata(model_name)?;
2401 let ctx_size = metadata.ctx_size as u64;
2402 let chat_prompt = ChatPrompt::from(metadata.prompt_template);
2403
2404 let max_prompt_tokens = ctx_size * 4 / 5;
2406
2407 loop {
2408 {
2410 }
2423
2424 if chat_request.messages.is_empty() {
2425 let err_msg = "The messages in the chat request are empty.";
2426
2427 #[cfg(feature = "logging")]
2428 error!(target: "stdout", "{err_msg}");
2429
2430 return Err(LlamaCoreError::Operation(err_msg.to_owned()));
2431 }
2432
2433 #[cfg(feature = "logging")]
2434 {
2435 let mut role_chain = String::new();
2436 for (idx, message) in chat_request.messages.iter().enumerate() {
2437 if idx == chat_request.messages.len() - 1 {
2438 role_chain.push_str(&format!("{}", message.role()));
2439 } else {
2440 role_chain.push_str(&format!("{} -> ", message.role()));
2441 }
2442 }
2443 info!(target: "stdout", "Role chain: {role_chain}");
2444 }
2445
2446 let (prompt, tool_use) = match chat_request.tool_choice.as_ref() {
2447 Some(tool_choice) => match tool_choice {
2448 ToolChoice::None => {
2449 match chat_prompt.build_with_tools(&mut chat_request.messages, Some(&[])) {
2450 Ok(prompt) => (prompt, false),
2451 Err(e) => {
2452 let err_msg = format!("Fail to build chat prompts. Reason: {e}");
2453
2454 #[cfg(feature = "logging")]
2455 error!(target: "stdout", "{}", &err_msg);
2456
2457 return Err(LlamaCoreError::Operation(err_msg));
2458 }
2459 }
2460 }
2461 _ => match chat_request.tools.as_ref() {
2462 Some(tools) => match chat_prompt
2463 .build_with_tools(&mut chat_request.messages, Some(tools.as_slice()))
2464 {
2465 Ok(prompt) => (prompt, true),
2466 Err(e) => {
2467 let err_msg = format!("Fail to build chat prompts. Reason: {e}");
2468
2469 #[cfg(feature = "logging")]
2470 error!(target: "stdout", "{}", &err_msg);
2471
2472 return Err(LlamaCoreError::Operation(err_msg));
2473 }
2474 },
2475 None => {
2476 #[cfg(feature = "logging")]
2477 warn!(target: "stdout", "The tool choice without tools is not supported.");
2478
2479 match chat_prompt.build_with_tools(&mut chat_request.messages, None) {
2480 Ok(prompt) => (prompt, false),
2481 Err(e) => {
2482 let err_msg = format!("Fail to build chat prompts. Reason: {e}");
2483
2484 #[cfg(feature = "logging")]
2485 error!(target: "stdout", "{}", &err_msg);
2486
2487 return Err(LlamaCoreError::Operation(err_msg));
2488 }
2489 }
2490 }
2491 },
2492 },
2493 None => match chat_prompt.build_with_tools(&mut chat_request.messages, None) {
2494 Ok(prompt) => (prompt, false),
2495 Err(e) => {
2496 let err_msg = format!("Fail to build chat prompts. Reason: {e}");
2497
2498 #[cfg(feature = "logging")]
2499 error!(target: "stdout", "{}", &err_msg);
2500
2501 return Err(LlamaCoreError::Operation(err_msg));
2502 }
2503 },
2504 };
2505 #[cfg(feature = "logging")]
2506 info!(target: "stdout", "Try to set prompt: {prompt}");
2507
2508 set_prompt(model_name, &prompt)?;
2510
2511 let token_info = get_token_info_by_graph_name(model_name)?;
2513
2514 match token_info.prompt_tokens > max_prompt_tokens {
2515 true => {
2516 match chat_request.messages[0].role() {
2517 ChatCompletionRole::System => {
2518 if chat_request.messages.len() > 2 {
2519 #[cfg(feature = "logging")]
2520 info!(target: "stdout", "Prune chat history: current length {}", chat_request.messages.len());
2521
2522 if chat_request.messages[1].role() == ChatCompletionRole::User {
2525 let user_message = chat_request.messages.remove(1);
2526
2527 #[cfg(feature = "logging")]
2528 info!(target: "stdout", "Remove a user message from the chat history: {user_message:?}");
2529 }
2530
2531 while chat_request.messages[1].role() != ChatCompletionRole::User {
2534 let message = chat_request.messages.remove(1);
2535
2536 #[cfg(feature = "logging")]
2537 info!(target: "stdout", "Remove a {} message from the chat history: {:?}", message.role(), message);
2538
2539 if chat_request.messages.len() == 1 {
2540 let err_msg = format!("The last message in the chat history should be a user message, but found a {} message.", message.role());
2541
2542 #[cfg(feature = "logging")]
2543 error!(target: "stdout", "{err_msg}");
2544
2545 return Err(LlamaCoreError::Operation(err_msg));
2546 }
2547 }
2548 } else if token_info.prompt_tokens > ctx_size {
2549 let err_msg = format!(
2550 "The number of prompt tokens ({}) is greater than the context size ({}). Please increase the context size, or simplify the input message.",
2551 token_info.prompt_tokens, ctx_size
2552 );
2553
2554 #[cfg(feature = "logging")]
2555 error!(target: "stdout", "{}", &err_msg);
2556
2557 return Err(LlamaCoreError::Operation(err_msg));
2558 } else {
2559 return Ok((prompt, ctx_size - token_info.prompt_tokens, tool_use));
2560 }
2561 }
2562 ChatCompletionRole::User => {
2563 if chat_request.messages.len() > 1 {
2564 if chat_request.messages[0].role() == ChatCompletionRole::User {
2569 let user_message = chat_request.messages.remove(0);
2570
2571 #[cfg(feature = "logging")]
2572 info!(target: "stdout", "Remove a user message from the chat history: {user_message:?}");
2573 }
2574
2575 while chat_request.messages[0].role() != ChatCompletionRole::User {
2578 let message = chat_request.messages.remove(0);
2579
2580 #[cfg(feature = "logging")]
2581 info!(target: "stdout", "Remove a {} message from the chat history: {:?}", message.role(), message);
2582
2583 if chat_request.messages.is_empty() {
2584 let err_msg = format!("The last message in the chat history should be a user message, but found a {} message.", message.role());
2585
2586 #[cfg(feature = "logging")]
2587 error!(target: "stdout", "{err_msg}");
2588
2589 return Err(LlamaCoreError::Operation(err_msg));
2590 }
2591 }
2592 } else if token_info.prompt_tokens > ctx_size {
2593 let err_msg = format!(
2594 "The number of prompt tokens ({}) is greater than the context size ({}). Please increase the context size, or simplify the input message.",
2595 token_info.prompt_tokens, ctx_size
2596 );
2597
2598 #[cfg(feature = "logging")]
2599 error!(target: "stdout", "{}", &err_msg);
2600
2601 return Err(LlamaCoreError::Operation(err_msg));
2602 } else {
2603 return Ok((prompt, ctx_size - token_info.prompt_tokens, tool_use));
2604 }
2605 }
2606 _ => {
2607 #[cfg(feature = "logging")]
2608 info!(target: "stdout", "remove a {} message from the message queue", chat_request.messages[0].role());
2609
2610 chat_request.messages.remove(0);
2611 }
2612 }
2613
2614 continue;
2615 }
2616 false => return Ok((prompt, ctx_size - max_prompt_tokens, tool_use)),
2617 }
2618 }
2619}
2620
2621fn download_image(image_url: impl AsRef<str>) -> Result<String, LlamaCoreError> {
2623 let image_url = image_url.as_ref();
2624 let url = reqwest::Url::parse(image_url).map_err(|e| {
2625 let err_msg = format!("Fail to parse the image URL: {image_url}. Reason: {e}");
2626
2627 #[cfg(feature = "logging")]
2628 error!(target: "stdout", "{}", &err_msg);
2629
2630 LlamaCoreError::Operation(err_msg)
2631 })?;
2632
2633 let curr_rt_handle = tokio::runtime::Handle::current();
2634 let res = curr_rt_handle.block_on(async {
2635 let response = reqwest::get(url).await.map_err(|e| {
2636 let err_msg =
2637 format!("Fail to download the image from the URL: {image_url}. Reason: {e}");
2638
2639 #[cfg(feature = "logging")]
2640 error!(target: "stdout", "{}", &err_msg);
2641
2642 LlamaCoreError::Operation(err_msg)
2643 })?;
2644
2645 let fname = response
2646 .url()
2647 .path_segments()
2648 .and_then(|mut segments| segments.next_back())
2649 .and_then(|name| if name.is_empty() { None } else { Some(name) })
2650 .ok_or(LlamaCoreError::Operation(format!(
2651 "Fail to get the file name: {image_url}"
2652 )))?
2653 .to_string();
2654
2655 let id = format!("file_{}", uuid::Uuid::new_v4());
2657 let path = Path::new("archives");
2659 if !path.exists() {
2660 fs::create_dir(path).unwrap();
2661 }
2662 let file_path = path.join(&id);
2663 if !file_path.exists() {
2664 fs::create_dir(&file_path).unwrap();
2665 }
2666 let img_path = file_path.join(&fname);
2667 let mut dest: File = File::create(&img_path).map_err(|e| {
2668 let err_msg = format!("Failed to create archive document {}. {}", &fname, e);
2669
2670 #[cfg(feature = "logging")]
2672 error!(target: "stdout", "{}", &err_msg);
2673
2674 LlamaCoreError::Operation(err_msg)
2675 })?;
2676
2677 let mut content = response.bytes_stream();
2678 while let Some(Ok(item)) = content.next().await {
2679 std::io::copy(&mut item.as_ref(), &mut dest).map_err(|e| {
2680 let err_msg = format!(
2681 "Fail to write the image content to the file: {}. Reason: {}",
2682 &fname, e
2683 );
2684
2685 #[cfg(feature = "logging")]
2686 error!(target: "stdout", "{}", &err_msg);
2687
2688 LlamaCoreError::Operation(err_msg)
2689 })?;
2690 }
2691
2692 Ok(img_path.as_path().to_string_lossy().to_string())
2693 });
2694
2695 res
2696
2697 }
2699
2700fn set_prompt(model_name: Option<&String>, prompt: impl AsRef<str>) -> Result<(), LlamaCoreError> {
2701 let chat_graphs = match CHAT_GRAPHS.get() {
2702 Some(chat_graphs) => chat_graphs,
2703 None => {
2704 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
2705
2706 #[cfg(feature = "logging")]
2707 error!(target: "stdout", "{}", &err_msg);
2708
2709 return Err(LlamaCoreError::Operation(err_msg.into()));
2710 }
2711 };
2712
2713 let mut chat_graphs = chat_graphs.lock().map_err(|e| {
2714 let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
2715
2716 #[cfg(feature = "logging")]
2717 error!(target: "stdout", "{}", &err_msg);
2718
2719 LlamaCoreError::Operation(err_msg)
2720 })?;
2721
2722 match model_name {
2723 Some(model_name) => {
2724 #[cfg(feature = "logging")]
2725 info!(target: "stdout", "Set prompt to the chat model named {model_name}");
2726
2727 match chat_graphs.contains_key(model_name) {
2728 true => {
2729 let graph = chat_graphs.get_mut(model_name).unwrap();
2730 let tensor_data = prompt.as_ref().as_bytes().to_vec();
2731 set_tensor_data_u8(graph, 0, &tensor_data)
2732 }
2733 false => match chat_graphs.iter_mut().next() {
2734 Some((_, graph)) => {
2735 let tensor_data = prompt.as_ref().as_bytes().to_vec();
2736 set_tensor_data_u8(graph, 0, &tensor_data)
2737 }
2738 None => {
2739 let err_msg = "There is no model available in the chat graphs.";
2740
2741 #[cfg(feature = "logging")]
2742 error!(target: "stdout", "{}", &err_msg);
2743
2744 Err(LlamaCoreError::Operation(err_msg.into()))
2745 }
2746 },
2747 }
2748 }
2749 None => {
2750 #[cfg(feature = "logging")]
2751 info!(target: "stdout", "Set prompt to the default chat model.");
2752
2753 match chat_graphs.iter_mut().next() {
2754 Some((_, graph)) => {
2755 let tensor_data = prompt.as_ref().as_bytes().to_vec();
2756 set_tensor_data_u8(graph, 0, &tensor_data)
2757 }
2758 None => {
2759 let err_msg = "There is no model available in the chat graphs while trying to set prompt to the default model.";
2760
2761 #[cfg(feature = "logging")]
2762 error!(target: "stdout", "{err_msg}");
2763
2764 Err(LlamaCoreError::Operation(err_msg.into()))
2765 }
2766 }
2767 }
2768 }
2769}
2770
2771fn get_model_metadata(model_name: Option<&String>) -> Result<GgmlMetadata, LlamaCoreError> {
2773 let chat_graphs = match CHAT_GRAPHS.get() {
2774 Some(chat_graphs) => chat_graphs,
2775 None => {
2776 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
2777
2778 #[cfg(feature = "logging")]
2779 error!(target: "stdout", "{err_msg}");
2780
2781 return Err(LlamaCoreError::Operation(err_msg.into()));
2782 }
2783 };
2784
2785 let chat_graphs = chat_graphs.lock().map_err(|e| {
2786 let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
2787
2788 #[cfg(feature = "logging")]
2789 error!(target: "stdout", "{}", &err_msg);
2790
2791 LlamaCoreError::Operation(err_msg)
2792 })?;
2793
2794 match model_name {
2795 Some(model_name) => match chat_graphs.contains_key(model_name) {
2796 true => {
2797 let graph = chat_graphs.get(model_name).unwrap();
2798 Ok(graph.metadata.clone())
2799 }
2800 false => match chat_graphs.iter().next() {
2801 Some((_, graph)) => Ok(graph.metadata.clone()),
2802 None => {
2803 let err_msg = "There is no model available in the chat graphs.";
2804
2805 #[cfg(feature = "logging")]
2806 error!(target: "stdout", "{}", &err_msg);
2807
2808 Err(LlamaCoreError::Operation(err_msg.into()))
2809 }
2810 },
2811 },
2812 None => match chat_graphs.iter().next() {
2813 Some((_, graph)) => Ok(graph.metadata.clone()),
2814 None => {
2815 let err_msg = "There is no model available in the chat graphs.";
2816
2817 #[cfg(feature = "logging")]
2818 error!(target: "stdout", "{err_msg}");
2819
2820 Err(LlamaCoreError::Operation(err_msg.into()))
2821 }
2822 },
2823 }
2824}
2825
2826fn update_model_metadata(
2827 model_name: Option<&String>,
2828 metadata: &GgmlMetadata,
2829) -> Result<(), LlamaCoreError> {
2830 let config = match serde_json::to_string(metadata) {
2831 Ok(config) => config,
2832 Err(e) => {
2833 let err_msg = format!("Fail to serialize metadata to a JSON string. {e}");
2834
2835 #[cfg(feature = "logging")]
2836 error!(target: "stdout", "{}", &err_msg);
2837
2838 return Err(LlamaCoreError::Operation(err_msg));
2839 }
2840 };
2841
2842 let chat_graphs = match CHAT_GRAPHS.get() {
2843 Some(chat_graphs) => chat_graphs,
2844 None => {
2845 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
2846
2847 #[cfg(feature = "logging")]
2848 error!(target: "stdout", "{err_msg}");
2849
2850 return Err(LlamaCoreError::Operation(err_msg.into()));
2851 }
2852 };
2853
2854 let mut chat_graphs = chat_graphs.lock().map_err(|e| {
2855 let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. Reason: {e}");
2856
2857 #[cfg(feature = "logging")]
2858 error!(target: "stdout", "{}", &err_msg);
2859
2860 LlamaCoreError::Operation(err_msg)
2861 })?;
2862
2863 match model_name {
2864 Some(model_name) => {
2865 match chat_graphs.contains_key(model_name) {
2866 true => {
2867 let graph = chat_graphs.get_mut(model_name).unwrap();
2868 set_tensor_data_u8(graph, 1, config.as_bytes())
2870 }
2871 false => match chat_graphs.iter_mut().next() {
2872 Some((_, graph)) => {
2873 set_tensor_data_u8(graph, 1, config.as_bytes())
2875 }
2876 None => {
2877 let err_msg = "There is no model available in the chat graphs.";
2878
2879 #[cfg(feature = "logging")]
2880 error!(target: "stdout", "{}", &err_msg);
2881
2882 Err(LlamaCoreError::Operation(err_msg.into()))
2883 }
2884 },
2885 }
2886 }
2887 None => {
2888 match chat_graphs.iter_mut().next() {
2889 Some((_, graph)) => {
2890 set_tensor_data_u8(graph, 1, config.as_bytes())
2892 }
2893 None => {
2894 let err_msg = "There is no model available in the chat graphs.";
2895
2896 #[cfg(feature = "logging")]
2897 error!(target: "stdout", "{err_msg}");
2898
2899 Err(LlamaCoreError::Operation(err_msg.into()))
2900 }
2901 }
2902 }
2903 }
2904}
2905
2906fn reset_model_metadata(model_name: Option<&String>) -> Result<(), LlamaCoreError> {
2907 let metadata = get_model_metadata(model_name)?;
2909
2910 update_model_metadata(model_name, &metadata)
2912}
2913
2914#[derive(Debug, Clone, Copy, PartialEq, Eq)]
2915enum ContextFullState {
2916 Message,
2917 Usage,
2918 Done,
2919 EndOfSequence,
2920}
2921
2922#[derive(Debug, Clone, Copy, PartialEq, Eq)]
2923enum StreamState {
2924 Usage,
2925 NoUsage,
2926 Done,
2927 EndOfSequence,
2928}
2929
2930#[derive(Debug, Clone, Copy, PartialEq, Eq)]
2931enum PromptTooLongState {
2932 Message,
2933 Usage,
2934 Done,
2935 EndOfSequence,
2936}
2937
2938struct ChatStream {
2939 id: String,
2940 model: Option<String>,
2941 include_usage: bool,
2942 context_full_state: ContextFullState,
2943 prompt_too_long_state: PromptTooLongState,
2944 stream_state: StreamState,
2945 cache: Option<VecDeque<String>>,
2946 is_waiting: bool,
2947 has_lock: bool,
2948}
2949impl ChatStream {
2950 fn new(
2951 model: Option<String>,
2952 id: String,
2953 include_usage: bool,
2954 cache: Option<Vec<String>>,
2955 ) -> Self {
2956 let has_lock = CHAT_STREAM_ACTIVE
2958 .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
2959 .is_ok();
2960
2961 #[cfg(feature = "logging")]
2962 if !has_lock {
2963 info!(target: "stdout", "Lock acquisition failed in ChatStream::new, creating with waiting status");
2964 }
2965
2966 ChatStream {
2967 id,
2968 model,
2969 include_usage,
2970 context_full_state: ContextFullState::Message,
2971 prompt_too_long_state: PromptTooLongState::Message,
2972 stream_state: if include_usage {
2973 StreamState::Usage
2974 } else {
2975 StreamState::NoUsage
2976 },
2977 cache: cache.map(VecDeque::from),
2978 is_waiting: !has_lock,
2979 has_lock,
2980 }
2981 }
2982
2983 fn try_acquire_lock(&mut self) -> bool {
2985 if self.has_lock {
2986 return true;
2987 }
2988
2989 let acquired = CHAT_STREAM_ACTIVE
2990 .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
2991 .is_ok();
2992
2993 if acquired {
2994 self.has_lock = true;
2995 self.is_waiting = false;
2996 }
2997
2998 acquired
2999 }
3000}
3001impl Drop for ChatStream {
3002 fn drop(&mut self) {
3003 if self.has_lock || (self.cache.is_none() && !self.is_waiting) {
3005 #[cfg(feature = "logging")]
3006 info!(target: "stdout", "Cleaning up context for ChatStream {}", &self.id);
3007
3008 match &self.model {
3009 Some(model_name) => {
3010 match CHAT_GRAPHS.get() {
3011 Some(chat_graphs) => {
3012 match chat_graphs.lock() {
3013 Ok(mut chat_graphs) => match chat_graphs.contains_key(model_name) {
3014 true => {
3015 let graph = chat_graphs.get_mut(model_name).unwrap();
3016
3017 if let Err(e) = graph.finish_single() {
3019 let err_msg = format!(
3020 "Failed to clean up the context. Reason: {e}"
3021 );
3022
3023 #[cfg(feature = "logging")]
3024 error!(target: "stdout", "{}", &err_msg);
3025
3026 #[cfg(not(feature = "logging"))]
3027 println!(
3028 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3029 &err_msg
3030 );
3031 }
3032 }
3033 false => match chat_graphs.iter_mut().next() {
3034 Some((_, graph)) => {
3035 if let Err(e) = graph.finish_single() {
3037 let err_msg = format!(
3038 "Failed to clean up the context. Reason: {e}"
3039 );
3040
3041 #[cfg(feature = "logging")]
3042 error!(target: "stdout", "{}", &err_msg);
3043
3044 #[cfg(not(feature = "logging"))]
3045 println!(
3046 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3047 &err_msg
3048 );
3049 }
3050 }
3051 None => {
3052 let err_msg =
3053 "There is no model available in the chat graphs.";
3054
3055 #[cfg(feature = "logging")]
3056 error!(target: "stdout", "{}", &err_msg);
3057
3058 #[cfg(not(feature = "logging"))]
3059 println!(
3060 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3061 &err_msg
3062 );
3063 }
3064 },
3065 },
3066 Err(e) => {
3067 let err_msg =
3068 format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
3069
3070 #[cfg(feature = "logging")]
3071 error!(target: "stdout", "{}", &err_msg);
3072
3073 #[cfg(not(feature = "logging"))]
3074 println!(
3075 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3076 &err_msg
3077 );
3078 }
3079 }
3080 }
3081 None => {
3082 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
3083
3084 #[cfg(feature = "logging")]
3085 error!(target: "stdout", "{}", &err_msg);
3086
3087 #[cfg(not(feature = "logging"))]
3088 println!(
3089 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3090 &err_msg
3091 );
3092 }
3093 };
3094 }
3095 None => {
3096 match CHAT_GRAPHS.get() {
3097 Some(chat_graphs) => {
3098 match chat_graphs.lock() {
3099 Ok(mut chat_graphs) => match chat_graphs.iter_mut().next() {
3100 Some((_, graph)) => {
3101 if let Err(e) = graph.finish_single() {
3103 let err_msg = format!(
3104 "Failed to clean up the context. Reason: {e}"
3105 );
3106
3107 #[cfg(feature = "logging")]
3108 error!(target: "stdout", "{}", &err_msg);
3109
3110 #[cfg(not(feature = "logging"))]
3111 println!(
3112 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3113 &err_msg
3114 );
3115 }
3116 }
3117 None => {
3118 let err_msg =
3119 "There is no model available in the chat graphs.";
3120
3121 #[cfg(feature = "logging")]
3122 error!(target: "stdout", "{err_msg}");
3123
3124 #[cfg(not(feature = "logging"))]
3125 println!(
3126 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3127 err_msg
3128 );
3129 }
3130 },
3131 Err(e) => {
3132 let err_msg =
3133 format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
3134
3135 #[cfg(feature = "logging")]
3136 error!(target: "stdout", "{}", &err_msg);
3137
3138 #[cfg(not(feature = "logging"))]
3139 println!(
3140 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3141 &err_msg
3142 );
3143 }
3144 }
3145 }
3146 None => {
3147 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
3148
3149 #[cfg(feature = "logging")]
3150 error!(target: "stdout", "{}", &err_msg);
3151
3152 #[cfg(not(feature = "logging"))]
3153 println!(
3154 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3155 &err_msg
3156 );
3157 }
3158 };
3159 }
3160 }
3161
3162 #[cfg(feature = "logging")]
3163 info!(target: "stdout", "Model context cleanup done!");
3164 }
3165
3166 if let Err(e) = reset_model_metadata(self.model.as_ref()) {
3168 let err_msg = format!("Fail to reset model metadata. Reason: {e}");
3169
3170 #[cfg(feature = "logging")]
3171 error!(target: "stdout", "{}", &err_msg);
3172
3173 #[cfg(not(feature = "logging"))]
3174 println!("[ERROR][llama_core] {}", &err_msg);
3175 }
3176 #[cfg(feature = "logging")]
3177 info!(target: "stdout", "Model metadata reset done!");
3178
3179 if self.has_lock {
3181 CHAT_STREAM_ACTIVE.store(false, Ordering::SeqCst);
3183
3184 #[cfg(feature = "logging")]
3185 info!(target: "stdout", "Lock from ChatStream {} released", &self.id);
3186
3187 if let Ok(mut queue) = get_chat_stream_waker_queue().lock() {
3189 if let Some(waker) = queue.pop_front() {
3190 #[cfg(feature = "logging")]
3191 info!(target: "stdout", "Waking up a waiting ChatStream");
3192
3193 waker.wake();
3194 }
3195 }
3196 }
3197 }
3198}
3199impl futures::Stream for ChatStream {
3200 type Item = Result<String, LlamaCoreError>;
3201
3202 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
3203 let this = self.get_mut();
3204
3205 if this.is_waiting {
3207 if !this.try_acquire_lock() {
3208 if let Ok(mut queue) = get_chat_stream_waker_queue().lock() {
3210 queue.retain(|w| !w.will_wake(cx.waker()));
3212 queue.push_back(cx.waker().clone());
3214
3215 #[cfg(feature = "logging")]
3216 debug!(target: "stdout", "ChatStream {} is waiting for lock, added waker to queue", &this.id);
3217 }
3218
3219 return Poll::Pending;
3220 }
3221
3222 #[cfg(feature = "logging")]
3223 info!(target: "stdout", "ChatStream {} acquired lock and is now active", &this.id);
3224 }
3226
3227 if !this.has_lock && !this.try_acquire_lock() {
3229 this.is_waiting = true;
3231
3232 if let Ok(mut queue) = get_chat_stream_waker_queue().lock() {
3234 queue.retain(|w| !w.will_wake(cx.waker()));
3235 queue.push_back(cx.waker().clone());
3236 }
3237
3238 return Poll::Pending;
3239 }
3240
3241 if this.cache.is_none() {
3242 let res = compute_stream(
3243 this.model.clone(),
3244 this.id.clone(),
3245 this.include_usage,
3246 &mut this.prompt_too_long_state,
3247 &mut this.context_full_state,
3248 &mut this.stream_state,
3249 );
3250
3251 match res {
3252 Ok(x) => {
3253 #[cfg(feature = "logging")]
3254 info!(target: "stdout", "next item for ChatStream {}: {}", &this.id, &x);
3255
3256 if x != "[GGML] End of sequence" && !x.is_empty() {
3257 Poll::Ready(Some(Ok(x)))
3258 } else {
3259 Poll::Ready(None)
3261 }
3262 }
3263 Err(e) => Poll::Ready(Some(Err(e))),
3264 }
3265 } else {
3266 let x = this.cache.as_mut().unwrap().pop_front();
3267
3268 #[cfg(feature = "logging")]
3269 info!(target: "stdout", "Get the next item from the cache for ChatStream {}: {:?}", &this.id, &x);
3270
3271 match x {
3272 Some(x) => Poll::Ready(Some(Ok(x))),
3273 None => Poll::Ready(None),
3274 }
3275 }
3276 }
3277}
3278
3279fn get_chat_stream_waker_queue() -> &'static Mutex<VecDeque<Waker>> {
3281 CHAT_STREAM_WAKER_QUEUE.get_or_init(|| {
3282 #[cfg(feature = "logging")]
3283 info!(target: "stdout", "Initializing ChatStream waker queue");
3284 Mutex::new(VecDeque::new())
3285 })
3286}
3287
3288fn compute_stream(
3289 model_name: Option<String>,
3290 id: String,
3291 include_usage: bool,
3292 prompt_too_long_state: &mut PromptTooLongState,
3293 context_full_state: &mut ContextFullState,
3294 stream_state: &mut StreamState,
3295) -> Result<String, LlamaCoreError> {
3296 #[cfg(feature = "logging")]
3297 info!(target: "stdout", "Computing stream chunk for ChatStream {}", &id);
3298
3299 #[cfg(feature = "logging")]
3300 debug!(target: "stdout", "prompt_too_long_state: {:?}", *prompt_too_long_state);
3301 #[cfg(feature = "logging")]
3302 debug!(target: "stdout", "context_full_state: {:?}", *context_full_state);
3303 #[cfg(feature = "logging")]
3304 debug!(target: "stdout", "stream_state: {:?}", *stream_state);
3305
3306 if *prompt_too_long_state == PromptTooLongState::EndOfSequence
3307 || *context_full_state == ContextFullState::EndOfSequence
3308 || *stream_state == StreamState::EndOfSequence
3309 {
3310 #[cfg(feature = "logging")]
3311 info!(target: "stdout", "Return the chat stream chunk!");
3312
3313 return Ok("[GGML] End of sequence".to_string());
3314 }
3315
3316 let chat_graphs = match CHAT_GRAPHS.get() {
3317 Some(chat_graphs) => chat_graphs,
3318 None => {
3319 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
3320
3321 #[cfg(feature = "logging")]
3322 error!(target: "stdout", "{}", &err_msg);
3323
3324 return Err(LlamaCoreError::Operation(err_msg.into()));
3325 }
3326 };
3327
3328 let mut chat_graphs = chat_graphs.lock().map_err(|e| {
3330 let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
3331
3332 #[cfg(feature = "logging")]
3333 error!(target: "stdout", "{}", &err_msg);
3334
3335 LlamaCoreError::Operation(err_msg)
3336 })?;
3337
3338 let res = match &model_name {
3340 Some(model_name) => {
3341 match chat_graphs.contains_key(model_name) {
3342 true => {
3343 let graph = chat_graphs.get_mut(model_name).unwrap();
3344 match graph.compute_single() {
3346 Ok(_) => {
3347 #[cfg(feature = "logging")]
3348 debug!(target: "stdout", "Compute the chat stream chunk successfully.");
3349
3350 match stream_state {
3352 StreamState::Usage | StreamState::NoUsage => {
3353 let output_buffer =
3355 get_output_buffer_single(graph, OUTPUT_TENSOR)?;
3356
3357 #[cfg(feature = "logging")]
3358 info!(target: "stdout", "retrieved the output buffer");
3359
3360 let output = match String::from_utf8(output_buffer.clone()) {
3362 Ok(token) => token,
3363 Err(_) => {
3364 let mutex = CACHED_UTF8_ENCODINGS
3365 .get_or_init(|| Mutex::new(Vec::new()));
3366 let mut cached_encodings = mutex.lock().map_err(|e| {
3367 let err_msg = format!(
3368 "Fail to acquire the lock of `UTF8_ENCODINGS`. Reason: {e}"
3369 );
3370
3371 #[cfg(feature = "logging")]
3372 error!(target: "stdout", "{}", &err_msg);
3373
3374
3375 LlamaCoreError::Operation(err_msg)
3376 })?;
3377
3378 cached_encodings.extend_from_slice(&output_buffer[..]);
3380
3381 match String::from_utf8(cached_encodings.to_vec()) {
3382 Ok(token) => {
3383 cached_encodings.clear();
3385
3386 token
3387 }
3388 Err(e) => {
3389 if cached_encodings.len() > 4 {
3391 let err_msg = format!("Fail to convert a vector of bytes to string. The length of the utf8 bytes exceeds 4. {e}");
3392
3393 #[cfg(feature = "logging")]
3394 error!(target: "stdout", "{}", &err_msg);
3395
3396 #[cfg(feature = "logging")]
3397 error!(target: "stdout", "The cached buffer: {:?}", &cached_encodings[..]);
3398
3399 cached_encodings.clear();
3406
3407 String::from("")
3408 } else {
3409 let warn_msg = format!("Fail to convert a vector of bytes to string. {e}");
3410
3411 #[cfg(feature = "logging")]
3412 warn!(target: "stdout", "{}", &warn_msg);
3413
3414 String::from("")
3415 }
3416 }
3417 }
3418 }
3419 };
3420
3421 #[cfg(feature = "logging")]
3422 info!(target: "stdout", "decoded the output buffer");
3423
3424 let created = SystemTime::now()
3425 .duration_since(std::time::UNIX_EPOCH)
3426 .map_err(|e| {
3427 let err_msg = format!(
3428 "Failed to get the current time. Reason: {e}"
3429 );
3430
3431 #[cfg(feature = "logging")]
3432 error!(target: "stdout", "{}", &err_msg);
3433
3434 LlamaCoreError::Operation(err_msg)
3435 })?;
3436
3437 let chat_completion_chunk = ChatCompletionChunk {
3438 id,
3439 object: "chat.completion.chunk".to_string(),
3440 created: created.as_secs(),
3441 model: graph.name().to_owned(),
3442 system_fingerprint: "fp_44709d6fcb".to_string(),
3443 choices: vec![ChatCompletionChunkChoice {
3444 index: 0,
3445 delta: ChatCompletionChunkChoiceDelta {
3446 role: ChatCompletionRole::Assistant,
3447 content: Some(output),
3448 tool_calls: vec![],
3449 },
3450 logprobs: None,
3451 finish_reason: None,
3452 }],
3453 usage: None,
3454 };
3455
3456 #[cfg(feature = "logging")]
3457 info!(target: "stdout", "created chat completion chunk");
3458
3459 let chunk_str = serde_json::to_string(&chat_completion_chunk)
3461 .map_err(|e| {
3462 let err_msg = format!(
3463 "Failed to serialize chat completion chunk. Reason: {e}"
3464 );
3465
3466 #[cfg(feature = "logging")]
3467 error!(target: "stdout", "{}", &err_msg);
3468
3469 LlamaCoreError::Operation(err_msg)
3470 })?;
3471
3472 Ok(format!("data: {chunk_str}\n\n"))
3473 }
3474 StreamState::Done => {
3475 *stream_state = StreamState::EndOfSequence;
3476
3477 Ok("data: [DONE]\n\n".to_string())
3478 }
3479 StreamState::EndOfSequence => {
3480 Ok("[GGML] End of sequence".to_string())
3481 }
3482 }
3483 }
3484 Err(wasmedge_wasi_nn::Error::BackendError(
3485 wasmedge_wasi_nn::BackendError::EndOfSequence,
3486 )) => {
3487 #[cfg(feature = "logging")]
3488 debug!(target: "stdout", "End of sequence");
3489
3490 match stream_state {
3491 StreamState::Usage => {
3492 *stream_state = StreamState::Done;
3493
3494 let token_info = get_token_info_by_graph(graph)?;
3496
3497 let usage = Some(Usage {
3498 prompt_tokens: token_info.prompt_tokens,
3499 completion_tokens: token_info.completion_tokens,
3500 total_tokens: token_info.prompt_tokens
3501 + token_info.completion_tokens,
3502 });
3503
3504 #[cfg(feature = "logging")]
3505 info!(target: "stdout", "token_info: {} prompt tokens, {} completion tokens", token_info.prompt_tokens, token_info.completion_tokens);
3506
3507 let created = SystemTime::now()
3508 .duration_since(std::time::UNIX_EPOCH)
3509 .map_err(|e| {
3510 let err_msg = format!(
3511 "Failed to get the current time. Reason: {e}"
3512 );
3513
3514 #[cfg(feature = "logging")]
3515 error!(target: "stdout", "{}", &err_msg);
3516
3517 LlamaCoreError::Operation(err_msg)
3518 })?;
3519
3520 let chat_completion_chunk = ChatCompletionChunk {
3521 id,
3522 object: "chat.completion.chunk".to_string(),
3523 created: created.as_secs(),
3524 model: graph.name().to_owned(),
3525 system_fingerprint: "fp_44709d6fcb".to_string(),
3526 choices: vec![],
3527 usage,
3528 };
3529
3530 let chunk_str = serde_json::to_string(&chat_completion_chunk)
3532 .map_err(|e| {
3533 let err_msg = format!(
3534 "Failed to serialize chat completion chunk. Reason: {e}"
3535 );
3536
3537 #[cfg(feature = "logging")]
3538 error!(target: "stdout", "{}", &err_msg);
3539
3540 LlamaCoreError::Operation(err_msg)
3541 })?;
3542
3543 Ok(format!("data: {chunk_str}\n\n"))
3544 }
3545 StreamState::Done | StreamState::NoUsage => {
3546 *stream_state = StreamState::EndOfSequence;
3547
3548 Ok("data: [DONE]\n\n".to_string())
3549 }
3550 StreamState::EndOfSequence => {
3551 Ok("[GGML] End of sequence".to_string())
3552 }
3553 }
3554 }
3555 Err(wasmedge_wasi_nn::Error::BackendError(
3556 wasmedge_wasi_nn::BackendError::ContextFull,
3557 )) => {
3558 #[cfg(feature = "logging")]
3559 debug!(target: "stdout", "Context full");
3560
3561 match context_full_state {
3562 ContextFullState::Message => {
3563 match include_usage {
3564 true => *context_full_state = ContextFullState::Usage,
3565 false => *context_full_state = ContextFullState::Done,
3566 }
3567
3568 let created = SystemTime::now()
3569 .duration_since(std::time::UNIX_EPOCH)
3570 .map_err(|e| {
3571 let err_msg = format!(
3572 "Failed to get the current time. Reason: {e}"
3573 );
3574
3575 #[cfg(feature = "logging")]
3576 error!(target: "stdout", "{}", &err_msg);
3577
3578 LlamaCoreError::Operation(err_msg)
3579 })?;
3580
3581 let chat_completion_chunk = ChatCompletionChunk {
3582 id,
3583 object: "chat.completion.chunk".to_string(),
3584 created: created.as_secs(),
3585 model: graph.name().to_owned(),
3586 system_fingerprint: "fp_44709d6fcb".to_string(),
3587 choices: vec![ChatCompletionChunkChoice {
3588 index: 0,
3589 delta: ChatCompletionChunkChoiceDelta {
3590 role: ChatCompletionRole::Assistant,
3591 content: Some(
3592 "<|WASMEDGE-GGML-CONTEXT-FULL|>".to_string(),
3593 ),
3594 tool_calls: vec![],
3595 },
3596 logprobs: None,
3597 finish_reason: Some(FinishReason::length),
3598 }],
3599 usage: None,
3600 };
3601
3602 let chunk_str = serde_json::to_string(&chat_completion_chunk)
3604 .map_err(|e| {
3605 let err_msg = format!(
3606 "Failed to serialize chat completion chunk. Reason: {e}"
3607 );
3608
3609 #[cfg(feature = "logging")]
3610 error!(target: "stdout", "{}", &err_msg);
3611
3612 LlamaCoreError::Operation(err_msg)
3613 })?;
3614
3615 Ok(format!("data: {chunk_str}\n\n"))
3616 }
3617 ContextFullState::Usage => {
3618 *context_full_state = ContextFullState::Done;
3619
3620 let token_info = get_token_info_by_graph(graph)?;
3622
3623 let usage = Some(Usage {
3624 prompt_tokens: token_info.prompt_tokens,
3625 completion_tokens: token_info.completion_tokens,
3626 total_tokens: token_info.prompt_tokens
3627 + token_info.completion_tokens,
3628 });
3629
3630 let created = SystemTime::now()
3631 .duration_since(std::time::UNIX_EPOCH)
3632 .map_err(|e| {
3633 let err_msg = format!(
3634 "Failed to get the current time. Reason: {e}"
3635 );
3636
3637 #[cfg(feature = "logging")]
3638 error!(target: "stdout", "{}", &err_msg);
3639
3640 LlamaCoreError::Operation(err_msg)
3641 })?;
3642
3643 let chat_completion_chunk = ChatCompletionChunk {
3644 id,
3645 object: "chat.completion.chunk".to_string(),
3646 created: created.as_secs(),
3647 model: graph.name().to_owned(),
3648 system_fingerprint: "fp_44709d6fcb".to_string(),
3649 choices: vec![],
3650 usage,
3651 };
3652
3653 let chunk_str = serde_json::to_string(&chat_completion_chunk)
3655 .map_err(|e| {
3656 let err_msg = format!(
3657 "Failed to serialize chat completion chunk. Reason: {e}"
3658 );
3659
3660 #[cfg(feature = "logging")]
3661 error!(target: "stdout", "{}", &err_msg);
3662
3663 LlamaCoreError::Operation(err_msg)
3664 })?;
3665
3666 Ok(format!("data: {chunk_str}\n\n"))
3667 }
3668 ContextFullState::Done => {
3669 *context_full_state = ContextFullState::EndOfSequence;
3670
3671 Ok("data: [DONE]\n\n".to_string())
3672 }
3673 ContextFullState::EndOfSequence => {
3674 Ok("[GGML] End of sequence".to_string())
3675 }
3676 }
3677 }
3678 Err(wasmedge_wasi_nn::Error::BackendError(
3679 wasmedge_wasi_nn::BackendError::PromptTooLong,
3680 )) => {
3681 #[cfg(feature = "logging")]
3682 debug!(target: "stdout", "Prompt too long");
3683
3684 match prompt_too_long_state {
3685 PromptTooLongState::Message => {
3686 match include_usage {
3687 true => *prompt_too_long_state = PromptTooLongState::Usage,
3688 false => *prompt_too_long_state = PromptTooLongState::Done,
3689 }
3690
3691 let created = SystemTime::now()
3692 .duration_since(std::time::UNIX_EPOCH)
3693 .map_err(|e| {
3694 let err_msg = format!(
3695 "Failed to get the current time. Reason: {e}"
3696 );
3697
3698 #[cfg(feature = "logging")]
3699 error!(target: "stdout", "{}", &err_msg);
3700
3701 LlamaCoreError::Operation(err_msg)
3702 })?;
3703
3704 let chat_completion_chunk = ChatCompletionChunk {
3705 id,
3706 object: "chat.completion.chunk".to_string(),
3707 created: created.as_secs(),
3708 model: graph.name().to_owned(),
3709 system_fingerprint: "fp_44709d6fcb".to_string(),
3710 choices: vec![ChatCompletionChunkChoice {
3711 index: 0,
3712 delta: ChatCompletionChunkChoiceDelta {
3713 role: ChatCompletionRole::Assistant,
3714 content: None,
3715 tool_calls: vec![],
3716 },
3717 logprobs: None,
3718 finish_reason: Some(FinishReason::length),
3719 }],
3720 usage: None,
3721 };
3722
3723 let chunk_str = serde_json::to_string(&chat_completion_chunk)
3725 .map_err(|e| {
3726 let err_msg = format!(
3727 "Failed to serialize chat completion chunk. Reason: {e}"
3728 );
3729
3730 #[cfg(feature = "logging")]
3731 error!(target: "stdout", "{}", &err_msg);
3732
3733 LlamaCoreError::Operation(err_msg)
3734 })?;
3735
3736 Ok(format!("data: {chunk_str}\n\n"))
3737 }
3738 PromptTooLongState::Usage => {
3739 *prompt_too_long_state = PromptTooLongState::Done;
3740
3741 let token_info = get_token_info_by_graph(graph)?;
3743
3744 let usage = Some(Usage {
3745 prompt_tokens: token_info.prompt_tokens,
3746 completion_tokens: token_info.completion_tokens,
3747 total_tokens: token_info.prompt_tokens
3748 + token_info.completion_tokens,
3749 });
3750
3751 let created = SystemTime::now()
3752 .duration_since(std::time::UNIX_EPOCH)
3753 .map_err(|e| {
3754 let err_msg = format!(
3755 "Failed to get the current time. Reason: {e}"
3756 );
3757
3758 #[cfg(feature = "logging")]
3759 error!(target: "stdout", "{}", &err_msg);
3760
3761 LlamaCoreError::Operation(err_msg)
3762 })?;
3763
3764 let chat_completion_chunk = ChatCompletionChunk {
3765 id,
3766 object: "chat.completion.chunk".to_string(),
3767 created: created.as_secs(),
3768 model: graph.name().to_owned(),
3769 system_fingerprint: "fp_44709d6fcb".to_string(),
3770 choices: vec![],
3771 usage,
3772 };
3773
3774 let chunk_str = serde_json::to_string(&chat_completion_chunk)
3776 .map_err(|e| {
3777 let err_msg = format!(
3778 "Failed to serialize chat completion chunk. Reason: {e}"
3779 );
3780
3781 #[cfg(feature = "logging")]
3782 error!(target: "stdout", "{}", &err_msg);
3783
3784 LlamaCoreError::Operation(err_msg)
3785 })?;
3786
3787 Ok(format!("data: {chunk_str}\n\n"))
3788 }
3789 PromptTooLongState::Done => {
3790 *prompt_too_long_state = PromptTooLongState::EndOfSequence;
3791
3792 Ok("data: [DONE]\n\n".to_string())
3793 }
3794 PromptTooLongState::EndOfSequence => {
3795 Ok("[GGML] End of sequence".to_string())
3796 }
3797 }
3798 }
3799 Err(e) => {
3800 let err_msg =
3801 format!("Failed to compute the chat completion. Reason: {e}");
3802
3803 #[cfg(feature = "logging")]
3804 error!(target: "stdout", "{}", &err_msg);
3805
3806 Err(LlamaCoreError::Backend(BackendError::ComputeSingle(
3807 err_msg,
3808 )))
3809 }
3810 }
3811 }
3812 false => {
3813 match chat_graphs.iter_mut().next() {
3814 Some((_, graph)) => {
3815 match graph.compute_single() {
3817 Ok(_) => {
3818 #[cfg(feature = "logging")]
3819 debug!(target: "stdout", "Compute the chat stream chunk successfully.");
3820
3821 match stream_state {
3822 StreamState::Usage | StreamState::NoUsage => {
3823 let output_buffer =
3825 get_output_buffer_single(graph, OUTPUT_TENSOR)?;
3826
3827 #[cfg(feature = "logging")]
3828 info!(target: "stdout", "retrieved the output buffer");
3829
3830 let output = match String::from_utf8(
3832 output_buffer.clone(),
3833 ) {
3834 Ok(token) => token,
3835 Err(_) => {
3836 let mutex = CACHED_UTF8_ENCODINGS
3837 .get_or_init(|| Mutex::new(Vec::new()));
3838 let mut cached_encodings = mutex.lock().map_err(|e| {
3839 let err_msg = format!(
3840 "Fail to acquire the lock of `UTF8_ENCODINGS`. Reason: {e}"
3841 );
3842
3843 #[cfg(feature = "logging")]
3844 error!(target: "stdout", "{}", &err_msg);
3845
3846
3847 LlamaCoreError::Operation(err_msg)
3848 })?;
3849
3850 cached_encodings
3852 .extend_from_slice(&output_buffer[..]);
3853
3854 match String::from_utf8(
3855 cached_encodings.to_vec(),
3856 ) {
3857 Ok(token) => {
3858 cached_encodings.clear();
3860
3861 token
3862 }
3863 Err(e) => {
3864 if cached_encodings.len() > 4 {
3866 let err_msg = format!("Fail to convert a vector of bytes to string. The length of the utf8 bytes exceeds 4. {e}");
3867
3868 #[cfg(feature = "logging")]
3869 error!(target: "stdout", "{}", &err_msg);
3870
3871 #[cfg(feature = "logging")]
3872 error!(target: "stdout", "The cached buffer: {:?}", &cached_encodings[..]);
3873
3874 cached_encodings.clear();
3882
3883 String::from("")
3884 } else {
3885 let warn_msg = format!("Fail to convert a vector of bytes to string. {e}");
3886
3887 #[cfg(feature = "logging")]
3888 warn!(target: "stdout", "{}", &warn_msg);
3889
3890 String::from("")
3891 }
3892 }
3893 }
3894 }
3895 };
3896
3897 #[cfg(feature = "logging")]
3898 info!(target: "stdout", "decoded the output buffer");
3899
3900 let created = SystemTime::now()
3901 .duration_since(std::time::UNIX_EPOCH)
3902 .map_err(|e| {
3903 let err_msg = format!(
3904 "Failed to get the current time. Reason: {e}"
3905 );
3906
3907 #[cfg(feature = "logging")]
3908 error!(target: "stdout", "{}", &err_msg);
3909
3910 LlamaCoreError::Operation(err_msg)
3911 })?;
3912
3913 let chat_completion_chunk = ChatCompletionChunk {
3914 id,
3915 object: "chat.completion.chunk".to_string(),
3916 created: created.as_secs(),
3917 model: graph.name().to_owned(),
3918 system_fingerprint: "fp_44709d6fcb".to_string(),
3919 choices: vec![ChatCompletionChunkChoice {
3920 index: 0,
3921 delta: ChatCompletionChunkChoiceDelta {
3922 role: ChatCompletionRole::Assistant,
3923 content: Some(output),
3924 tool_calls: vec![],
3925 },
3926 logprobs: None,
3927 finish_reason: None,
3928 }],
3929 usage: None,
3930 };
3931
3932 #[cfg(feature = "logging")]
3933 info!(target: "stdout", "created chat completion chunk");
3934
3935 let chunk_str =
3937 serde_json::to_string(&chat_completion_chunk)
3938 .map_err(|e| {
3939 let err_msg = format!(
3940 "Failed to serialize chat completion chunk. Reason: {e}"
3941 );
3942
3943 #[cfg(feature = "logging")]
3944 error!(target: "stdout", "{}", &err_msg);
3945
3946 LlamaCoreError::Operation(err_msg)
3947 })?;
3948
3949 Ok(format!("data: {chunk_str}\n\n"))
3950 }
3951 StreamState::Done => {
3952 *stream_state = StreamState::EndOfSequence;
3953
3954 Ok("data: [DONE]\n\n".to_string())
3955 }
3956 StreamState::EndOfSequence => {
3957 Ok("[GGML] End of sequence".to_string())
3958 }
3959 }
3960 }
3961 Err(wasmedge_wasi_nn::Error::BackendError(
3962 wasmedge_wasi_nn::BackendError::EndOfSequence,
3963 )) => {
3964 #[cfg(feature = "logging")]
3965 debug!(target: "stdout", "End of sequence");
3966
3967 match stream_state {
3968 StreamState::Usage => {
3969 *stream_state = StreamState::Done;
3970
3971 let token_info = get_token_info_by_graph(graph)?;
3973
3974 let usage = Some(Usage {
3975 prompt_tokens: token_info.prompt_tokens,
3976 completion_tokens: token_info.completion_tokens,
3977 total_tokens: token_info.prompt_tokens
3978 + token_info.completion_tokens,
3979 });
3980
3981 #[cfg(feature = "logging")]
3982 info!(target: "stdout", "token_info: {} prompt tokens, {} completion tokens", token_info.prompt_tokens, token_info.completion_tokens);
3983
3984 let created = SystemTime::now()
3985 .duration_since(std::time::UNIX_EPOCH)
3986 .map_err(|e| {
3987 let err_msg = format!(
3988 "Failed to get the current time. Reason: {e}"
3989 );
3990
3991 #[cfg(feature = "logging")]
3992 error!(target: "stdout", "{}", &err_msg);
3993
3994 LlamaCoreError::Operation(err_msg)
3995 })?;
3996
3997 let chat_completion_chunk = ChatCompletionChunk {
3998 id,
3999 object: "chat.completion.chunk".to_string(),
4000 created: created.as_secs(),
4001 model: graph.name().to_owned(),
4002 system_fingerprint: "fp_44709d6fcb".to_string(),
4003 choices: vec![],
4004 usage,
4005 };
4006
4007 let chunk_str =
4009 serde_json::to_string(&chat_completion_chunk)
4010 .map_err(|e| {
4011 let err_msg = format!(
4012 "Failed to serialize chat completion chunk. Reason: {e}"
4013 );
4014
4015 #[cfg(feature = "logging")]
4016 error!(target: "stdout", "{}", &err_msg);
4017
4018 LlamaCoreError::Operation(err_msg)
4019 })?;
4020
4021 Ok(format!("data: {chunk_str}\n\n"))
4022 }
4023 StreamState::Done | StreamState::NoUsage => {
4024 *stream_state = StreamState::EndOfSequence;
4025
4026 Ok("data: [DONE]\n\n".to_string())
4027 }
4028 StreamState::EndOfSequence => {
4029 Ok("[GGML] End of sequence".to_string())
4030 }
4031 }
4032 }
4033 Err(wasmedge_wasi_nn::Error::BackendError(
4034 wasmedge_wasi_nn::BackendError::ContextFull,
4035 )) => {
4036 #[cfg(feature = "logging")]
4037 debug!(target: "stdout", "Context full");
4038
4039 match context_full_state {
4040 ContextFullState::Message => {
4041 match include_usage {
4042 true => {
4043 *context_full_state = ContextFullState::Usage
4044 }
4045 false => {
4046 *context_full_state = ContextFullState::Done
4047 }
4048 }
4049
4050 let created = SystemTime::now()
4051 .duration_since(std::time::UNIX_EPOCH)
4052 .map_err(|e| {
4053 let err_msg = format!(
4054 "Failed to get the current time. Reason: {e}"
4055 );
4056
4057 #[cfg(feature = "logging")]
4058 error!(target: "stdout", "{}", &err_msg);
4059
4060 LlamaCoreError::Operation(err_msg)
4061 })?;
4062
4063 let chat_completion_chunk = ChatCompletionChunk {
4064 id,
4065 object: "chat.completion.chunk".to_string(),
4066 created: created.as_secs(),
4067 model: graph.name().to_owned(),
4068 system_fingerprint: "fp_44709d6fcb".to_string(),
4069 choices: vec![ChatCompletionChunkChoice {
4070 index: 0,
4071 delta: ChatCompletionChunkChoiceDelta {
4072 role: ChatCompletionRole::Assistant,
4073 content: Some(
4074 "<|WASMEDGE-GGML-CONTEXT-FULL|>"
4075 .to_string(),
4076 ),
4077 tool_calls: vec![],
4078 },
4079 logprobs: None,
4080 finish_reason: Some(FinishReason::length),
4081 }],
4082 usage: None,
4083 };
4084
4085 let chunk_str =
4087 serde_json::to_string(&chat_completion_chunk)
4088 .map_err(|e| {
4089 let err_msg = format!(
4090 "Failed to serialize chat completion chunk. Reason: {e}"
4091 );
4092
4093 #[cfg(feature = "logging")]
4094 error!(target: "stdout", "{}", &err_msg);
4095
4096 LlamaCoreError::Operation(err_msg)
4097 })?;
4098
4099 Ok(format!("data: {chunk_str}\n\n"))
4100 }
4101 ContextFullState::Usage => {
4102 *context_full_state = ContextFullState::Done;
4103
4104 let token_info = get_token_info_by_graph(graph)?;
4106
4107 let usage = Some(Usage {
4108 prompt_tokens: token_info.prompt_tokens,
4109 completion_tokens: token_info.completion_tokens,
4110 total_tokens: token_info.prompt_tokens
4111 + token_info.completion_tokens,
4112 });
4113
4114 let created = SystemTime::now()
4115 .duration_since(std::time::UNIX_EPOCH)
4116 .map_err(|e| {
4117 let err_msg = format!(
4118 "Failed to get the current time. Reason: {e}"
4119 );
4120
4121 #[cfg(feature = "logging")]
4122 error!(target: "stdout", "{}", &err_msg);
4123
4124 LlamaCoreError::Operation(err_msg)
4125 })?;
4126
4127 let chat_completion_chunk = ChatCompletionChunk {
4128 id,
4129 object: "chat.completion.chunk".to_string(),
4130 created: created.as_secs(),
4131 model: graph.name().to_owned(),
4132 system_fingerprint: "fp_44709d6fcb".to_string(),
4133 choices: vec![],
4134 usage,
4135 };
4136
4137 let chunk_str =
4139 serde_json::to_string(&chat_completion_chunk)
4140 .map_err(|e| {
4141 let err_msg = format!(
4142 "Failed to serialize chat completion chunk. Reason: {e}"
4143 );
4144
4145 #[cfg(feature = "logging")]
4146 error!(target: "stdout", "{}", &err_msg);
4147
4148 LlamaCoreError::Operation(err_msg)
4149 })?;
4150
4151 Ok(format!("data: {chunk_str}\n\n"))
4152 }
4153 ContextFullState::Done => {
4154 *context_full_state = ContextFullState::EndOfSequence;
4155
4156 Ok("data: [DONE]\n\n".to_string())
4157 }
4158 ContextFullState::EndOfSequence => {
4159 Ok("[GGML] End of sequence".to_string())
4160 }
4161 }
4162 }
4163 Err(wasmedge_wasi_nn::Error::BackendError(
4164 wasmedge_wasi_nn::BackendError::PromptTooLong,
4165 )) => {
4166 #[cfg(feature = "logging")]
4167 debug!(target: "stdout", "Prompt too long");
4168
4169 match prompt_too_long_state {
4170 PromptTooLongState::Message => {
4171 match include_usage {
4172 true => {
4173 *prompt_too_long_state =
4174 PromptTooLongState::Usage
4175 }
4176 false => {
4177 *prompt_too_long_state =
4178 PromptTooLongState::Done
4179 }
4180 }
4181
4182 let created = SystemTime::now()
4183 .duration_since(std::time::UNIX_EPOCH)
4184 .map_err(|e| {
4185 let err_msg = format!(
4186 "Failed to get the current time. Reason: {e}"
4187 );
4188
4189 #[cfg(feature = "logging")]
4190 error!(target: "stdout", "{}", &err_msg);
4191
4192 LlamaCoreError::Operation(err_msg)
4193 })?;
4194
4195 let chat_completion_chunk = ChatCompletionChunk {
4196 id,
4197 object: "chat.completion.chunk".to_string(),
4198 created: created.as_secs(),
4199 model: graph.name().to_owned(),
4200 system_fingerprint: "fp_44709d6fcb".to_string(),
4201 choices: vec![ChatCompletionChunkChoice {
4202 index: 0,
4203 delta: ChatCompletionChunkChoiceDelta {
4204 role: ChatCompletionRole::Assistant,
4205 content: None,
4206 tool_calls: vec![],
4207 },
4208 logprobs: None,
4209 finish_reason: Some(FinishReason::length),
4210 }],
4211 usage: None,
4212 };
4213
4214 let chunk_str =
4216 serde_json::to_string(&chat_completion_chunk)
4217 .map_err(|e| {
4218 let err_msg = format!(
4219 "Failed to serialize chat completion chunk. Reason: {e}"
4220 );
4221
4222 #[cfg(feature = "logging")]
4223 error!(target: "stdout", "{}", &err_msg);
4224
4225 LlamaCoreError::Operation(err_msg)
4226 })?;
4227
4228 Ok(format!("data: {chunk_str}\n\n"))
4229 }
4230 PromptTooLongState::Usage => {
4231 *prompt_too_long_state = PromptTooLongState::Done;
4232
4233 let token_info = get_token_info_by_graph(graph)?;
4235
4236 let usage = Some(Usage {
4237 prompt_tokens: token_info.prompt_tokens,
4238 completion_tokens: token_info.completion_tokens,
4239 total_tokens: token_info.prompt_tokens
4240 + token_info.completion_tokens,
4241 });
4242
4243 let created = SystemTime::now()
4244 .duration_since(std::time::UNIX_EPOCH)
4245 .map_err(|e| {
4246 let err_msg = format!(
4247 "Failed to get the current time. Reason: {e}"
4248 );
4249
4250 #[cfg(feature = "logging")]
4251 error!(target: "stdout", "{}", &err_msg);
4252
4253 LlamaCoreError::Operation(err_msg)
4254 })?;
4255
4256 let chat_completion_chunk = ChatCompletionChunk {
4257 id,
4258 object: "chat.completion.chunk".to_string(),
4259 created: created.as_secs(),
4260 model: graph.name().to_owned(),
4261 system_fingerprint: "fp_44709d6fcb".to_string(),
4262 choices: vec![],
4263 usage,
4264 };
4265
4266 let chunk_str =
4268 serde_json::to_string(&chat_completion_chunk)
4269 .map_err(|e| {
4270 let err_msg = format!(
4271 "Failed to serialize chat completion chunk. Reason: {e}"
4272 );
4273
4274 #[cfg(feature = "logging")]
4275 error!(target: "stdout", "{}", &err_msg);
4276
4277 LlamaCoreError::Operation(err_msg)
4278 })?;
4279
4280 Ok(format!("data: {chunk_str}\n\n"))
4281 }
4282 PromptTooLongState::Done => {
4283 *prompt_too_long_state =
4284 PromptTooLongState::EndOfSequence;
4285
4286 Ok("data: [DONE]\n\n".to_string())
4287 }
4288 PromptTooLongState::EndOfSequence => {
4289 Ok("[GGML] End of sequence".to_string())
4290 }
4291 }
4292 }
4293 Err(e) => {
4294 let err_msg = format!(
4295 "Failed to compute the chat completion. Reason: {e}"
4296 );
4297
4298 #[cfg(feature = "logging")]
4299 error!(target: "stdout", "{}", &err_msg);
4300
4301 Err(LlamaCoreError::Backend(BackendError::ComputeSingle(
4302 err_msg,
4303 )))
4304 }
4305 }
4306 }
4307 None => {
4308 let err_msg = "There is no model available in the chat graphs.";
4309
4310 #[cfg(feature = "logging")]
4311 error!(target: "stdout", "{}", &err_msg);
4312
4313 Err(LlamaCoreError::Operation(err_msg.into()))
4314 }
4315 }
4316 }
4317 }
4318 }
4319 None => {
4320 match chat_graphs.iter_mut().next() {
4321 Some((_, graph)) => {
4322 match graph.compute_single() {
4324 Ok(_) => {
4325 #[cfg(feature = "logging")]
4326 debug!(target: "stdout", "Compute the chat stream chunk successfully.");
4327
4328 match stream_state {
4329 StreamState::Usage | StreamState::NoUsage => {
4330 let output_buffer =
4332 get_output_buffer_single(graph, OUTPUT_TENSOR)?;
4333
4334 #[cfg(feature = "logging")]
4335 info!(target: "stdout", "retrieved the output buffer");
4336
4337 let output = match String::from_utf8(output_buffer.clone()) {
4339 Ok(token) => token,
4340 Err(_) => {
4341 let mutex = CACHED_UTF8_ENCODINGS
4342 .get_or_init(|| Mutex::new(Vec::new()));
4343 let mut cached_encodings = mutex.lock().map_err(|e| {
4344 let err_msg = format!(
4345 "Fail to acquire the lock of `UTF8_ENCODINGS`. Reason: {e}"
4346 );
4347
4348 #[cfg(feature = "logging")]
4349 error!(target: "stdout", "{}", &err_msg);
4350
4351 LlamaCoreError::Operation(err_msg)
4352 })?;
4353
4354 cached_encodings.extend_from_slice(&output_buffer[..]);
4355
4356 match String::from_utf8(cached_encodings.to_vec()) {
4357 Ok(token) => {
4358 cached_encodings.clear();
4360
4361 token
4362 }
4363 Err(e) => {
4364 if cached_encodings.len() > 4 {
4366 let err_msg = format!("Fail to convert a vector of bytes to string. The length of the utf8 bytes exceeds 4. {e}");
4367
4368 #[cfg(feature = "logging")]
4369 error!(target: "stdout", "{}", &err_msg);
4370
4371 #[cfg(feature = "logging")]
4372 error!(target: "stdout", "The cached buffer: {:?}", &cached_encodings[..]);
4373
4374 cached_encodings.clear();
4381
4382 String::from("")
4383 } else {
4384 let warn_msg = format!("Fail to convert a vector of bytes to string. {e}");
4385
4386 #[cfg(feature = "logging")]
4387 warn!(target: "stdout", "{}", &warn_msg);
4388
4389 String::from("")
4390 }
4391 }
4392 }
4393 }
4394 };
4395
4396 #[cfg(feature = "logging")]
4397 info!(target: "stdout", "decoded the output buffer");
4398
4399 let created = SystemTime::now()
4400 .duration_since(std::time::UNIX_EPOCH)
4401 .map_err(|e| {
4402 let err_msg = format!(
4403 "Failed to get the current time. Reason: {e}"
4404 );
4405
4406 #[cfg(feature = "logging")]
4407 error!(target: "stdout", "{}", &err_msg);
4408
4409 LlamaCoreError::Operation(err_msg)
4410 })?;
4411
4412 let chat_completion_chunk = ChatCompletionChunk {
4413 id,
4414 object: "chat.completion.chunk".to_string(),
4415 created: created.as_secs(),
4416 model: graph.name().to_owned(),
4417 system_fingerprint: "fp_44709d6fcb".to_string(),
4418 choices: vec![ChatCompletionChunkChoice {
4419 index: 0,
4420 delta: ChatCompletionChunkChoiceDelta {
4421 role: ChatCompletionRole::Assistant,
4422 content: Some(output),
4423 tool_calls: vec![],
4424 },
4425 logprobs: None,
4426 finish_reason: None,
4427 }],
4428 usage: None,
4429 };
4430
4431 #[cfg(feature = "logging")]
4432 info!(target: "stdout", "created chat completion chunk");
4433
4434 let chunk_str = serde_json::to_string(&chat_completion_chunk)
4436 .map_err(|e| {
4437 let err_msg = format!(
4438 "Failed to serialize chat completion chunk. Reason: {e}"
4439 );
4440
4441 #[cfg(feature = "logging")]
4442 error!(target: "stdout", "{}", &err_msg);
4443
4444 LlamaCoreError::Operation(err_msg)
4445 })?;
4446
4447 Ok(format!("data: {chunk_str}\n\n"))
4448 }
4449 StreamState::Done => {
4450 *stream_state = StreamState::EndOfSequence;
4451
4452 Ok("data: [DONE]\n\n".to_string())
4453 }
4454 StreamState::EndOfSequence => {
4455 Ok("[GGML] End of sequence".to_string())
4456 }
4457 }
4458 }
4459 Err(wasmedge_wasi_nn::Error::BackendError(
4460 wasmedge_wasi_nn::BackendError::EndOfSequence,
4461 )) => {
4462 #[cfg(feature = "logging")]
4463 debug!(target: "stdout", "End of sequence");
4464
4465 match stream_state {
4466 StreamState::Usage => {
4467 *stream_state = StreamState::Done;
4468
4469 let token_info = get_token_info_by_graph(graph)?;
4471
4472 let usage = Some(Usage {
4473 prompt_tokens: token_info.prompt_tokens,
4474 completion_tokens: token_info.completion_tokens,
4475 total_tokens: token_info.prompt_tokens
4476 + token_info.completion_tokens,
4477 });
4478
4479 #[cfg(feature = "logging")]
4480 info!(target: "stdout", "token_info: {} prompt tokens, {} completion tokens", token_info.prompt_tokens, token_info.completion_tokens);
4481
4482 let created = SystemTime::now()
4483 .duration_since(std::time::UNIX_EPOCH)
4484 .map_err(|e| {
4485 let err_msg = format!(
4486 "Failed to get the current time. Reason: {e}"
4487 );
4488
4489 #[cfg(feature = "logging")]
4490 error!(target: "stdout", "{}", &err_msg);
4491
4492 LlamaCoreError::Operation(err_msg)
4493 })?;
4494
4495 let chat_completion_chunk = ChatCompletionChunk {
4496 id,
4497 object: "chat.completion.chunk".to_string(),
4498 created: created.as_secs(),
4499 model: graph.name().to_owned(),
4500 system_fingerprint: "fp_44709d6fcb".to_string(),
4501 choices: vec![],
4502 usage,
4503 };
4504
4505 let chunk_str = serde_json::to_string(&chat_completion_chunk)
4507 .map_err(|e| {
4508 let err_msg = format!(
4509 "Failed to serialize chat completion chunk. Reason: {e}"
4510 );
4511
4512 #[cfg(feature = "logging")]
4513 error!(target: "stdout", "{}", &err_msg);
4514
4515 LlamaCoreError::Operation(err_msg)
4516 })?;
4517
4518 Ok(format!("data: {chunk_str}\n\n"))
4519 }
4520 StreamState::Done | StreamState::NoUsage => {
4521 *stream_state = StreamState::EndOfSequence;
4522
4523 Ok("data: [DONE]\n\n".to_string())
4524 }
4525 StreamState::EndOfSequence => {
4526 Ok("[GGML] End of sequence".to_string())
4527 }
4528 }
4529 }
4530 Err(wasmedge_wasi_nn::Error::BackendError(
4531 wasmedge_wasi_nn::BackendError::ContextFull,
4532 )) => {
4533 #[cfg(feature = "logging")]
4534 debug!(target: "stdout", "Context full");
4535
4536 match context_full_state {
4537 ContextFullState::Message => {
4538 match include_usage {
4539 true => *context_full_state = ContextFullState::Usage,
4540 false => *context_full_state = ContextFullState::Done,
4541 }
4542
4543 let created = SystemTime::now()
4544 .duration_since(std::time::UNIX_EPOCH)
4545 .map_err(|e| {
4546 let err_msg = format!(
4547 "Failed to get the current time. Reason: {e}"
4548 );
4549
4550 #[cfg(feature = "logging")]
4551 error!(target: "stdout", "{}", &err_msg);
4552
4553 LlamaCoreError::Operation(err_msg)
4554 })?;
4555
4556 let chat_completion_chunk = ChatCompletionChunk {
4557 id,
4558 object: "chat.completion.chunk".to_string(),
4559 created: created.as_secs(),
4560 model: graph.name().to_owned(),
4561 system_fingerprint: "fp_44709d6fcb".to_string(),
4562 choices: vec![ChatCompletionChunkChoice {
4563 index: 0,
4564 delta: ChatCompletionChunkChoiceDelta {
4565 role: ChatCompletionRole::Assistant,
4566 content: Some(
4567 "<|WASMEDGE-GGML-CONTEXT-FULL|>".to_string(),
4568 ),
4569 tool_calls: vec![],
4570 },
4571 logprobs: None,
4572 finish_reason: Some(FinishReason::length),
4573 }],
4574 usage: None,
4575 };
4576
4577 let chunk_str = serde_json::to_string(&chat_completion_chunk)
4579 .map_err(|e| {
4580 let err_msg = format!(
4581 "Failed to serialize chat completion chunk. Reason: {e}"
4582 );
4583
4584 #[cfg(feature = "logging")]
4585 error!(target: "stdout", "{}", &err_msg);
4586
4587 LlamaCoreError::Operation(err_msg)
4588 })?;
4589
4590 Ok(format!("data: {chunk_str}\n\n"))
4591 }
4592 ContextFullState::Usage => {
4593 *context_full_state = ContextFullState::Done;
4594
4595 let token_info = get_token_info_by_graph(graph)?;
4597
4598 let usage = Some(Usage {
4599 prompt_tokens: token_info.prompt_tokens,
4600 completion_tokens: token_info.completion_tokens,
4601 total_tokens: token_info.prompt_tokens
4602 + token_info.completion_tokens,
4603 });
4604
4605 let created = SystemTime::now()
4606 .duration_since(std::time::UNIX_EPOCH)
4607 .map_err(|e| {
4608 let err_msg = format!(
4609 "Failed to get the current time. Reason: {e}"
4610 );
4611
4612 #[cfg(feature = "logging")]
4613 error!(target: "stdout", "{}", &err_msg);
4614
4615 LlamaCoreError::Operation(err_msg)
4616 })?;
4617
4618 let chat_completion_chunk = ChatCompletionChunk {
4619 id,
4620 object: "chat.completion.chunk".to_string(),
4621 created: created.as_secs(),
4622 model: graph.name().to_owned(),
4623 system_fingerprint: "fp_44709d6fcb".to_string(),
4624 choices: vec![],
4625 usage,
4626 };
4627
4628 let chunk_str = serde_json::to_string(&chat_completion_chunk)
4630 .map_err(|e| {
4631 let err_msg = format!(
4632 "Failed to serialize chat completion chunk. Reason: {e}"
4633 );
4634
4635 #[cfg(feature = "logging")]
4636 error!(target: "stdout", "{}", &err_msg);
4637
4638 LlamaCoreError::Operation(err_msg)
4639 })?;
4640
4641 Ok(format!("data: {chunk_str}\n\n"))
4642 }
4643 ContextFullState::Done => {
4644 *context_full_state = ContextFullState::EndOfSequence;
4645
4646 Ok("data: [DONE]\n\n".to_string())
4647 }
4648 ContextFullState::EndOfSequence => {
4649 Ok("[GGML] End of sequence".to_string())
4650 }
4651 }
4652 }
4653 Err(wasmedge_wasi_nn::Error::BackendError(
4654 wasmedge_wasi_nn::BackendError::PromptTooLong,
4655 )) => {
4656 #[cfg(feature = "logging")]
4657 debug!(target: "stdout", "Prompt too long");
4658
4659 match prompt_too_long_state {
4660 PromptTooLongState::Message => {
4661 match include_usage {
4662 true => *prompt_too_long_state = PromptTooLongState::Usage,
4663 false => *prompt_too_long_state = PromptTooLongState::Done,
4664 }
4665
4666 let created = SystemTime::now()
4667 .duration_since(std::time::UNIX_EPOCH)
4668 .map_err(|e| {
4669 let err_msg = format!(
4670 "Failed to get the current time. Reason: {e}"
4671 );
4672
4673 #[cfg(feature = "logging")]
4674 error!(target: "stdout", "{}", &err_msg);
4675
4676 LlamaCoreError::Operation(err_msg)
4677 })?;
4678
4679 let chat_completion_chunk = ChatCompletionChunk {
4680 id,
4681 object: "chat.completion.chunk".to_string(),
4682 created: created.as_secs(),
4683 model: graph.name().to_owned(),
4684 system_fingerprint: "fp_44709d6fcb".to_string(),
4685 choices: vec![ChatCompletionChunkChoice {
4686 index: 0,
4687 delta: ChatCompletionChunkChoiceDelta {
4688 role: ChatCompletionRole::Assistant,
4689 content: None,
4690 tool_calls: vec![],
4691 },
4692 logprobs: None,
4693 finish_reason: Some(FinishReason::length),
4694 }],
4695 usage: None,
4696 };
4697
4698 let chunk_str = serde_json::to_string(&chat_completion_chunk)
4700 .map_err(|e| {
4701 let err_msg = format!(
4702 "Failed to serialize chat completion chunk. Reason: {e}"
4703 );
4704
4705 #[cfg(feature = "logging")]
4706 error!(target: "stdout", "{}", &err_msg);
4707
4708 LlamaCoreError::Operation(err_msg)
4709 })?;
4710
4711 Ok(format!("data: {chunk_str}\n\n"))
4712 }
4713 PromptTooLongState::Usage => {
4714 *prompt_too_long_state = PromptTooLongState::Done;
4715
4716 let token_info = get_token_info_by_graph(graph)?;
4718
4719 let usage = Some(Usage {
4720 prompt_tokens: token_info.prompt_tokens,
4721 completion_tokens: token_info.completion_tokens,
4722 total_tokens: token_info.prompt_tokens
4723 + token_info.completion_tokens,
4724 });
4725
4726 let created = SystemTime::now()
4727 .duration_since(std::time::UNIX_EPOCH)
4728 .map_err(|e| {
4729 let err_msg = format!(
4730 "Failed to get the current time. Reason: {e}"
4731 );
4732
4733 #[cfg(feature = "logging")]
4734 error!(target: "stdout", "{}", &err_msg);
4735
4736 LlamaCoreError::Operation(err_msg)
4737 })?;
4738
4739 let chat_completion_chunk = ChatCompletionChunk {
4740 id,
4741 object: "chat.completion.chunk".to_string(),
4742 created: created.as_secs(),
4743 model: graph.name().to_owned(),
4744 system_fingerprint: "fp_44709d6fcb".to_string(),
4745 choices: vec![],
4746 usage,
4747 };
4748
4749 let chunk_str = serde_json::to_string(&chat_completion_chunk)
4751 .map_err(|e| {
4752 let err_msg = format!(
4753 "Failed to serialize chat completion chunk. Reason: {e}"
4754 );
4755
4756 #[cfg(feature = "logging")]
4757 error!(target: "stdout", "{}", &err_msg);
4758
4759 LlamaCoreError::Operation(err_msg)
4760 })?;
4761
4762 Ok(format!("data: {chunk_str}\n\n"))
4763 }
4764 PromptTooLongState::Done => {
4765 *prompt_too_long_state = PromptTooLongState::EndOfSequence;
4766
4767 Ok("data: [DONE]\n\n".to_string())
4768 }
4769 PromptTooLongState::EndOfSequence => {
4770 Ok("[GGML] End of sequence".to_string())
4771 }
4772 }
4773 }
4774 Err(e) => {
4775 let err_msg =
4776 format!("Failed to compute the chat completion. Reason: {e}");
4777
4778 #[cfg(feature = "logging")]
4779 error!(target: "stdout", "{}", &err_msg);
4780
4781 Err(LlamaCoreError::Backend(BackendError::ComputeSingle(
4782 err_msg,
4783 )))
4784 }
4785 }
4786 }
4787 None => {
4788 let err_msg = "There is no model available in the chat graphs.";
4789
4790 #[cfg(feature = "logging")]
4791 error!(target: "stdout", "{}", &err_msg);
4792
4793 Err(LlamaCoreError::Operation(err_msg.into()))
4794 }
4795 }
4796 }
4797 };
4798
4799 #[cfg(feature = "logging")]
4800 info!(target: "stdout", "Return the chat stream chunk!");
4801
4802 res
4803}
4804
4805#[allow(dead_code)]
4806#[derive(Debug)]
4807struct ParseResult {
4808 raw: String,
4809 content: Option<String>,
4810 tool_calls: Vec<ToolCall>,
4811}