1use crate::{
4 error,
5 metadata::ggml::GgmlMetadata,
6 running_mode,
7 utils::{
8 gen_chat_id, get_output_buffer, get_output_buffer_single, get_token_info_by_graph,
9 get_token_info_by_graph_name, set_tensor_data_u8,
10 },
11 Graph, RunningMode, CACHED_UTF8_ENCODINGS, CHAT_GRAPHS, OUTPUT_TENSOR,
12};
13use chat_prompts::{
14 chat::{BuildChatPrompt, ChatPrompt},
15 PromptTemplateType,
16};
17use either::{Either, Left, Right};
18use endpoints::{
19 chat::{
20 ChatCompletionChunk, ChatCompletionChunkChoice, ChatCompletionChunkChoiceDelta,
21 ChatCompletionObject, ChatCompletionObjectChoice, ChatCompletionObjectMessage,
22 ChatCompletionRequest, ChatCompletionRequestMessage, ChatCompletionRole,
23 ChatCompletionUserMessageContent, ContentPart, Function, ToolCall, ToolCallForChunk,
24 ToolChoice,
25 },
26 common::{FinishReason, Usage},
27};
28use error::{BackendError, LlamaCoreError};
29use futures::StreamExt;
30use std::{
31 collections::VecDeque,
32 fs::{self, File},
33 path::Path,
34 pin::Pin,
35 sync::Mutex,
36 task::{Context, Poll},
37 time::SystemTime,
38};
39
40pub async fn chat(
42 chat_request: &mut ChatCompletionRequest,
43) -> Result<
44 Either<impl futures::TryStream<Ok = String, Error = LlamaCoreError>, ChatCompletionObject>,
45 LlamaCoreError,
46> {
47 #[cfg(feature = "logging")]
48 {
49 debug!(target: "stdout", "tool choice: {:?}", chat_request.tool_choice.as_ref());
50 debug!(target: "stdout", "tools: {:?}", chat_request.tools.as_ref());
51 debug!(target: "stdout", "stream mode: {:?}", chat_request.stream);
52 }
53
54 let model_name = chat_request.model.clone();
55
56 let result = match chat_request.stream {
57 Some(true) => match chat_stream(chat_request).await {
58 Ok(stream) => Ok(Left(stream)),
59 Err(e) => Err(e),
60 },
61 Some(false) | None => match chat_once(chat_request).await {
62 Ok(chat_completion_object) => Ok(Right(chat_completion_object)),
63 Err(e) => Err(e),
64 },
65 };
66
67 #[cfg(feature = "logging")]
68 info!(target: "stdout", "Reset the model metadata");
69
70 reset_model_metadata(model_name.as_ref())?;
72
73 result
74}
75
76#[deprecated(since = "0.10.0", note = "Please use the `chat` function.")]
78pub async fn chat_completions_stream(
79 chat_request: &mut ChatCompletionRequest,
80) -> Result<impl futures::TryStream<Ok = String, Error = LlamaCoreError>, LlamaCoreError> {
81 chat_stream(chat_request).await
82}
83
84#[deprecated(since = "0.10.0", note = "Please use the `chat` function.")]
86pub async fn chat_completions(
87 chat_request: &mut ChatCompletionRequest,
88) -> Result<ChatCompletionObject, LlamaCoreError> {
89 chat_once(chat_request).await
90}
91
92async fn chat_stream(
93 chat_request: &mut ChatCompletionRequest,
94) -> Result<impl futures::TryStream<Ok = String, Error = LlamaCoreError>, LlamaCoreError> {
95 #[cfg(feature = "logging")]
96 info!(target: "stdout", "Process chat completion request in the stream mode");
97
98 let running_mode = running_mode()?;
99 if !running_mode.contains(RunningMode::CHAT) && !running_mode.contains(RunningMode::RAG) {
100 let err_msg = "The chat completion is only supported in the chat or rag mode.";
101
102 #[cfg(feature = "logging")]
103 error!(target: "stdout", "{}", err_msg);
104
105 return Err(LlamaCoreError::Operation(err_msg.to_string()));
106 }
107
108 let model_name = chat_request.model.clone();
109 let id = match &chat_request.user {
110 Some(id) => id.clone(),
111 None => gen_chat_id(),
112 };
113 #[cfg(feature = "logging")]
114 info!(target: "stdout", "user: {}", &id);
115
116 #[cfg(feature = "logging")]
117 info!(target: "stdout", "Check model metadata");
118
119 let mut metadata = check_model_metadata(chat_request).await?;
121
122 let include_usage = match chat_request.stream_options {
124 Some(ref stream_options) => stream_options.include_usage.unwrap_or_default(),
125 None => metadata.include_usage,
126 };
127 #[cfg(feature = "logging")]
128 info!(target: "stdout", "include_usage: {}", include_usage);
129
130 #[cfg(feature = "logging")]
131 info!(target: "stdout", "Build the chat prompt");
132
133 let (prompt, avaible_completion_tokens, tool_use) =
135 build_prompt(model_name.as_ref(), chat_request)?;
136
137 #[cfg(feature = "logging")]
138 {
139 info!(target: "stdout", "prompt:\n{}", &prompt);
140 info!(target: "stdout", "available_completion_tokens: {}", avaible_completion_tokens);
141 info!(target: "stdout", "tool_use: {}", tool_use);
142 }
143
144 #[cfg(feature = "logging")]
145 info!(target: "stdout", "Update the n_predict");
146
147 update_n_predict(chat_request, &mut metadata, avaible_completion_tokens).await?;
149
150 #[cfg(feature = "logging")]
151 info!(target: "stdout", "Feed the prompt to the model");
152
153 set_prompt(chat_request.model.as_ref(), &prompt)?;
155
156 let stream = match tool_use {
157 false => ChatStream::new(model_name, id, include_usage, None),
158 true => {
159 let chat_graphs = match CHAT_GRAPHS.get() {
160 Some(chat_graphs) => chat_graphs,
161 None => {
162 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
163
164 #[cfg(feature = "logging")]
165 error!(target: "stdout", "{}", &err_msg);
166
167 return Err(LlamaCoreError::Operation(err_msg.into()));
168 }
169 };
170
171 let mut chat_graphs = chat_graphs.lock().map_err(|e| {
172 let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {}", e);
173
174 #[cfg(feature = "logging")]
175 error!(target: "stdout", "{}", &err_msg);
176
177 LlamaCoreError::Operation(err_msg)
178 })?;
179
180 match model_name {
181 Some(model_name) => match chat_graphs.contains_key(&model_name) {
182 true => {
183 let graph = chat_graphs.get_mut(&model_name).unwrap();
184 chat_stream_by_graph(graph, id, include_usage)?
185 }
186 false => match chat_graphs.iter_mut().next() {
187 Some((_, graph)) => chat_stream_by_graph(graph, id, include_usage)?,
188 None => {
189 let err_msg = "There is no model available in the chat graphs.";
190
191 #[cfg(feature = "logging")]
192 error!(target: "stdout", "{}", &err_msg);
193
194 return Err(LlamaCoreError::Operation(err_msg.into()));
195 }
196 },
197 },
198 None => match chat_graphs.iter_mut().next() {
199 Some((_, graph)) => chat_stream_by_graph(graph, id, include_usage)?,
200 None => {
201 let err_msg = "There is no model available in the chat graphs.";
202
203 #[cfg(feature = "logging")]
204 error!(target: "stdout", "{}", &err_msg);
205
206 return Err(LlamaCoreError::Operation(err_msg.into()));
207 }
208 },
209 }
210 }
211 };
212
213 #[cfg(feature = "logging")]
214 info!(target: "stdout", "End of the chat completion stream.");
215
216 Ok(stream)
217}
218
219fn chat_stream_by_graph(
220 graph: &mut Graph<GgmlMetadata>,
221 id: impl Into<String>,
222 include_usage: bool,
223) -> Result<ChatStream, LlamaCoreError> {
224 #[cfg(feature = "logging")]
225 info!(target: "stdout", "Handle chat request with available tools by the model named {}.", graph.name());
226
227 let id = id.into();
228
229 match graph.compute() {
230 Ok(_) => {
231 let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
233 let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
234 let err_msg = format!(
235 "Failed to decode the buffer of the inference result to a utf-8 string. {}",
236 e
237 );
238
239 #[cfg(feature = "logging")]
240 error!(target: "stdout", "{}", &err_msg);
241
242 LlamaCoreError::Operation(err_msg)
243 })?;
244
245 #[cfg(feature = "logging")]
246 info!(target: "stdout", "raw generation:\n{}", output);
247
248 let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
250 LlamaCoreError::Operation(format!("Failed to post-process the output. {}", e))
251 })?;
252
253 #[cfg(feature = "logging")]
254 info!(target: "stdout", "post-processed generation:\n{}", &message);
255
256 let token_info = get_token_info_by_graph(graph)?;
258
259 #[cfg(feature = "logging")]
260 info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
261
262 let usage = Some(Usage {
263 prompt_tokens: token_info.prompt_tokens,
264 completion_tokens: token_info.completion_tokens,
265 total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
266 });
267
268 let created = SystemTime::now()
269 .duration_since(std::time::UNIX_EPOCH)
270 .map_err(|e| {
271 let err_msg = format!("Failed to get the current time. Reason: {}", e);
272
273 #[cfg(feature = "logging")]
274 error!(target: "stdout", "{}", &err_msg);
275
276 LlamaCoreError::Operation(err_msg)
277 })?;
278
279 if graph.metadata.prompt_template != PromptTemplateType::MistralTool
280 && graph.metadata.prompt_template != PromptTemplateType::ChatMLTool
281 && graph.metadata.prompt_template != PromptTemplateType::GroqLlama3Tool
282 && graph.metadata.prompt_template != PromptTemplateType::Llama3Tool
283 && graph.metadata.prompt_template != PromptTemplateType::InternLM2Tool
284 && graph.metadata.prompt_template != PromptTemplateType::NemotronTool
285 && graph.metadata.prompt_template != PromptTemplateType::FunctionaryV32
286 && graph.metadata.prompt_template != PromptTemplateType::FunctionaryV31
287 && graph.metadata.prompt_template != PromptTemplateType::MistralSmallTool
288 {
289 let err_msg = format!("Unsupported prompt template: {}. The tool use is only supported for 'mistral-tool', 'chatml-tool', 'groq-llama3-tool', 'llama-3-tool', 'internlm-2-tool', 'nemotron-tool', 'functionary-31', 'functionary-32', and 'mistral-small-tool' prompt templates.", graph.metadata.prompt_template);
290
291 #[cfg(feature = "logging")]
292 error!(target: "stdout", "{}", &err_msg);
293
294 return Err(LlamaCoreError::Operation(err_msg));
295 }
296
297 let parsed_result = parse_tool_calls(&message, graph.metadata.prompt_template)?;
298
299 let content = match parsed_result.content {
300 Some(content) => Some(content),
301 None => Some(parsed_result.raw),
302 };
303
304 let tool_calls: Vec<ToolCallForChunk> = parsed_result
305 .tool_calls
306 .into_iter()
307 .enumerate()
308 .map(|(index, tool_call)| ToolCallForChunk {
309 index,
310 id: tool_call.id,
311 ty: tool_call.ty,
312 function: tool_call.function,
313 })
314 .collect();
315
316 let tool_call_chunk = {
318 let chat_completion_chunk = ChatCompletionChunk {
319 id: id.clone(),
320 object: "chat.completion.chunk".to_string(),
321 created: created.as_secs(),
322 model: graph.name().to_owned(),
323 system_fingerprint: "fp_44709d6fcb".to_string(),
324 choices: vec![ChatCompletionChunkChoice {
325 index: 0,
326 delta: ChatCompletionChunkChoiceDelta {
327 role: ChatCompletionRole::Assistant,
328 content,
329 tool_calls,
330 },
331 logprobs: None,
332 finish_reason: None,
333 }],
334 usage: None,
335 };
336 let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
337 let err_msg =
338 format!("Failed to serialize chat completion chunk. Reason: {}", e);
339
340 #[cfg(feature = "logging")]
341 error!(target: "stdout", "{}", &err_msg);
342
343 LlamaCoreError::Operation(err_msg)
344 })?;
345
346 format!("data: {}\n\n", chunk_str)
347 };
348
349 let usage_chunk = {
351 let chat_completion_chunk = ChatCompletionChunk {
352 id: id.clone(),
353 object: "chat.completion.chunk".to_string(),
354 created: created.as_secs(),
355 model: graph.name().to_owned(),
356 system_fingerprint: "fp_44709d6fcb".to_string(),
357 choices: vec![],
358 usage,
359 };
360 let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
361 let err_msg =
362 format!("Failed to serialize chat completion chunk. Reason: {}", e);
363
364 #[cfg(feature = "logging")]
365 error!(target: "stdout", "{}", &err_msg);
366
367 LlamaCoreError::Operation(err_msg)
368 })?;
369
370 format!("data: {}\n\n", chunk_str)
371 };
372
373 let ending_chunk = "data: [DONE]\n\n".to_string();
375
376 let chunks = vec![tool_call_chunk, usage_chunk, ending_chunk];
377
378 Ok(ChatStream::new(
379 Some(graph.name().to_owned()),
380 id,
381 include_usage,
382 Some(chunks),
383 ))
384 }
385 Err(wasmedge_wasi_nn::Error::BackendError(wasmedge_wasi_nn::BackendError::ContextFull)) => {
386 let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
388 let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
389 let err_msg = format!(
390 "Failed to decode the buffer of the inference result to a utf-8 string. {}",
391 e
392 );
393
394 #[cfg(feature = "logging")]
395 error!(target: "stdout", "{}", &err_msg);
396
397 LlamaCoreError::Operation(err_msg)
398 })?;
399
400 let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
402 let err_msg = format!("Failed to post-process the output. {}", e);
403
404 #[cfg(feature = "logging")]
405 error!(target: "stdout", "{}", &err_msg);
406
407 LlamaCoreError::Operation(err_msg)
408 })?;
409
410 let token_info = get_token_info_by_graph(graph)?;
412
413 #[cfg(feature = "logging")]
414 info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
415
416 let usage = Some(Usage {
417 prompt_tokens: token_info.prompt_tokens,
418 completion_tokens: token_info.completion_tokens,
419 total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
420 });
421
422 let created = SystemTime::now()
423 .duration_since(std::time::UNIX_EPOCH)
424 .map_err(|e| {
425 let err_msg = format!("Failed to get the current time. Reason: {}", e);
426
427 #[cfg(feature = "logging")]
428 error!(target: "stdout", "{}", &err_msg);
429
430 LlamaCoreError::Operation(err_msg)
431 })?;
432
433 let context_full_chunk = {
435 let chat_completion_chunk = ChatCompletionChunk {
436 id: id.clone(),
437 object: "chat.completion.chunk".to_string(),
438 created: created.as_secs(),
439 model: graph.name().to_owned(),
440 system_fingerprint: "fp_44709d6fcb".to_string(),
441 choices: vec![ChatCompletionChunkChoice {
442 index: 0,
443 delta: ChatCompletionChunkChoiceDelta {
444 role: ChatCompletionRole::Assistant,
445 content: Some(message),
446 tool_calls: vec![],
447 },
448 logprobs: None,
449 finish_reason: Some(FinishReason::length),
450 }],
451 usage: None,
452 };
453
454 let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
456 let err_msg =
457 format!("Failed to serialize chat completion chunk. Reason: {}", e);
458
459 #[cfg(feature = "logging")]
460 error!(target: "stdout", "{}", &err_msg);
461
462 LlamaCoreError::Operation(err_msg)
463 })?;
464
465 format!("data: {}\n\n", chunk_str)
466 };
467
468 let usage_chunk = {
470 let chat_completion_chunk = ChatCompletionChunk {
471 id: id.clone(),
472 object: "chat.completion.chunk".to_string(),
473 created: created.as_secs(),
474 model: graph.name().to_owned(),
475 system_fingerprint: "fp_44709d6fcb".to_string(),
476 choices: vec![],
477 usage,
478 };
479
480 let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
482 let err_msg =
483 format!("Failed to serialize chat completion chunk. Reason: {}", e);
484
485 #[cfg(feature = "logging")]
486 error!(target: "stdout", "{}", &err_msg);
487
488 LlamaCoreError::Operation(err_msg)
489 })?;
490
491 format!("data: {}\n\n", chunk_str)
492 };
493
494 let ending_chunk = "data: [DONE]\n\n".to_string();
496
497 let chunks = vec![context_full_chunk, usage_chunk, ending_chunk];
498
499 Ok(ChatStream::new(
500 Some(graph.name().to_owned()),
501 id,
502 include_usage,
503 Some(chunks),
504 ))
505 }
506 Err(wasmedge_wasi_nn::Error::BackendError(
507 wasmedge_wasi_nn::BackendError::PromptTooLong,
508 )) => {
509 #[cfg(feature = "logging")]
510 warn!(target: "stdout", "The prompt is too long. Please reduce the length of your input and try again.");
511
512 let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
514 let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
515 let err_msg = format!(
516 "Failed to decode the buffer of the inference result to a utf-8 string. {}",
517 e
518 );
519
520 #[cfg(feature = "logging")]
521 error!(target: "stdout", "{}", &err_msg);
522
523 LlamaCoreError::Operation(err_msg)
524 })?;
525
526 let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
528 let err_msg = format!("Failed to post-process the output. {}", e);
529
530 #[cfg(feature = "logging")]
531 error!(target: "stdout", "{}", &err_msg);
532
533 LlamaCoreError::Operation(err_msg)
534 })?;
535
536 let token_info = get_token_info_by_graph(graph)?;
538
539 #[cfg(feature = "logging")]
540 info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
541
542 let usage = Some(Usage {
543 prompt_tokens: token_info.prompt_tokens,
544 completion_tokens: token_info.completion_tokens,
545 total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
546 });
547
548 let created = SystemTime::now()
549 .duration_since(std::time::UNIX_EPOCH)
550 .map_err(|e| {
551 let err_msg = format!("Failed to get the current time. Reason: {}", e);
552
553 #[cfg(feature = "logging")]
554 error!(target: "stdout", "{}", &err_msg);
555
556 LlamaCoreError::Operation(err_msg)
557 })?;
558
559 let prompt_too_long_chunk = {
561 let chat_completion_chunk = ChatCompletionChunk {
562 id: id.clone(),
563 object: "chat.completion.chunk".to_string(),
564 created: created.as_secs(),
565 model: graph.name().to_owned(),
566 system_fingerprint: "fp_44709d6fcb".to_string(),
567 choices: vec![ChatCompletionChunkChoice {
568 index: 0,
569 delta: ChatCompletionChunkChoiceDelta {
570 role: ChatCompletionRole::Assistant,
571 content: Some(message),
572 tool_calls: vec![],
573 },
574 logprobs: None,
575 finish_reason: Some(FinishReason::length),
576 }],
577 usage: None,
578 };
579
580 let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
582 let err_msg =
583 format!("Failed to serialize chat completion chunk. Reason: {}", e);
584
585 #[cfg(feature = "logging")]
586 error!(target: "stdout", "{}", &err_msg);
587
588 LlamaCoreError::Operation(err_msg)
589 })?;
590
591 format!("data: {}\n\n", chunk_str)
592 };
593
594 let usage_chunk = {
596 let chat_completion_chunk = ChatCompletionChunk {
597 id: id.clone(),
598 object: "chat.completion.chunk".to_string(),
599 created: created.as_secs(),
600 model: graph.name().to_owned(),
601 system_fingerprint: "fp_44709d6fcb".to_string(),
602 choices: vec![],
603 usage,
604 };
605
606 let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
608 let err_msg =
609 format!("Failed to serialize chat completion chunk. Reason: {}", e);
610
611 #[cfg(feature = "logging")]
612 error!(target: "stdout", "{}", &err_msg);
613
614 LlamaCoreError::Operation(err_msg)
615 })?;
616
617 format!("data: {}\n\n", chunk_str)
618 };
619
620 let ending_chunk = "data: [DONE]\n\n".to_string();
622
623 let chunks = vec![prompt_too_long_chunk, usage_chunk, ending_chunk];
624
625 Ok(ChatStream::new(
626 Some(graph.name().to_owned()),
627 id,
628 include_usage,
629 Some(chunks),
630 ))
631 }
632 Err(e) => {
633 let err_msg = format!("Failed to compute the chat completion. Reason: {}", e);
634
635 #[cfg(feature = "logging")]
636 error!(target: "stdout", "{}", &err_msg);
637
638 Err(LlamaCoreError::Backend(BackendError::Compute(err_msg)))
639 }
640 }
641}
642
643async fn chat_once(
644 chat_request: &mut ChatCompletionRequest,
645) -> Result<ChatCompletionObject, LlamaCoreError> {
646 #[cfg(feature = "logging")]
647 info!(target: "stdout", "Processing chat completion request in non-stream mode");
648
649 let running_mode = running_mode()?;
650 if !running_mode.contains(RunningMode::CHAT) && !running_mode.contains(RunningMode::RAG) {
651 let err_msg = "The chat completion is only supported in the chat or rag mode.";
652
653 #[cfg(feature = "logging")]
654 error!(target: "stdout", "{}", err_msg);
655
656 return Err(LlamaCoreError::Operation(err_msg.to_string()));
657 }
658
659 let model_name = chat_request.model.clone();
660 let id = match &chat_request.user {
661 Some(id) => id.clone(),
662 None => gen_chat_id(),
663 };
664
665 #[cfg(feature = "logging")]
666 info!(target: "stdout", "user: {}", &id);
667
668 #[cfg(feature = "logging")]
669 info!(target: "stdout", "Check model metadata");
670
671 let mut metadata = check_model_metadata(chat_request).await?;
673
674 #[cfg(feature = "logging")]
675 info!(target: "stdout", "Build the chat prompt");
676
677 let (prompt, avaible_completion_tokens, tool_use) =
679 build_prompt(model_name.as_ref(), chat_request)?;
680
681 #[cfg(feature = "logging")]
682 {
683 info!(target: "stdout", "prompt:\n{}", &prompt);
684 info!(target: "stdout", "available_completion_tokens: {}", avaible_completion_tokens);
685 info!(target: "stdout", "tool_use: {}", tool_use);
686 }
687
688 #[cfg(feature = "logging")]
689 info!(target: "stdout", "Update n_predict");
690
691 update_n_predict(chat_request, &mut metadata, avaible_completion_tokens).await?;
693
694 #[cfg(feature = "logging")]
695 info!(target: "stdout", "Feed the prompt to the model");
696
697 set_prompt(model_name.as_ref(), &prompt)?;
699
700 #[cfg(feature = "logging")]
701 info!(target: "stdout", "Compute chat completion.");
702
703 let res = compute(model_name.as_ref(), id, tool_use);
705
706 #[cfg(feature = "logging")]
707 info!(target: "stdout", "End of the chat completion");
708
709 res
710}
711
712fn compute(
713 model_name: Option<&String>,
714 id: impl Into<String>,
715 tool_use: bool,
716) -> Result<ChatCompletionObject, LlamaCoreError> {
717 let chat_graphs = match CHAT_GRAPHS.get() {
718 Some(chat_graphs) => chat_graphs,
719 None => {
720 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
721
722 #[cfg(feature = "logging")]
723 error!(target: "stdout", "{}", &err_msg);
724
725 return Err(LlamaCoreError::Operation(err_msg.into()));
726 }
727 };
728
729 let mut chat_graphs = chat_graphs.lock().map_err(|e| {
730 let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {}", e);
731
732 #[cfg(feature = "logging")]
733 error!(target: "stdout", "{}", &err_msg);
734
735 LlamaCoreError::Operation(err_msg)
736 })?;
737
738 match model_name {
739 Some(model_name) => match chat_graphs.contains_key(model_name) {
740 true => {
741 let graph = chat_graphs.get_mut(model_name).unwrap();
742 compute_by_graph(graph, id, tool_use)
743 }
744 false => match chat_graphs.iter_mut().next() {
745 Some((_, graph)) => compute_by_graph(graph, id, tool_use),
746 None => {
747 let err_msg = "There is no model available in the chat graphs.";
748
749 #[cfg(feature = "logging")]
750 error!(target: "stdout", "{}", &err_msg);
751
752 Err(LlamaCoreError::Operation(err_msg.into()))
753 }
754 },
755 },
756 None => match chat_graphs.iter_mut().next() {
757 Some((_, graph)) => compute_by_graph(graph, id, tool_use),
758 None => {
759 let err_msg = "There is no model available in the chat graphs.";
760
761 #[cfg(feature = "logging")]
762 error!(target: "stdout", "{}", &err_msg);
763
764 Err(LlamaCoreError::Operation(err_msg.into()))
765 }
766 },
767 }
768}
769
770fn compute_by_graph(
771 graph: &mut Graph<GgmlMetadata>,
772 id: impl Into<String>,
773 tool_use: bool,
774) -> Result<ChatCompletionObject, LlamaCoreError> {
775 #[cfg(feature = "logging")]
776 info!(target: "stdout", "Compute chat completion by the model named {}.", graph.name());
777
778 match graph.compute() {
779 Ok(_) => {
780 let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
782 let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
783 let err_msg = format!(
784 "Failed to decode the buffer of the inference result to a utf-8 string. {}",
785 e
786 );
787
788 #[cfg(feature = "logging")]
789 error!(target: "stdout", "{}", &err_msg);
790
791 LlamaCoreError::Operation(err_msg)
792 })?;
793
794 #[cfg(feature = "logging")]
795 info!(target: "stdout", "raw generation: {}", output);
796
797 let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
799 LlamaCoreError::Operation(format!("Failed to post-process the output. {}", e))
800 })?;
801
802 #[cfg(feature = "logging")]
803 info!(target: "stdout", "post-processed generation:\n{}", &message);
804
805 let token_info = get_token_info_by_graph(graph)?;
807
808 #[cfg(feature = "logging")]
809 info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
810
811 let created = SystemTime::now()
812 .duration_since(std::time::UNIX_EPOCH)
813 .map_err(|e| {
814 let err_msg = format!("Failed to get the current time. Reason: {}", e);
815
816 #[cfg(feature = "logging")]
817 error!(target: "stdout", "{}", &err_msg);
818
819 LlamaCoreError::Operation(err_msg)
820 })?;
821
822 match tool_use {
823 true => {
824 if graph.metadata.prompt_template != PromptTemplateType::MistralTool
825 && graph.metadata.prompt_template != PromptTemplateType::ChatMLTool
826 && graph.metadata.prompt_template != PromptTemplateType::GroqLlama3Tool
827 && graph.metadata.prompt_template != PromptTemplateType::Llama3Tool
828 && graph.metadata.prompt_template != PromptTemplateType::InternLM2Tool
829 && graph.metadata.prompt_template != PromptTemplateType::NemotronTool
830 && graph.metadata.prompt_template != PromptTemplateType::FunctionaryV32
831 && graph.metadata.prompt_template != PromptTemplateType::FunctionaryV31
832 && graph.metadata.prompt_template != PromptTemplateType::MistralSmallTool
833 {
834 let err_msg = format!("Unsupported prompt template: {}. The tool use is only supported for 'mistral-tool', 'chatml-tool', 'groq-llama3-tool', 'llama-3-tool', 'internlm-2-tool', 'nemotron-tool', 'functionary-31', 'functionary-32', and 'mistral-small-tool' prompt templates.", graph.metadata.prompt_template);
835
836 #[cfg(feature = "logging")]
837 error!(target: "stdout", "{}", &err_msg);
838
839 return Err(LlamaCoreError::Operation(err_msg));
840 }
841
842 let parsed_result = parse_tool_calls(&message, graph.metadata.prompt_template)?;
843
844 let finish_reason = if parsed_result.tool_calls.is_empty() {
845 FinishReason::stop
846 } else {
847 FinishReason::tool_calls
848 };
849
850 let content = match parsed_result.content {
851 Some(content) => Some(content),
852 None => Some(parsed_result.raw),
853 };
854
855 Ok(ChatCompletionObject {
857 id: id.into(),
858 object: String::from("chat.completion"),
859 created: created.as_secs(),
860 model: graph.name().to_owned(),
861 choices: vec![ChatCompletionObjectChoice {
862 index: 0,
863 message: ChatCompletionObjectMessage {
864 role: ChatCompletionRole::Assistant,
865 content,
866 tool_calls: parsed_result.tool_calls,
867 function_call: None,
868 },
869 finish_reason,
870 logprobs: None,
871 }],
872 usage: Usage {
873 prompt_tokens: token_info.prompt_tokens,
874 completion_tokens: token_info.completion_tokens,
875 total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
876 },
877 })
878 }
879 false => {
880 Ok(ChatCompletionObject {
882 id: id.into(),
883 object: String::from("chat.completion"),
884 created: created.as_secs(),
885 model: graph.name().to_owned(),
886 choices: vec![ChatCompletionObjectChoice {
887 index: 0,
888 message: ChatCompletionObjectMessage {
889 role: ChatCompletionRole::Assistant,
890 content: Some(message),
891 tool_calls: vec![],
892 function_call: None,
893 },
894 finish_reason: FinishReason::stop,
895 logprobs: None,
896 }],
897 usage: Usage {
898 prompt_tokens: token_info.prompt_tokens,
899 completion_tokens: token_info.completion_tokens,
900 total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
901 },
902 })
903 }
904 }
905 }
906 Err(wasmedge_wasi_nn::Error::BackendError(wasmedge_wasi_nn::BackendError::ContextFull)) => {
907 let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
909 let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
910 let err_msg = format!(
911 "Failed to decode the buffer of the inference result to a utf-8 string. {}",
912 e
913 );
914
915 #[cfg(feature = "logging")]
916 error!(target: "stdout", "{}", &err_msg);
917
918 LlamaCoreError::Operation(err_msg)
919 })?;
920
921 let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
923 let err_msg = format!("Failed to post-process the output. {}", e);
924
925 #[cfg(feature = "logging")]
926 error!(target: "stdout", "{}", &err_msg);
927
928 LlamaCoreError::Operation(err_msg)
929 })?;
930
931 let token_info = get_token_info_by_graph(graph)?;
933
934 #[cfg(feature = "logging")]
935 info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
936
937 let created = SystemTime::now()
938 .duration_since(std::time::UNIX_EPOCH)
939 .map_err(|e| {
940 let err_msg = format!("Failed to get the current time. Reason: {}", e);
941
942 #[cfg(feature = "logging")]
943 error!(target: "stdout", "{}", &err_msg);
944
945 LlamaCoreError::Operation(err_msg)
946 })?;
947
948 Ok(ChatCompletionObject {
950 id: id.into(),
951 object: String::from("chat.completion"),
952 created: created.as_secs(),
953 model: graph.name().to_owned(),
954 choices: vec![ChatCompletionObjectChoice {
955 index: 0,
956 message: ChatCompletionObjectMessage {
957 role: ChatCompletionRole::Assistant,
958 content: Some(message),
959 tool_calls: vec![],
960 function_call: None,
961 },
962 finish_reason: FinishReason::length,
963 logprobs: None,
964 }],
965 usage: Usage {
966 prompt_tokens: token_info.prompt_tokens,
967 completion_tokens: token_info.completion_tokens,
968 total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
969 },
970 })
971 }
972 Err(wasmedge_wasi_nn::Error::BackendError(
973 wasmedge_wasi_nn::BackendError::PromptTooLong,
974 )) => {
975 #[cfg(feature = "logging")]
976 warn!(target: "stdout", "The prompt is too long. Please reduce the length of your input and try again.");
977
978 let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
980 let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
981 let err_msg = format!(
982 "Failed to decode the buffer of the inference result to a utf-8 string. {}",
983 e
984 );
985
986 #[cfg(feature = "logging")]
987 error!(target: "stdout", "{}", &err_msg);
988
989 LlamaCoreError::Operation(err_msg)
990 })?;
991
992 let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
994 let err_msg = format!("Failed to post-process the output. {}", e);
995
996 #[cfg(feature = "logging")]
997 error!(target: "stdout", "{}", &err_msg);
998
999 LlamaCoreError::Operation(err_msg)
1000 })?;
1001
1002 let token_info = get_token_info_by_graph(graph)?;
1004
1005 #[cfg(feature = "logging")]
1006 info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
1007
1008 let created = SystemTime::now()
1009 .duration_since(std::time::UNIX_EPOCH)
1010 .map_err(|e| {
1011 let err_msg = format!("Failed to get the current time. Reason: {}", e);
1012
1013 #[cfg(feature = "logging")]
1014 error!(target: "stdout", "{}", &err_msg);
1015
1016 LlamaCoreError::Operation(err_msg)
1017 })?;
1018
1019 Ok(ChatCompletionObject {
1021 id: id.into(),
1022 object: String::from("chat.completion"),
1023 created: created.as_secs(),
1024 model: graph.name().to_owned(),
1025 choices: vec![ChatCompletionObjectChoice {
1026 index: 0,
1027 message: ChatCompletionObjectMessage {
1028 role: ChatCompletionRole::Assistant,
1029 content: Some(message),
1030 tool_calls: vec![],
1031 function_call: None,
1032 },
1033 finish_reason: FinishReason::length,
1034 logprobs: None,
1035 }],
1036 usage: Usage {
1037 prompt_tokens: token_info.prompt_tokens,
1038 completion_tokens: token_info.completion_tokens,
1039 total_tokens: token_info.completion_tokens + token_info.completion_tokens,
1040 },
1041 })
1042 }
1043 Err(e) => {
1044 let err_msg = format!("Failed to compute the chat completion. Reason: {}", e);
1045
1046 #[cfg(feature = "logging")]
1047 error!(target: "stdout", "{}", &err_msg);
1048
1049 Err(LlamaCoreError::Backend(BackendError::Compute(err_msg)))
1050 }
1051 }
1052}
1053
1054fn parse_tool_calls(
1055 input: &str,
1056 prompt_template: PromptTemplateType,
1057) -> Result<ParseResult, LlamaCoreError> {
1058 match prompt_template {
1059 PromptTemplateType::MistralTool => match regex::Regex::new(r"\[\{.*?\}\]") {
1060 Ok(re) => {
1061 let mut values: Vec<serde_json::Value> = vec![];
1062 for cap in re.captures_iter(input) {
1063 let matched = &cap[0];
1064
1065 #[cfg(feature = "logging")]
1066 info!(target: "stdout", "captured: {}", matched);
1067
1068 match serde_json::from_str::<Vec<serde_json::Value>>(matched) {
1069 Ok(group) => values.extend(group),
1070 Err(e) => {
1071 let err_msg = format!(
1072 "Failed to deserialize generated tool calls. Reason: {}",
1073 e
1074 );
1075
1076 #[cfg(feature = "logging")]
1077 error!(target: "stdout", "{}", &err_msg);
1078
1079 return Err(LlamaCoreError::Operation(err_msg));
1080 }
1081 }
1082 }
1083
1084 let mut tool_calls: Vec<ToolCall> = vec![];
1085 for value in values.iter() {
1086 let name = match value.get("name") {
1087 Some(name) => name.to_string().replace("\"", ""),
1088 None => {
1089 let err_msg = format!(
1090 "Failed to get the name of the function. Tool call: {:?}",
1091 value
1092 );
1093
1094 #[cfg(feature = "logging")]
1095 error!(target: "stdout", "{}", &err_msg);
1096
1097 return Err(LlamaCoreError::Operation(err_msg));
1098 }
1099 };
1100
1101 let arguments = match value.get("arguments") {
1102 Some(arguments) => arguments.to_string(),
1103 None => {
1104 let err_msg = format!(
1105 "Failed to get the arguments of the function. Tool call: {:?}",
1106 value
1107 );
1108
1109 #[cfg(feature = "logging")]
1110 error!(target: "stdout", "{}", &err_msg);
1111
1112 return Err(LlamaCoreError::Operation(err_msg));
1113 }
1114 };
1115
1116 let function = Function { name, arguments };
1117
1118 let tool_call = ToolCall {
1119 id: "call_abc123".to_string(),
1120 ty: "function".to_string(),
1121 function,
1122 };
1123
1124 tool_calls.push(tool_call);
1125 }
1126
1127 let parsed = ParseResult {
1128 raw: input.to_owned(),
1129 content: None,
1130 tool_calls,
1131 };
1132
1133 #[cfg(feature = "logging")]
1134 info!(target: "stdout", "parsed result: {:?}", parsed);
1135
1136 Ok(parsed)
1137 }
1138 Err(e) => {
1139 let err_msg = format!("Failed to create a regex pattern. Reason: {}", e);
1140
1141 #[cfg(feature = "logging")]
1142 error!(target: "stdout", "{}", &err_msg);
1143
1144 Err(LlamaCoreError::Operation(err_msg))
1145 }
1146 },
1147 PromptTemplateType::ChatMLTool => {
1148 match regex::Regex::new(r"<tool_call>(.*?)</tool_call>") {
1149 Ok(re) => {
1150 let mut values: Vec<serde_json::Value> = vec![];
1151 for cap in re.captures_iter(input) {
1152 let matched = cap[1].replace("\\n", ""); #[cfg(feature = "logging")]
1155 info!(target: "stdout", "captured: {}", &matched);
1156
1157 match serde_json::from_str::<serde_json::Value>(&matched) {
1158 Ok(value) => values.push(value),
1159 Err(e) => {
1160 let err_msg = format!(
1161 "Failed to deserialize generated tool calls. Reason: {}",
1162 e
1163 );
1164
1165 #[cfg(feature = "logging")]
1166 error!(target: "stdout", "{}", &err_msg);
1167
1168 return Err(LlamaCoreError::Operation(err_msg));
1169 }
1170 }
1171 }
1172
1173 let mut tool_calls: Vec<ToolCall> = vec![];
1174 for value in values.iter() {
1175 let name = match value.get("name") {
1176 Some(name) => name.to_string().replace("\"", ""),
1177 None => {
1178 let err_msg = format!(
1179 "Failed to get the name of the function. Tool call: {:?}",
1180 value
1181 );
1182
1183 #[cfg(feature = "logging")]
1184 error!(target: "stdout", "{}", &err_msg);
1185
1186 return Err(LlamaCoreError::Operation(err_msg));
1187 }
1188 };
1189
1190 let arguments = match value.get("arguments") {
1191 Some(arguments) => arguments.to_string(),
1192 None => {
1193 let err_msg = format!(
1194 "Failed to get the arguments of the function. Tool call: {:?}",
1195 value
1196 );
1197
1198 #[cfg(feature = "logging")]
1199 error!(target: "stdout", "{}", &err_msg);
1200
1201 return Err(LlamaCoreError::Operation(err_msg));
1202 }
1203 };
1204
1205 let function = Function { name, arguments };
1206
1207 let tool_call = ToolCall {
1208 id: "call_abc123".to_string(),
1209 ty: "function".to_string(),
1210 function,
1211 };
1212
1213 tool_calls.push(tool_call);
1214 }
1215
1216 let parsed = ParseResult {
1217 raw: input.to_owned(),
1218 content: None,
1219 tool_calls,
1220 };
1221
1222 #[cfg(feature = "logging")]
1223 info!(target: "stdout", "parsed result: {:?}", parsed);
1224
1225 Ok(parsed)
1226 }
1227 Err(e) => {
1228 let err_msg = format!("Failed to create a regex pattern. Reason: {}", e);
1229
1230 #[cfg(feature = "logging")]
1231 error!(target: "stdout", "{}", &err_msg);
1232
1233 Err(LlamaCoreError::Operation(err_msg))
1234 }
1235 }
1236 }
1237 PromptTemplateType::GroqLlama3Tool => {
1238 #[cfg(feature = "logging")]
1239 info!(target: "stdout", "raw input: {}", input);
1240
1241 match regex::Regex::new(r"(?s)<tool_call>((.|\r|\n)*?)</tool_call>") {
1242 Ok(re) => {
1243 let mut values: Vec<serde_json::Value> = vec![];
1244 for cap in re.captures_iter(input) {
1245 let matched = cap[1].trim();
1246
1247 #[cfg(feature = "logging")]
1248 info!(target: "stdout", "captured: {}", matched);
1249
1250 match serde_json::from_str::<serde_json::Value>(matched) {
1251 Ok(value) => values.push(value),
1252 Err(e) => {
1253 let err_msg = format!(
1254 "Failed to deserialize generated tool calls. Reason: {}",
1255 e
1256 );
1257
1258 #[cfg(feature = "logging")]
1259 error!(target: "stdout", "{}", &err_msg);
1260
1261 return Err(LlamaCoreError::Operation(err_msg));
1262 }
1263 }
1264 }
1265
1266 let mut tool_calls: Vec<ToolCall> = vec![];
1267 for value in values.iter() {
1268 let name = match value.get("name") {
1269 Some(name) => name.to_string().replace("\"", ""),
1270 None => {
1271 let err_msg = format!(
1272 "Failed to get the name of the function. Tool call: {:?}",
1273 value
1274 );
1275
1276 #[cfg(feature = "logging")]
1277 error!(target: "stdout", "{}", &err_msg);
1278
1279 return Err(LlamaCoreError::Operation(err_msg));
1280 }
1281 };
1282
1283 let arguments = match value.get("arguments") {
1284 Some(arguments) => arguments.to_string(),
1285 None => {
1286 let err_msg = format!(
1287 "Failed to get the arguments of the function. Tool call: {:?}",
1288 value
1289 );
1290
1291 #[cfg(feature = "logging")]
1292 error!(target: "stdout", "{}", &err_msg);
1293
1294 return Err(LlamaCoreError::Operation(err_msg));
1295 }
1296 };
1297
1298 let function = Function { name, arguments };
1299
1300 let tool_call = ToolCall {
1301 id: "call_abc123".to_string(),
1302 ty: "function".to_string(),
1303 function,
1304 };
1305
1306 tool_calls.push(tool_call);
1307 }
1308
1309 let parsed = ParseResult {
1310 raw: input.to_owned(),
1311 content: None,
1312 tool_calls,
1313 };
1314
1315 #[cfg(feature = "logging")]
1316 info!(target: "stdout", "parsed result: {:?}", parsed);
1317
1318 Ok(parsed)
1319 }
1320 Err(e) => {
1321 let err_msg = format!("Failed to create a regex pattern. Reason: {}", e);
1322
1323 #[cfg(feature = "logging")]
1324 error!(target: "stdout", "{}", &err_msg);
1325
1326 Err(LlamaCoreError::Operation(err_msg))
1327 }
1328 }
1329 }
1330 PromptTemplateType::Llama3Tool => {
1331 #[cfg(feature = "logging")]
1332 info!(target: "stdout", "raw input: {}", input);
1333
1334 let re = match regex::Regex::new(r"^\{(.|\r|\n)*\}$") {
1335 Ok(re) => re,
1336 Err(e) => {
1337 let err_msg = format!("Failed to create a regex pattern. Reason: {}", e);
1338
1339 #[cfg(feature = "logging")]
1340 error!(target: "stdout", "{}", &err_msg);
1341
1342 return Err(LlamaCoreError::Operation(err_msg));
1343 }
1344 };
1345
1346 if re.is_match(input) {
1347 match serde_json::from_str::<serde_json::Value>(input) {
1348 Ok(value) => {
1349 let values: Vec<serde_json::Value> = vec![value];
1350
1351 let mut tool_calls: Vec<ToolCall> = vec![];
1352 for value in values.iter() {
1353 let name = match value.get("name") {
1354 Some(name) => name.to_string().replace("\"", ""),
1355 None => {
1356 let err_msg = format!(
1357 "Failed to get the name of the function. Tool call: {:?}",
1358 value
1359 );
1360
1361 #[cfg(feature = "logging")]
1362 error!(target: "stdout", "{}", &err_msg);
1363
1364 return Err(LlamaCoreError::Operation(err_msg));
1365 }
1366 };
1367
1368 let arguments = match value.get("parameters") {
1369 Some(arguments) => arguments.to_string(),
1370 None => {
1371 let err_msg = format!(
1372 "Failed to get the arguments of the function. Tool call: {:?}",
1373 value
1374 );
1375
1376 #[cfg(feature = "logging")]
1377 error!(target: "stdout", "{}", &err_msg);
1378
1379 return Err(LlamaCoreError::Operation(err_msg));
1380 }
1381 };
1382
1383 let function = Function { name, arguments };
1384
1385 let tool_call = ToolCall {
1386 id: "call_abc123".to_string(),
1387 ty: "function".to_string(),
1388 function,
1389 };
1390
1391 tool_calls.push(tool_call);
1392 }
1393
1394 let parsed = ParseResult {
1395 raw: input.to_owned(),
1396 content: None,
1397 tool_calls,
1398 };
1399
1400 #[cfg(feature = "logging")]
1401 info!(target: "stdout", "parsed result: {:?}", parsed);
1402
1403 Ok(parsed)
1404 }
1405 Err(e) => {
1406 let err_msg =
1407 format!("Failed to deserialize generated tool calls. Reason: {}", e);
1408
1409 #[cfg(feature = "logging")]
1410 error!(target: "stdout", "{}", &err_msg);
1411
1412 Err(LlamaCoreError::Operation(err_msg))
1413 }
1414 }
1415 } else {
1416 let parsed = ParseResult {
1417 raw: input.to_owned(),
1418 content: None,
1419 tool_calls: vec![],
1420 };
1421
1422 #[cfg(feature = "logging")]
1423 info!(target: "stdout", "parsed result: {:?}", parsed);
1424
1425 Ok(parsed)
1426 }
1427 }
1428 PromptTemplateType::InternLM2Tool => {
1429 #[cfg(feature = "logging")]
1430 info!(target: "stdout", "raw input: {}", input);
1431
1432 let blocks: Vec<&str> = input.trim().split("<|action_start|><|plugin|>").collect();
1433
1434 #[cfg(feature = "logging")]
1435 info!(target: "stdout", "blocks: {:?}", blocks);
1436
1437 let mut tool_calls: Vec<ToolCall> = vec![];
1438 let mut content = String::new();
1439 for block in blocks {
1440 let block = block.trim();
1441 if !block.is_empty() {
1442 if block.ends_with("<|action_end|>") {
1443 let value = block.trim().trim_end_matches("<|action_end|>");
1444
1445 #[cfg(feature = "logging")]
1446 info!(target: "stdout", "tool call: {}", value);
1447
1448 match serde_json::from_str::<serde_json::Value>(value) {
1449 Ok(value) => {
1450 let name = match value.get("name") {
1451 Some(name) => name.to_string().replace("\"", ""),
1452 None => {
1453 let err_msg = format!(
1454 "Failed to get the name of the function. Tool call: {:?}",
1455 value
1456 );
1457
1458 #[cfg(feature = "logging")]
1459 error!(target: "stdout", "{}", &err_msg);
1460
1461 return Err(LlamaCoreError::Operation(err_msg));
1462 }
1463 };
1464
1465 let arguments = match value.get("parameters") {
1466 Some(arguments) => arguments.to_string(),
1467 None => {
1468 let err_msg = format!(
1469 "Failed to get the arguments of the function. Tool call: {:?}",
1470 value
1471 );
1472
1473 #[cfg(feature = "logging")]
1474 error!(target: "stdout", "{}", &err_msg);
1475
1476 return Err(LlamaCoreError::Operation(err_msg));
1477 }
1478 };
1479
1480 let function = Function { name, arguments };
1481
1482 let tool_call = ToolCall {
1483 id: "call_abc123".to_string(),
1484 ty: "function".to_string(),
1485 function,
1486 };
1487
1488 tool_calls.push(tool_call);
1489 }
1490 Err(e) => {
1491 let err_msg = format!(
1492 "Failed to deserialize generated tool calls. Reason: {}",
1493 e
1494 );
1495
1496 #[cfg(feature = "logging")]
1497 error!(target: "stdout", "{}", &err_msg);
1498
1499 return Err(LlamaCoreError::Operation(err_msg));
1500 }
1501 }
1502 } else {
1503 content.push_str(block);
1504 content.push('\n');
1505 }
1506 }
1507 }
1508
1509 let parsed = match content.is_empty() {
1510 true => ParseResult {
1511 raw: input.to_owned(),
1512 content: None,
1513 tool_calls,
1514 },
1515 false => ParseResult {
1516 raw: input.to_owned(),
1517 content: Some(content.trim().to_owned()),
1518 tool_calls,
1519 },
1520 };
1521
1522 #[cfg(feature = "logging")]
1523 info!(target: "stdout", "parsed result: {:?}", parsed);
1524
1525 Ok(parsed)
1526 }
1527 PromptTemplateType::NemotronTool => {
1528 #[cfg(feature = "logging")]
1529 info!(target: "stdout", "raw input: {}", input);
1530
1531 match regex::Regex::new(r"(?s)<toolcall>\s*(.*?)\s*</toolcall>") {
1532 Ok(re) => {
1533 let mut values: Vec<serde_json::Value> = vec![];
1534 for cap in re.captures_iter(input) {
1535 #[cfg(feature = "logging")]
1536 info!(target: "stdout", "captured: {}", &cap[0]);
1537
1538 #[cfg(feature = "logging")]
1539 info!(target: "stdout", "extracted: {}", &cap[1]);
1540
1541 let matched = cap[1].trim();
1542
1543 #[cfg(feature = "logging")]
1544 info!(target: "stdout", "captured: {}", matched);
1545
1546 match serde_json::from_str::<serde_json::Value>(matched) {
1547 Ok(value) => values.push(value),
1548 Err(e) => {
1549 let err_msg = format!(
1550 "Failed to deserialize generated tool calls. Reason: {}",
1551 e
1552 );
1553
1554 #[cfg(feature = "logging")]
1555 error!(target: "stdout", "{}", &err_msg);
1556
1557 return Err(LlamaCoreError::Operation(err_msg));
1558 }
1559 }
1560 }
1561
1562 let mut tool_calls: Vec<ToolCall> = vec![];
1563 for value in values.iter() {
1564 let name = match value.get("name") {
1565 Some(name) => name.to_string().replace("\"", ""),
1566 None => {
1567 let err_msg = format!(
1568 "Failed to get the name of the function. Tool call: {:?}",
1569 value
1570 );
1571
1572 #[cfg(feature = "logging")]
1573 error!(target: "stdout", "{}", &err_msg);
1574
1575 return Err(LlamaCoreError::Operation(err_msg));
1576 }
1577 };
1578
1579 let arguments = match value.get("arguments") {
1580 Some(arguments) => arguments.to_string(),
1581 None => {
1582 let err_msg = format!(
1583 "Failed to get the arguments of the function. Tool call: {:?}",
1584 value
1585 );
1586
1587 #[cfg(feature = "logging")]
1588 error!(target: "stdout", "{}", &err_msg);
1589
1590 return Err(LlamaCoreError::Operation(err_msg));
1591 }
1592 };
1593
1594 let function = Function { name, arguments };
1595
1596 let tool_call = ToolCall {
1597 id: "call_abc123".to_string(),
1598 ty: "function".to_string(),
1599 function,
1600 };
1601
1602 tool_calls.push(tool_call);
1603 }
1604
1605 let parsed = ParseResult {
1606 raw: input.to_owned(),
1607 content: None,
1608 tool_calls,
1609 };
1610
1611 #[cfg(feature = "logging")]
1612 info!(target: "stdout", "parsed result: {:?}", parsed);
1613
1614 Ok(parsed)
1615 }
1616 Err(e) => {
1617 let err_msg = format!("Failed to create a regex pattern. Reason: {}", e);
1618
1619 #[cfg(feature = "logging")]
1620 error!(target: "stdout", "{}", &err_msg);
1621
1622 Err(LlamaCoreError::Operation(err_msg))
1623 }
1624 }
1625 }
1626 PromptTemplateType::FunctionaryV32 => {
1627 #[cfg(feature = "logging")]
1628 info!(target: "stdout", "raw input: {}", input);
1629
1630 match regex::Regex::new(r">>>\s*(\w+)\s*\{(.*)\}<\|eot_id\|>") {
1631 Ok(re) => {
1632 let mut tool_calls: Vec<ToolCall> = vec![];
1633 for cap in re.captures_iter(input) {
1634 #[cfg(feature = "logging")]
1635 info!(target: "stdout", "func_name: {}", &cap[1]);
1636
1637 #[cfg(feature = "logging")]
1638 info!(target: "stdout", "arguments: {}", &cap[2]);
1639
1640 let tool_call = ToolCall {
1641 id: "call_abc123".to_string(),
1642 ty: "function".to_string(),
1643 function: Function {
1644 name: cap[1].to_string(),
1645 arguments: cap[2].to_string(),
1646 },
1647 };
1648
1649 tool_calls.push(tool_call);
1650 }
1651
1652 let parsed = ParseResult {
1653 raw: input.to_owned(),
1654 content: None,
1655 tool_calls,
1656 };
1657
1658 #[cfg(feature = "logging")]
1659 info!(target: "stdout", "parsed result: {:?}", parsed);
1660
1661 Ok(parsed)
1662 }
1663 Err(e) => {
1664 let warn_msg = format!("Failed to create a regex pattern. Reason: {}", e);
1665
1666 #[cfg(feature = "logging")]
1667 warn!(target: "stdout", "{}", &warn_msg);
1668
1669 Ok(ParseResult {
1670 raw: input.to_owned(),
1671 content: None,
1672 tool_calls: vec![],
1673 })
1674 }
1675 }
1676 }
1677 PromptTemplateType::FunctionaryV31 => {
1678 #[cfg(feature = "logging")]
1679 info!(target: "stdout", "raw input: {}", input);
1680
1681 match regex::Regex::new(r"<function=(\w+)>\s*(\{.*?\})</function>") {
1682 Ok(re) => {
1683 let mut tool_calls: Vec<ToolCall> = vec![];
1684 for cap in re.captures_iter(input) {
1685 #[cfg(feature = "logging")]
1686 info!(target: "stdout", "func_name: {}", &cap[1]);
1687
1688 #[cfg(feature = "logging")]
1689 info!(target: "stdout", "arguments: {}", &cap[2]);
1690
1691 let tool_call = ToolCall {
1692 id: "call_abc123".to_string(),
1693 ty: "function".to_string(),
1694 function: Function {
1695 name: cap[1].to_string(),
1696 arguments: cap[2].to_string(),
1697 },
1698 };
1699
1700 tool_calls.push(tool_call);
1701 }
1702
1703 let parsed = ParseResult {
1704 raw: input.to_owned(),
1705 content: None,
1706 tool_calls,
1707 };
1708
1709 #[cfg(feature = "logging")]
1710 info!(target: "stdout", "parsed result: {:?}", parsed);
1711
1712 Ok(parsed)
1713 }
1714 Err(e) => {
1715 let warn_msg = format!("Failed to create a regex pattern. Reason: {}", e);
1716
1717 #[cfg(feature = "logging")]
1718 warn!(target: "stdout", "{}", &warn_msg);
1719
1720 Ok(ParseResult {
1721 raw: input.to_owned(),
1722 content: None,
1723 tool_calls: vec![],
1724 })
1725 }
1726 }
1727 }
1728 PromptTemplateType::MistralSmallTool => {
1729 #[cfg(feature = "logging")]
1730 info!(target: "stdout", "raw input: {}", input);
1731
1732 match regex::Regex::new(r"\[TOOL_CALLS\]\s*(\[(.*?)\])") {
1733 Ok(re) => {
1734 let mut values: Vec<serde_json::Value> = vec![];
1735 if let Some(cap) = re.captures(input) {
1736 let matched = cap[1].trim();
1737
1738 #[cfg(feature = "logging")]
1739 info!(target: "stdout", "captured: {}", matched);
1740
1741 match serde_json::from_str::<Vec<serde_json::Value>>(matched) {
1742 Ok(vals) => values = vals,
1743 Err(e) => {
1744 let err_msg = format!(
1745 "Failed to deserialize generated tool calls. Reason: {}",
1746 e
1747 );
1748
1749 #[cfg(feature = "logging")]
1750 error!(target: "stdout", "{}", &err_msg);
1751
1752 return Err(LlamaCoreError::Operation(err_msg));
1753 }
1754 }
1755 };
1756
1757 let mut tool_calls: Vec<ToolCall> = vec![];
1758 for value in values.iter() {
1759 if let Some(object_map) = value.as_object() {
1760 if object_map.contains_key("function") {
1761 let mut function = Function {
1762 name: String::new(),
1763 arguments: String::new(),
1764 };
1765
1766 let value = object_map.get("function").unwrap();
1767 let func_map = value.as_object().unwrap();
1768 if func_map.contains_key("name") {
1769 let func_name = func_map.get("name").unwrap().as_str().unwrap();
1770 println!("Function name: {:?}", func_name);
1771
1772 function.name = func_name.to_string();
1773 }
1774 if func_map.contains_key("arguments") {
1775 let args = func_map.get("arguments").unwrap();
1776 let arguments = args.to_string();
1777 println!("Arguments: {:?}", arguments);
1778
1779 function.arguments = arguments;
1780 }
1781
1782 let tool_call = ToolCall {
1783 id: "call_abc123".to_string(),
1784 ty: "function".to_string(),
1785 function,
1786 };
1787
1788 tool_calls.push(tool_call);
1789 } else if object_map.contains_key("name") {
1790 let mut function = Function {
1791 name: String::new(),
1792 arguments: String::new(),
1793 };
1794
1795 let name = object_map.get("name").unwrap().as_str().unwrap();
1796 println!("name: {:?}", name);
1797 function.name = name.to_string();
1798
1799 if object_map.contains_key("arguments") {
1800 let args = object_map.get("arguments").unwrap();
1801 let arguments = args.to_string();
1802 println!("Arguments: {:?}", arguments);
1803
1804 function.arguments = arguments;
1805 }
1806
1807 let tool_call = ToolCall {
1808 id: "call_abc123".to_string(),
1809 ty: "function".to_string(),
1810 function,
1811 };
1812
1813 tool_calls.push(tool_call);
1814 }
1815 }
1816
1817 }
1862
1863 let parsed = ParseResult {
1864 raw: input.to_owned(),
1865 content: None,
1866 tool_calls,
1867 };
1868
1869 #[cfg(feature = "logging")]
1870 info!(target: "stdout", "parsed result: {:?}", parsed);
1871
1872 Ok(parsed)
1873 }
1874 Err(e) => {
1875 let err_msg = format!("Failed to create a regex pattern. Reason: {}", e);
1876
1877 #[cfg(feature = "logging")]
1878 error!(target: "stdout", "{}", &err_msg);
1879
1880 Err(LlamaCoreError::Operation(err_msg))
1881 }
1882 }
1883 }
1884 _ => {
1885 let err_msg = format!(
1886 "The tool use is only supported for prompt templates: {}, {}, {}, {}, {}, {}, {}, and {}.",
1887 PromptTemplateType::MistralTool,
1888 PromptTemplateType::ChatMLTool,
1889 PromptTemplateType::GroqLlama3Tool,
1890 PromptTemplateType::Llama3Tool,
1891 PromptTemplateType::InternLM2Tool,
1892 PromptTemplateType::NemotronTool,
1893 PromptTemplateType::FunctionaryV32,
1894 PromptTemplateType::MistralSmallTool,
1895 );
1896
1897 #[cfg(feature = "logging")]
1898 error!(target: "stdout", "{}", &err_msg);
1899
1900 Err(LlamaCoreError::Operation(err_msg))
1901 }
1902 }
1903}
1904
1905async fn check_model_metadata(
1906 chat_request: &ChatCompletionRequest,
1907) -> Result<GgmlMetadata, LlamaCoreError> {
1908 let mut should_update = false;
1909 let mut metadata = get_model_metadata(chat_request.model.as_ref())?;
1910
1911 if metadata.prompt_template.is_image_supported() {
1913 if let Some(ChatCompletionRequestMessage::User(user_message)) = chat_request.messages.last()
1914 {
1915 if let ChatCompletionUserMessageContent::Parts(parts) = user_message.content() {
1916 for part in parts {
1917 if let ContentPart::Image(image_part) = part {
1918 let image = image_part.image();
1919
1920 if image.is_url() {
1921 #[cfg(feature = "logging")]
1922 info!(target: "stdout", "The image is provided in URL format.");
1923
1924 let img_path_str = download_image(&image.url).await?;
1926
1927 #[cfg(feature = "logging")]
1928 info!(target: "stdout", "The image is saved to {}", img_path_str);
1929
1930 metadata.image = Some(img_path_str);
1932
1933 if !should_update {
1934 should_update = true;
1935 }
1936
1937 break;
1940 } else {
1941 #[cfg(feature = "logging")]
1942 info!(target: "stdout", "The image is provided in base64 format.");
1943
1944 break;
1947 }
1948 }
1949 }
1950 }
1951 }
1952 }
1953
1954 if let Some(temp) = chat_request.temperature {
1956 if metadata.temperature != temp {
1957 metadata.temperature = temp;
1959
1960 if !should_update {
1961 should_update = true;
1962 }
1963 }
1964 }
1965
1966 if let Some(top_p) = chat_request.top_p {
1968 if metadata.top_p != top_p {
1969 metadata.top_p = top_p;
1971
1972 if !should_update {
1973 should_update = true;
1974 }
1975 }
1976 }
1977
1978 if let Some(frequency_penalty) = chat_request.frequency_penalty {
1980 if metadata.frequency_penalty != frequency_penalty {
1981 metadata.frequency_penalty = frequency_penalty;
1983
1984 if !should_update {
1985 should_update = true;
1986 }
1987 }
1988 }
1989
1990 if let Some(presence_penalty) = chat_request.presence_penalty {
1992 if metadata.presence_penalty != presence_penalty {
1993 metadata.presence_penalty = presence_penalty;
1995
1996 if !should_update {
1997 should_update = true;
1998 }
1999 }
2000 }
2001
2002 if metadata.embeddings {
2004 metadata.embeddings = false;
2005
2006 if !should_update {
2007 should_update = true;
2008 }
2009 }
2010
2011 if should_update {
2012 #[cfg(feature = "logging")]
2013 info!(target: "stdout", "Update the model metadata.");
2014
2015 update_model_metadata(chat_request.model.as_ref(), &metadata)?;
2017 }
2018
2019 Ok(metadata)
2020}
2021
2022async fn update_n_predict(
2023 chat_request: &ChatCompletionRequest,
2024 metadata: &mut GgmlMetadata,
2025 available_completion_tokens: u64,
2026) -> Result<(), LlamaCoreError> {
2027 let mut should_update = false;
2028
2029 #[cfg(feature = "logging")]
2030 info!(target: "stdout", "n_predict: {}", metadata.n_predict);
2031
2032 if let Some(max_completion_tokens) = chat_request.max_completion_tokens {
2038 if metadata.n_predict != max_completion_tokens {
2039 #[cfg(feature = "logging")]
2040 info!(target: "stdout", "Update n_predict with max_completion_tokens from {} to {}", metadata.n_predict, max_completion_tokens);
2041
2042 metadata.n_predict = max_completion_tokens;
2043
2044 if !should_update {
2045 should_update = true;
2046 }
2047 }
2048 }
2049
2050 if metadata.n_predict == -2 {
2052 #[cfg(feature = "logging")]
2053 info!(target: "stdout", "Update n_predict with available_completion_tokens from {} to {}", metadata.n_predict, available_completion_tokens);
2054
2055 metadata.n_predict = available_completion_tokens as i32;
2057
2058 if !should_update {
2059 should_update = true;
2060 }
2061 }
2062
2063 if metadata.n_predict == -1
2064 || (metadata.n_predict > 0 && metadata.n_predict < available_completion_tokens as i32)
2065 || (metadata.n_predict < 0 && metadata.n_predict != -2)
2066 {
2068 #[cfg(feature = "logging")]
2069 info!(target: "stdout", "Update n_predict with available_completion_tokens from {} to {}", metadata.n_predict, available_completion_tokens);
2070
2071 metadata.n_predict = available_completion_tokens as i32;
2073
2074 if !should_update {
2075 should_update = true;
2076 }
2077 }
2078
2079 if should_update {
2080 #[cfg(feature = "logging")]
2081 info!(target: "stdout", "Update the model metadata.");
2082
2083 update_model_metadata(chat_request.model.as_ref(), metadata)?;
2085 }
2086
2087 Ok(())
2088}
2089
2090fn post_process(
2091 output: impl AsRef<str>,
2092 template_ty: &PromptTemplateType,
2093) -> Result<String, String> {
2094 let output = if *template_ty == PromptTemplateType::Baichuan2 {
2095 if output.as_ref().contains("用户:") {
2096 output.as_ref().trim_end_matches("用户:").trim().to_owned()
2097 } else {
2098 output.as_ref().trim().to_owned()
2099 }
2100 } else if *template_ty == PromptTemplateType::OpenChat {
2101 if output.as_ref().contains("<|end_of_turn|>") {
2102 output
2103 .as_ref()
2104 .trim_end_matches("<|end_of_turn|>")
2105 .trim()
2106 .to_owned()
2107 } else {
2108 output.as_ref().trim().to_owned()
2109 }
2110 } else if *template_ty == PromptTemplateType::GemmaInstruct {
2111 let s = output.as_ref().trim();
2112 if s.ends_with("<end_of_turn>") {
2113 s.trim_end_matches("<end_of_turn>").trim().to_owned()
2114 } else {
2115 s.to_owned()
2116 }
2117 } else if *template_ty == PromptTemplateType::ChatML
2118 || *template_ty == PromptTemplateType::ChatMLTool
2119 || *template_ty == PromptTemplateType::InternLM2Tool
2120 || *template_ty == PromptTemplateType::MiniCPMV
2121 {
2122 let mut s = output.as_ref().trim();
2123 if s.ends_with("<|endoftext|>") {
2124 s = s.trim_end_matches("<|endoftext|>").trim();
2125 }
2126
2127 if s.contains("<|im_start|>") && s.contains("<|im_end|>") {
2128 let idx_start = s.find("<|im_start|>").unwrap();
2129 let idx_end = s.find("<|im_end|>").unwrap();
2130
2131 match idx_start <= idx_end {
2132 true => s.split("<|im_start|>").collect::<Vec<_>>()[0]
2133 .trim()
2134 .to_owned(),
2135 false => s.split("<|im_end|>").collect::<Vec<_>>()[0]
2136 .trim()
2137 .to_owned(),
2138 }
2139 } else if s.contains("<|im_start|>") {
2140 s.split("<|im_start|>").collect::<Vec<_>>()[0]
2141 .trim()
2142 .to_owned()
2143 } else if s.contains("<|im_end|>") {
2144 let output = s.trim_end_matches("<|im_end|>").trim();
2145 if output.starts_with(": ") {
2146 output.trim_start_matches(": ").to_owned()
2147 } else {
2148 output.to_owned()
2149 }
2150 } else {
2151 s.to_owned()
2152 }
2153 } else if *template_ty == PromptTemplateType::Zephyr
2154 || *template_ty == PromptTemplateType::MistralLite
2155 || *template_ty == PromptTemplateType::MistralTool
2156 || *template_ty == PromptTemplateType::MistralInstruct
2157 || *template_ty == PromptTemplateType::MistralSmallChat
2158 || *template_ty == PromptTemplateType::MistralSmallTool
2159 || *template_ty == PromptTemplateType::BreezeInstruct
2160 {
2161 if output.as_ref().contains("</s><") {
2162 output.as_ref().trim_end_matches("</s><").trim().to_owned()
2163 } else if output.as_ref().contains("</s>") {
2164 output
2165 .as_ref()
2166 .strip_suffix("</s>")
2167 .unwrap()
2168 .trim()
2169 .to_owned()
2170 } else {
2171 output.as_ref().trim().to_owned()
2172 }
2173 } else if *template_ty == PromptTemplateType::DeepseekChat {
2174 if output.as_ref().contains("<|end_of_sentence|>") {
2175 output
2176 .as_ref()
2177 .trim_end_matches("<|end_of_sentence|>")
2178 .trim()
2179 .replace("<|end_of_sentence|>", " ")
2180 .trim()
2181 .to_owned()
2182 } else {
2183 output.as_ref().trim().to_owned()
2184 }
2185 } else if *template_ty == PromptTemplateType::HumanAssistant {
2186 if output.as_ref().contains("Human:") {
2187 output.as_ref().trim_end_matches("Human:").trim().to_owned()
2188 } else {
2189 output.as_ref().trim().to_owned()
2190 }
2191 } else if *template_ty == PromptTemplateType::SolarInstruct {
2192 let s = output.as_ref().trim();
2193
2194 if s.starts_with("### Answer") {
2195 let s = s.trim_start_matches("###").trim();
2196
2197 if s.starts_with("Answer:\n") {
2198 s.replace("Answer:\n", "Answer: ")
2199 } else {
2200 s.to_owned()
2201 }
2202 } else {
2203 s.to_owned()
2204 }
2205 } else if *template_ty == PromptTemplateType::Llama2Chat
2206 || *template_ty == PromptTemplateType::NemotronTool
2207 || *template_ty == PromptTemplateType::NemotronChat
2208 {
2209 let s = output.as_ref().trim();
2210 if s.ends_with("</s>") {
2211 s.trim_end_matches("</s>").trim().to_owned()
2212 } else {
2213 s.to_owned()
2214 }
2215 } else if *template_ty == PromptTemplateType::Llama3Chat
2216 || *template_ty == PromptTemplateType::GroqLlama3Tool
2217 || *template_ty == PromptTemplateType::Llama3Tool
2218 || *template_ty == PromptTemplateType::FunctionaryV32
2219 {
2220 let s = output.as_ref().trim();
2221 if s.ends_with("<|eot_id|>") {
2222 s.trim_end_matches("<|eot_id|>").trim().to_owned()
2223 } else {
2224 s.to_owned()
2225 }
2226 } else if *template_ty == PromptTemplateType::Phi3Chat {
2227 let s = output.as_ref().trim();
2228 if s.ends_with("<|end|>") {
2229 s.trim_end_matches("<|end|>").trim().to_owned()
2230 } else {
2231 s.to_owned()
2232 }
2233 } else if *template_ty == PromptTemplateType::Phi4Chat {
2234 let s = output.as_ref().trim();
2235 if s.ends_with("<|im_end|>") {
2236 s.trim_end_matches("<|im_end|>").trim().to_owned()
2237 } else {
2238 s.to_owned()
2239 }
2240 } else if *template_ty == PromptTemplateType::FunctionaryV31 {
2241 let mut s = output.as_ref().trim();
2242 if s.ends_with("<|eot_id|>") {
2243 s = s.trim_end_matches("<|eot_id|>").trim();
2244 }
2245 if s.ends_with("<|eom_id|>") {
2246 s = s.trim_end_matches("<|eom_id|>").trim();
2247 }
2248 s.to_owned()
2249 } else if *template_ty == PromptTemplateType::MoxinChat {
2250 let s = output.as_ref().trim();
2251 if s.ends_with("</s>") {
2252 s.trim_end_matches("</s>").trim().to_owned()
2253 } else if s.ends_with("[INST]") {
2254 s.trim_end_matches("[INST]").trim().to_owned()
2255 } else {
2256 s.to_owned()
2257 }
2258 } else if *template_ty == PromptTemplateType::Falcon3 {
2259 let s = output.as_ref().trim();
2260 if s.ends_with("<|endoftext|>") {
2261 s.trim_end_matches("<|endoftext|>").trim().to_owned()
2262 } else {
2263 s.to_owned()
2264 }
2265 } else if *template_ty == PromptTemplateType::Megrez {
2266 let s = output.as_ref().trim();
2267 if s.ends_with("<|turn_end|>") {
2268 s.trim_end_matches("<|turn_end|>").trim().to_owned()
2269 } else {
2270 s.to_owned()
2271 }
2272 } else if *template_ty == PromptTemplateType::Qwen2vl {
2273 let mut s = output.as_ref().trim();
2274
2275 if s.starts_with(":") {
2276 s = s.trim_start_matches(":").trim();
2277 }
2278
2279 if s.ends_with("<|im_end|>") {
2280 s.trim_end_matches("<|im_end|>").trim().to_owned()
2281 } else {
2282 s.to_owned()
2283 }
2284 } else if *template_ty == PromptTemplateType::VicunaLlava {
2285 let s = output.as_ref().trim();
2286 if s.ends_with("</s>") {
2287 s.trim_end_matches("</s>").trim().to_owned()
2288 } else {
2289 s.to_owned()
2290 }
2291 } else {
2292 output.as_ref().trim().to_owned()
2293 };
2294
2295 Ok(output)
2296}
2297
2298fn build_prompt(
2310 model_name: Option<&String>,
2311 chat_request: &mut ChatCompletionRequest,
2312) -> Result<(String, u64, bool), LlamaCoreError> {
2313 let metadata = get_model_metadata(model_name)?;
2314 let ctx_size = metadata.ctx_size as u64;
2315 let chat_prompt = ChatPrompt::from(metadata.prompt_template);
2316
2317 let max_prompt_tokens = ctx_size * 4 / 5;
2319
2320 loop {
2321 {
2323 }
2336
2337 if chat_request.messages.is_empty() {
2338 let err_msg = "The messages in the chat request are empty.";
2339
2340 #[cfg(feature = "logging")]
2341 error!(target: "stdout", "{}", err_msg);
2342
2343 return Err(LlamaCoreError::Operation(err_msg.to_owned()));
2344 }
2345
2346 #[cfg(feature = "logging")]
2347 {
2348 let mut role_chain = String::new();
2349 for (idx, message) in chat_request.messages.iter().enumerate() {
2350 if idx == chat_request.messages.len() - 1 {
2351 role_chain.push_str(&format!("{}", message.role()));
2352 } else {
2353 role_chain.push_str(&format!("{} -> ", message.role()));
2354 }
2355 }
2356 info!(target: "stdout", "Role chain: {}", role_chain);
2357 }
2358
2359 let (prompt, tool_use) = match chat_request.tool_choice.as_ref() {
2360 Some(tool_choice) => match tool_choice {
2361 ToolChoice::None => {
2362 match chat_prompt.build_with_tools(&mut chat_request.messages, Some(&[])) {
2363 Ok(prompt) => (prompt, false),
2364 Err(e) => {
2365 let err_msg = format!("Fail to build chat prompts. Reason: {}", e);
2366
2367 #[cfg(feature = "logging")]
2368 error!(target: "stdout", "{}", &err_msg);
2369
2370 return Err(LlamaCoreError::Operation(err_msg));
2371 }
2372 }
2373 }
2374 _ => match chat_request.tools.as_ref() {
2375 Some(tools) => match chat_prompt
2376 .build_with_tools(&mut chat_request.messages, Some(tools.as_slice()))
2377 {
2378 Ok(prompt) => (prompt, true),
2379 Err(e) => {
2380 let err_msg = format!("Fail to build chat prompts. Reason: {}", e);
2381
2382 #[cfg(feature = "logging")]
2383 error!(target: "stdout", "{}", &err_msg);
2384
2385 return Err(LlamaCoreError::Operation(err_msg));
2386 }
2387 },
2388 None => {
2389 #[cfg(feature = "logging")]
2390 warn!(target: "stdout", "The tool choice without tools is not supported.");
2391
2392 match chat_prompt.build_with_tools(&mut chat_request.messages, None) {
2393 Ok(prompt) => (prompt, false),
2394 Err(e) => {
2395 let err_msg = format!("Fail to build chat prompts. Reason: {}", e);
2396
2397 #[cfg(feature = "logging")]
2398 error!(target: "stdout", "{}", &err_msg);
2399
2400 return Err(LlamaCoreError::Operation(err_msg));
2401 }
2402 }
2403 }
2404 },
2405 },
2406 None => match chat_prompt.build_with_tools(&mut chat_request.messages, None) {
2407 Ok(prompt) => (prompt, false),
2408 Err(e) => {
2409 let err_msg = format!("Fail to build chat prompts. Reason: {}", e);
2410
2411 #[cfg(feature = "logging")]
2412 error!(target: "stdout", "{}", &err_msg);
2413
2414 return Err(LlamaCoreError::Operation(err_msg));
2415 }
2416 },
2417 };
2418 #[cfg(feature = "logging")]
2419 info!(target: "stdout", "Try to set prompt: {}", prompt);
2420
2421 set_prompt(model_name, &prompt)?;
2423
2424 let token_info = get_token_info_by_graph_name(model_name)?;
2426
2427 match token_info.prompt_tokens > max_prompt_tokens {
2428 true => {
2429 match chat_request.messages[0].role() {
2430 ChatCompletionRole::System => {
2431 if chat_request.messages.len() > 2 {
2432 #[cfg(feature = "logging")]
2433 info!(target: "stdout", "Prune chat history: current length {}", chat_request.messages.len());
2434
2435 if chat_request.messages[1].role() == ChatCompletionRole::User {
2438 let user_message = chat_request.messages.remove(1);
2439
2440 #[cfg(feature = "logging")]
2441 info!(target: "stdout", "Remove a user message from the chat history: {:?}", user_message);
2442 }
2443
2444 while chat_request.messages[1].role() != ChatCompletionRole::User {
2447 let message = chat_request.messages.remove(1);
2448
2449 #[cfg(feature = "logging")]
2450 info!(target: "stdout", "Remove a {} message from the chat history: {:?}", message.role(), message);
2451
2452 if chat_request.messages.len() == 1 {
2453 let err_msg = format!("The last message in the chat history should be a user message, but found a {} message.", message.role());
2454
2455 #[cfg(feature = "logging")]
2456 error!(target: "stdout", "{}", err_msg);
2457
2458 return Err(LlamaCoreError::Operation(err_msg));
2459 }
2460 }
2461 } else if token_info.prompt_tokens > ctx_size {
2462 let err_msg = format!(
2463 "The number of prompt tokens ({}) is greater than the context size ({}). Please increase the context size, or simplify the input message.",
2464 token_info.prompt_tokens, ctx_size
2465 );
2466
2467 #[cfg(feature = "logging")]
2468 error!(target: "stdout", "{}", &err_msg);
2469
2470 return Err(LlamaCoreError::Operation(err_msg));
2471 } else {
2472 return Ok((prompt, ctx_size - token_info.prompt_tokens, tool_use));
2473 }
2474 }
2475 ChatCompletionRole::User => {
2476 if chat_request.messages.len() > 1 {
2477 if chat_request.messages[0].role() == ChatCompletionRole::User {
2482 let user_message = chat_request.messages.remove(0);
2483
2484 #[cfg(feature = "logging")]
2485 info!(target: "stdout", "Remove a user message from the chat history: {:?}", user_message);
2486 }
2487
2488 while chat_request.messages[0].role() != ChatCompletionRole::User {
2491 let message = chat_request.messages.remove(0);
2492
2493 #[cfg(feature = "logging")]
2494 info!(target: "stdout", "Remove a {} message from the chat history: {:?}", message.role(), message);
2495
2496 if chat_request.messages.is_empty() {
2497 let err_msg = format!("The last message in the chat history should be a user message, but found a {} message.", message.role());
2498
2499 #[cfg(feature = "logging")]
2500 error!(target: "stdout", "{}", err_msg);
2501
2502 return Err(LlamaCoreError::Operation(err_msg));
2503 }
2504 }
2505 } else if token_info.prompt_tokens > ctx_size {
2506 let err_msg = format!(
2507 "The number of prompt tokens ({}) is greater than the context size ({}). Please increase the context size, or simplify the input message.",
2508 token_info.prompt_tokens, ctx_size
2509 );
2510
2511 #[cfg(feature = "logging")]
2512 error!(target: "stdout", "{}", &err_msg);
2513
2514 return Err(LlamaCoreError::Operation(err_msg));
2515 } else {
2516 return Ok((prompt, ctx_size - token_info.prompt_tokens, tool_use));
2517 }
2518 }
2519 _ => {
2520 #[cfg(feature = "logging")]
2521 info!(target: "stdout", "remove a {} message from the message queue", chat_request.messages[0].role());
2522
2523 chat_request.messages.remove(0);
2524 }
2525 }
2526
2527 continue;
2528 }
2529 false => return Ok((prompt, ctx_size - max_prompt_tokens, tool_use)),
2530 }
2531 }
2532}
2533
2534async fn download_image(image_url: impl AsRef<str>) -> Result<String, LlamaCoreError> {
2536 let image_url = image_url.as_ref();
2537 let url = reqwest::Url::parse(image_url).map_err(|e| {
2538 let err_msg = format!("Fail to parse the image URL: {}. Reason: {}", image_url, e);
2539
2540 #[cfg(feature = "logging")]
2541 error!(target: "stdout", "{}", &err_msg);
2542
2543 LlamaCoreError::Operation(err_msg)
2544 })?;
2545
2546 let response = reqwest::get(url).await.map_err(|e| {
2547 let err_msg = format!(
2548 "Fail to download the image from the URL: {}. Reason: {}",
2549 image_url, e
2550 );
2551
2552 #[cfg(feature = "logging")]
2553 error!(target: "stdout", "{}", &err_msg);
2554
2555 LlamaCoreError::Operation(err_msg)
2556 })?;
2557
2558 let fname = response
2559 .url()
2560 .path_segments()
2561 .and_then(|mut segments| segments.next_back())
2562 .and_then(|name| if name.is_empty() { None } else { Some(name) })
2563 .ok_or(LlamaCoreError::Operation(format!(
2564 "Fail to get the file name: {}",
2565 image_url
2566 )))?
2567 .to_string();
2568
2569 let id = format!("file_{}", uuid::Uuid::new_v4());
2571 let path = Path::new("archives");
2573 if !path.exists() {
2574 fs::create_dir(path).unwrap();
2575 }
2576 let file_path = path.join(&id);
2577 if !file_path.exists() {
2578 fs::create_dir(&file_path).unwrap();
2579 }
2580 let img_path = file_path.join(&fname);
2581 let mut dest: File = File::create(&img_path).map_err(|e| {
2582 let err_msg = format!("Failed to create archive document {}. {}", &fname, e);
2583
2584 #[cfg(feature = "logging")]
2586 error!(target: "stdout", "{}", &err_msg);
2587
2588 LlamaCoreError::Operation(err_msg)
2589 })?;
2590
2591 let mut content = response.bytes_stream();
2592 while let Some(Ok(item)) = content.next().await {
2593 std::io::copy(&mut item.as_ref(), &mut dest).map_err(|e| {
2594 let err_msg = format!(
2595 "Fail to write the image content to the file: {}. Reason: {}",
2596 &fname, e
2597 );
2598
2599 #[cfg(feature = "logging")]
2600 error!(target: "stdout", "{}", &err_msg);
2601
2602 LlamaCoreError::Operation(err_msg)
2603 })?;
2604 }
2605
2606 Ok(img_path.as_path().to_string_lossy().to_string())
2607}
2608
2609fn set_prompt(model_name: Option<&String>, prompt: impl AsRef<str>) -> Result<(), LlamaCoreError> {
2610 let chat_graphs = match CHAT_GRAPHS.get() {
2611 Some(chat_graphs) => chat_graphs,
2612 None => {
2613 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
2614
2615 #[cfg(feature = "logging")]
2616 error!(target: "stdout", "{}", &err_msg);
2617
2618 return Err(LlamaCoreError::Operation(err_msg.into()));
2619 }
2620 };
2621
2622 let mut chat_graphs = chat_graphs.lock().map_err(|e| {
2623 let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {}", e);
2624
2625 #[cfg(feature = "logging")]
2626 error!(target: "stdout", "{}", &err_msg);
2627
2628 LlamaCoreError::Operation(err_msg)
2629 })?;
2630
2631 match model_name {
2632 Some(model_name) => {
2633 #[cfg(feature = "logging")]
2634 info!(target: "stdout", "Set prompt to the chat model named {}", model_name);
2635
2636 match chat_graphs.contains_key(model_name) {
2637 true => {
2638 let graph = chat_graphs.get_mut(model_name).unwrap();
2639 let tensor_data = prompt.as_ref().as_bytes().to_vec();
2640 set_tensor_data_u8(graph, 0, &tensor_data)
2641 }
2642 false => match chat_graphs.iter_mut().next() {
2643 Some((_, graph)) => {
2644 let tensor_data = prompt.as_ref().as_bytes().to_vec();
2645 set_tensor_data_u8(graph, 0, &tensor_data)
2646 }
2647 None => {
2648 let err_msg = "There is no model available in the chat graphs.";
2649
2650 #[cfg(feature = "logging")]
2651 error!(target: "stdout", "{}", &err_msg);
2652
2653 Err(LlamaCoreError::Operation(err_msg.into()))
2654 }
2655 },
2656 }
2657 }
2658 None => {
2659 #[cfg(feature = "logging")]
2660 info!(target: "stdout", "Set prompt to the default chat model.");
2661
2662 match chat_graphs.iter_mut().next() {
2663 Some((_, graph)) => {
2664 let tensor_data = prompt.as_ref().as_bytes().to_vec();
2665 set_tensor_data_u8(graph, 0, &tensor_data)
2666 }
2667 None => {
2668 let err_msg = "There is no model available in the chat graphs while trying to set prompt to the default model.";
2669
2670 #[cfg(feature = "logging")]
2671 error!(target: "stdout", "{}", err_msg);
2672
2673 Err(LlamaCoreError::Operation(err_msg.into()))
2674 }
2675 }
2676 }
2677 }
2678}
2679
2680fn get_model_metadata(model_name: Option<&String>) -> Result<GgmlMetadata, LlamaCoreError> {
2682 let chat_graphs = match CHAT_GRAPHS.get() {
2683 Some(chat_graphs) => chat_graphs,
2684 None => {
2685 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
2686
2687 #[cfg(feature = "logging")]
2688 error!(target: "stdout", "{}", err_msg);
2689
2690 return Err(LlamaCoreError::Operation(err_msg.into()));
2691 }
2692 };
2693
2694 let chat_graphs = chat_graphs.lock().map_err(|e| {
2695 let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {}", e);
2696
2697 #[cfg(feature = "logging")]
2698 error!(target: "stdout", "{}", &err_msg);
2699
2700 LlamaCoreError::Operation(err_msg)
2701 })?;
2702
2703 match model_name {
2704 Some(model_name) => match chat_graphs.contains_key(model_name) {
2705 true => {
2706 let graph = chat_graphs.get(model_name).unwrap();
2707 Ok(graph.metadata.clone())
2708 }
2709 false => match chat_graphs.iter().next() {
2710 Some((_, graph)) => Ok(graph.metadata.clone()),
2711 None => {
2712 let err_msg = "There is no model available in the chat graphs.";
2713
2714 #[cfg(feature = "logging")]
2715 error!(target: "stdout", "{}", &err_msg);
2716
2717 Err(LlamaCoreError::Operation(err_msg.into()))
2718 }
2719 },
2720 },
2721 None => match chat_graphs.iter().next() {
2722 Some((_, graph)) => Ok(graph.metadata.clone()),
2723 None => {
2724 let err_msg = "There is no model available in the chat graphs.";
2725
2726 #[cfg(feature = "logging")]
2727 error!(target: "stdout", "{}", err_msg);
2728
2729 Err(LlamaCoreError::Operation(err_msg.into()))
2730 }
2731 },
2732 }
2733}
2734
2735fn update_model_metadata(
2736 model_name: Option<&String>,
2737 metadata: &GgmlMetadata,
2738) -> Result<(), LlamaCoreError> {
2739 let config = match serde_json::to_string(metadata) {
2740 Ok(config) => config,
2741 Err(e) => {
2742 let err_msg = format!("Fail to serialize metadata to a JSON string. {}", e);
2743
2744 #[cfg(feature = "logging")]
2745 error!(target: "stdout", "{}", &err_msg);
2746
2747 return Err(LlamaCoreError::Operation(err_msg));
2748 }
2749 };
2750
2751 let chat_graphs = match CHAT_GRAPHS.get() {
2752 Some(chat_graphs) => chat_graphs,
2753 None => {
2754 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
2755
2756 #[cfg(feature = "logging")]
2757 error!(target: "stdout", "{}", err_msg);
2758
2759 return Err(LlamaCoreError::Operation(err_msg.into()));
2760 }
2761 };
2762
2763 let mut chat_graphs = chat_graphs.lock().map_err(|e| {
2764 let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. Reason: {}", e);
2765
2766 #[cfg(feature = "logging")]
2767 error!(target: "stdout", "{}", &err_msg);
2768
2769 LlamaCoreError::Operation(err_msg)
2770 })?;
2771
2772 match model_name {
2773 Some(model_name) => {
2774 match chat_graphs.contains_key(model_name) {
2775 true => {
2776 let graph = chat_graphs.get_mut(model_name).unwrap();
2777 set_tensor_data_u8(graph, 1, config.as_bytes())
2779 }
2780 false => match chat_graphs.iter_mut().next() {
2781 Some((_, graph)) => {
2782 set_tensor_data_u8(graph, 1, config.as_bytes())
2784 }
2785 None => {
2786 let err_msg = "There is no model available in the chat graphs.";
2787
2788 #[cfg(feature = "logging")]
2789 error!(target: "stdout", "{}", &err_msg);
2790
2791 Err(LlamaCoreError::Operation(err_msg.into()))
2792 }
2793 },
2794 }
2795 }
2796 None => {
2797 match chat_graphs.iter_mut().next() {
2798 Some((_, graph)) => {
2799 set_tensor_data_u8(graph, 1, config.as_bytes())
2801 }
2802 None => {
2803 let err_msg = "There is no model available in the chat graphs.";
2804
2805 #[cfg(feature = "logging")]
2806 error!(target: "stdout", "{}", err_msg);
2807
2808 Err(LlamaCoreError::Operation(err_msg.into()))
2809 }
2810 }
2811 }
2812 }
2813}
2814
2815fn reset_model_metadata(model_name: Option<&String>) -> Result<(), LlamaCoreError> {
2816 let metadata = get_model_metadata(model_name)?;
2818
2819 update_model_metadata(model_name, &metadata)
2821}
2822
2823#[derive(Debug, Clone, Copy, PartialEq, Eq)]
2824enum ContextFullState {
2825 Message,
2826 Usage,
2827 Done,
2828 EndOfSequence,
2829}
2830
2831#[derive(Debug, Clone, Copy, PartialEq, Eq)]
2832enum StreamState {
2833 Usage,
2834 NoUsage,
2835 Done,
2836 EndOfSequence,
2837}
2838
2839#[derive(Debug, Clone, Copy, PartialEq, Eq)]
2840enum PromptTooLongState {
2841 Message,
2842 Usage,
2843 Done,
2844 EndOfSequence,
2845}
2846
2847struct ChatStream {
2848 id: String,
2849 model: Option<String>,
2850 include_usage: bool,
2851 context_full_state: ContextFullState,
2852 prompt_too_long_state: PromptTooLongState,
2853 stream_state: StreamState,
2854 cache: Option<VecDeque<String>>,
2855}
2856impl ChatStream {
2857 fn new(
2858 model: Option<String>,
2859 id: String,
2860 include_usage: bool,
2861 cache: Option<Vec<String>>,
2862 ) -> Self {
2863 let stream_state = if include_usage {
2864 StreamState::Usage
2865 } else {
2866 StreamState::NoUsage
2867 };
2868
2869 ChatStream {
2870 id,
2871 model,
2872 include_usage,
2873 context_full_state: ContextFullState::Message,
2874 prompt_too_long_state: PromptTooLongState::Message,
2875 stream_state,
2876 cache: cache.map(VecDeque::from),
2877 }
2878 }
2879}
2880impl Drop for ChatStream {
2881 fn drop(&mut self) {
2882 if self.cache.is_none() {
2883 #[cfg(feature = "logging")]
2884 info!(target: "stdout", "Clean up the context of the stream work environment.");
2885
2886 match &self.model {
2887 Some(model_name) => {
2888 match CHAT_GRAPHS.get() {
2889 Some(chat_graphs) => match chat_graphs.lock() {
2890 Ok(mut chat_graphs) => match chat_graphs.contains_key(model_name) {
2891 true => {
2892 let graph = chat_graphs.get_mut(model_name).unwrap();
2893
2894 if let Err(e) = graph.finish_single() {
2895 let err_msg = format!(
2896 "Failed to clean up the context. Reason: {}",
2897 e
2898 );
2899
2900 #[cfg(feature = "logging")]
2901 error!(target: "stdout", "{}", &err_msg);
2902
2903 #[cfg(not(feature = "logging"))]
2904 println!(
2905 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
2906 &err_msg
2907 );
2908 }
2909 }
2910 false => match chat_graphs.iter_mut().next() {
2911 Some((_, graph)) => {
2912 if let Err(e) = graph.finish_single() {
2913 let err_msg = format!(
2914 "Failed to clean up the context. Reason: {}",
2915 e
2916 );
2917
2918 #[cfg(feature = "logging")]
2919 error!(target: "stdout", "{}", &err_msg);
2920
2921 #[cfg(not(feature = "logging"))]
2922 println!(
2923 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
2924 &err_msg
2925 );
2926 }
2927 }
2928 None => {
2929 let err_msg =
2930 "There is no model available in the chat graphs.";
2931
2932 #[cfg(feature = "logging")]
2933 error!(target: "stdout", "{}", &err_msg);
2934
2935 #[cfg(not(feature = "logging"))]
2936 println!(
2937 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
2938 &err_msg
2939 );
2940 }
2941 },
2942 },
2943 Err(e) => {
2944 let err_msg =
2945 format!("Fail to acquire the lock of `CHAT_GRAPHS`. {}", e);
2946
2947 #[cfg(feature = "logging")]
2948 error!(target: "stdout", "{}", &err_msg);
2949
2950 #[cfg(not(feature = "logging"))]
2951 println!(
2952 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
2953 &err_msg
2954 );
2955 }
2956 },
2957 None => {
2958 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
2959
2960 #[cfg(feature = "logging")]
2961 error!(target: "stdout", "{}", &err_msg);
2962
2963 #[cfg(not(feature = "logging"))]
2964 println!(
2965 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
2966 &err_msg
2967 );
2968 }
2969 };
2970 }
2971 None => {
2972 match CHAT_GRAPHS.get() {
2973 Some(chat_graphs) => match chat_graphs.lock() {
2974 Ok(mut chat_graphs) => match chat_graphs.iter_mut().next() {
2975 Some((_, graph)) => {
2976 if let Err(e) = graph.finish_single() {
2977 let err_msg = format!(
2978 "Failed to clean up the context. Reason: {}",
2979 e
2980 );
2981
2982 #[cfg(feature = "logging")]
2983 error!(target: "stdout", "{}", &err_msg);
2984
2985 #[cfg(not(feature = "logging"))]
2986 println!(
2987 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
2988 &err_msg
2989 );
2990 }
2991 }
2992 None => {
2993 let err_msg = "There is no model available in the chat graphs.";
2994
2995 #[cfg(feature = "logging")]
2996 error!(target: "stdout", "{}", err_msg);
2997
2998 #[cfg(not(feature = "logging"))]
2999 println!(
3000 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3001 err_msg
3002 );
3003 }
3004 },
3005 Err(e) => {
3006 let err_msg =
3007 format!("Fail to acquire the lock of `CHAT_GRAPHS`. {}", e);
3008
3009 #[cfg(feature = "logging")]
3010 error!(target: "stdout", "{}", &err_msg);
3011
3012 #[cfg(not(feature = "logging"))]
3013 println!(
3014 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3015 &err_msg
3016 );
3017 }
3018 },
3019 None => {
3020 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
3021
3022 #[cfg(feature = "logging")]
3023 error!(target: "stdout", "{}", &err_msg);
3024
3025 #[cfg(not(feature = "logging"))]
3026 println!(
3027 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3028 &err_msg
3029 );
3030 }
3031 };
3032 }
3033 }
3034
3035 #[cfg(feature = "logging")]
3036 info!(target: "stdout", "Cleanup done!");
3037 }
3038 }
3039}
3040impl futures::Stream for ChatStream {
3041 type Item = Result<String, LlamaCoreError>;
3042
3043 fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
3044 if self.cache.is_none() {
3045 let this = self.get_mut();
3046 let res = compute_stream(
3047 this.model.clone(),
3048 this.id.clone(),
3049 this.include_usage,
3050 &mut this.prompt_too_long_state,
3051 &mut this.context_full_state,
3052 &mut this.stream_state,
3053 );
3054
3055 match res {
3056 Ok(x) => {
3057 #[cfg(feature = "logging")]
3058 info!(target: "stdout", "next item: {}", &x);
3059
3060 if x != "[GGML] End of sequence" && !x.is_empty() {
3061 Poll::Ready(Some(Ok(x)))
3062 } else {
3063 Poll::Ready(None)
3065 }
3066 }
3067 Err(e) => Poll::Ready(Some(Err(e))),
3068 }
3069 } else {
3070 let this = self.get_mut();
3071
3072 let x = this.cache.as_mut().unwrap().pop_front();
3073
3074 #[cfg(feature = "logging")]
3075 info!(target: "stdout", "Get the next item from the cache: {:?}", &x);
3076
3077 match x {
3078 Some(x) => Poll::Ready(Some(Ok(x))),
3079 None => Poll::Ready(None),
3080 }
3081 }
3082 }
3083}
3084
3085fn compute_stream(
3086 model_name: Option<String>,
3087 id: String,
3088 include_usage: bool,
3089 prompt_too_long_state: &mut PromptTooLongState,
3090 context_full_state: &mut ContextFullState,
3091 stream_state: &mut StreamState,
3092) -> Result<String, LlamaCoreError> {
3093 #[cfg(feature = "logging")]
3094 info!(target: "stdout", "Compute the chat stream chunk.");
3095
3096 #[cfg(feature = "logging")]
3097 debug!(target: "stdout", "prompt_too_long_state: {:?}", *prompt_too_long_state);
3098 #[cfg(feature = "logging")]
3099 debug!(target: "stdout", "context_full_state: {:?}", *context_full_state);
3100 #[cfg(feature = "logging")]
3101 debug!(target: "stdout", "stream_state: {:?}", *stream_state);
3102
3103 if *prompt_too_long_state == PromptTooLongState::EndOfSequence
3104 || *context_full_state == ContextFullState::EndOfSequence
3105 || *stream_state == StreamState::EndOfSequence
3106 {
3107 #[cfg(feature = "logging")]
3108 info!(target: "stdout", "Return the chat stream chunk!");
3109
3110 return Ok("[GGML] End of sequence".to_string());
3111 }
3112
3113 let chat_graphs = match CHAT_GRAPHS.get() {
3114 Some(chat_graphs) => chat_graphs,
3115 None => {
3116 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
3117
3118 #[cfg(feature = "logging")]
3119 error!(target: "stdout", "{}", &err_msg);
3120
3121 return Err(LlamaCoreError::Operation(err_msg.into()));
3122 }
3123 };
3124
3125 let mut chat_graphs = chat_graphs.lock().map_err(|e| {
3126 let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {}", e);
3127
3128 #[cfg(feature = "logging")]
3129 error!(target: "stdout", "{}", &err_msg);
3130
3131 LlamaCoreError::Operation(err_msg)
3132 })?;
3133
3134 let res = match &model_name {
3136 Some(model_name) => {
3137 match chat_graphs.contains_key(model_name) {
3138 true => {
3139 let graph = chat_graphs.get_mut(model_name).unwrap();
3140 match graph.compute_single() {
3142 Ok(_) => {
3143 #[cfg(feature = "logging")]
3144 debug!(target: "stdout", "Compute the chat stream chunk successfully.");
3145
3146 match stream_state {
3147 StreamState::Usage | StreamState::NoUsage => {
3148 let output_buffer =
3150 get_output_buffer_single(graph, OUTPUT_TENSOR)?;
3151
3152 #[cfg(feature = "logging")]
3153 info!(target: "stdout", "retrieved the output buffer");
3154
3155 let output = match String::from_utf8(output_buffer.clone()) {
3157 Ok(token) => token,
3158 Err(_) => {
3159 let mutex = CACHED_UTF8_ENCODINGS
3160 .get_or_init(|| Mutex::new(Vec::new()));
3161 let mut cached_encodings = mutex.lock().map_err(|e| {
3162 let err_msg = format!(
3163 "Fail to acquire the lock of `UTF8_ENCODINGS`. Reason: {}",
3164 e
3165 );
3166
3167 #[cfg(feature = "logging")]
3168 error!(target: "stdout", "{}", &err_msg);
3169
3170
3171 LlamaCoreError::Operation(err_msg)
3172 })?;
3173
3174 cached_encodings.extend_from_slice(&output_buffer[..]);
3176
3177 match String::from_utf8(cached_encodings.to_vec()) {
3178 Ok(token) => {
3179 cached_encodings.clear();
3181
3182 token
3183 }
3184 Err(e) => {
3185 if cached_encodings.len() > 4 {
3187 let err_msg = format!("Fail to convert a vector of bytes to string. The length of the utf8 bytes exceeds 4. {}", e);
3188
3189 #[cfg(feature = "logging")]
3190 error!(target: "stdout", "{}", &err_msg);
3191
3192 #[cfg(feature = "logging")]
3193 error!(target: "stdout", "The cached buffer: {:?}", &cached_encodings[..]);
3194
3195 cached_encodings.clear();
3202
3203 String::from("")
3204 } else {
3205 let warn_msg = format!("Fail to convert a vector of bytes to string. {}", e);
3206
3207 #[cfg(feature = "logging")]
3208 warn!(target: "stdout", "{}", &warn_msg);
3209
3210 String::from("")
3211 }
3212 }
3213 }
3214 }
3215 };
3216
3217 #[cfg(feature = "logging")]
3218 info!(target: "stdout", "decoded the output buffer");
3219
3220 let created = SystemTime::now()
3221 .duration_since(std::time::UNIX_EPOCH)
3222 .map_err(|e| {
3223 let err_msg = format!(
3224 "Failed to get the current time. Reason: {}",
3225 e
3226 );
3227
3228 #[cfg(feature = "logging")]
3229 error!(target: "stdout", "{}", &err_msg);
3230
3231 LlamaCoreError::Operation(err_msg)
3232 })?;
3233
3234 let chat_completion_chunk = ChatCompletionChunk {
3235 id,
3236 object: "chat.completion.chunk".to_string(),
3237 created: created.as_secs(),
3238 model: graph.name().to_owned(),
3239 system_fingerprint: "fp_44709d6fcb".to_string(),
3240 choices: vec![ChatCompletionChunkChoice {
3241 index: 0,
3242 delta: ChatCompletionChunkChoiceDelta {
3243 role: ChatCompletionRole::Assistant,
3244 content: Some(output),
3245 tool_calls: vec![],
3246 },
3247 logprobs: None,
3248 finish_reason: None,
3249 }],
3250 usage: None,
3251 };
3252
3253 #[cfg(feature = "logging")]
3254 info!(target: "stdout", "created chat completion chunk");
3255
3256 let chunk_str = serde_json::to_string(&chat_completion_chunk)
3258 .map_err(|e| {
3259 let err_msg = format!(
3260 "Failed to serialize chat completion chunk. Reason: {}",
3261 e
3262 );
3263
3264 #[cfg(feature = "logging")]
3265 error!(target: "stdout", "{}", &err_msg);
3266
3267 LlamaCoreError::Operation(err_msg)
3268 })?;
3269
3270 Ok(format!("data: {}\n\n", chunk_str))
3271 }
3272 StreamState::Done => {
3273 *stream_state = StreamState::EndOfSequence;
3274
3275 Ok("data: [DONE]\n\n".to_string())
3276 }
3277 StreamState::EndOfSequence => {
3278 Ok("[GGML] End of sequence".to_string())
3279 }
3280 }
3281 }
3282 Err(wasmedge_wasi_nn::Error::BackendError(
3283 wasmedge_wasi_nn::BackendError::EndOfSequence,
3284 )) => {
3285 #[cfg(feature = "logging")]
3286 debug!(target: "stdout", "End of sequence");
3287
3288 match stream_state {
3289 StreamState::Usage => {
3290 *stream_state = StreamState::Done;
3291
3292 let token_info = get_token_info_by_graph(graph)?;
3294
3295 let usage = Some(Usage {
3296 prompt_tokens: token_info.prompt_tokens,
3297 completion_tokens: token_info.completion_tokens,
3298 total_tokens: token_info.prompt_tokens
3299 + token_info.completion_tokens,
3300 });
3301
3302 #[cfg(feature = "logging")]
3303 info!(target: "stdout", "token_info: {} prompt tokens, {} completion tokens", token_info.prompt_tokens, token_info.completion_tokens);
3304
3305 let created = SystemTime::now()
3306 .duration_since(std::time::UNIX_EPOCH)
3307 .map_err(|e| {
3308 let err_msg = format!(
3309 "Failed to get the current time. Reason: {}",
3310 e
3311 );
3312
3313 #[cfg(feature = "logging")]
3314 error!(target: "stdout", "{}", &err_msg);
3315
3316 LlamaCoreError::Operation(err_msg)
3317 })?;
3318
3319 let chat_completion_chunk = ChatCompletionChunk {
3320 id,
3321 object: "chat.completion.chunk".to_string(),
3322 created: created.as_secs(),
3323 model: graph.name().to_owned(),
3324 system_fingerprint: "fp_44709d6fcb".to_string(),
3325 choices: vec![],
3326 usage,
3327 };
3328
3329 let chunk_str = serde_json::to_string(&chat_completion_chunk)
3331 .map_err(|e| {
3332 let err_msg = format!(
3333 "Failed to serialize chat completion chunk. Reason: {}",
3334 e
3335 );
3336
3337 #[cfg(feature = "logging")]
3338 error!(target: "stdout", "{}", &err_msg);
3339
3340 LlamaCoreError::Operation(err_msg)
3341 })?;
3342
3343 Ok(format!("data: {}\n\n", chunk_str))
3344 }
3345 StreamState::Done | StreamState::NoUsage => {
3346 *stream_state = StreamState::EndOfSequence;
3347
3348 Ok("data: [DONE]\n\n".to_string())
3349 }
3350 StreamState::EndOfSequence => {
3351 Ok("[GGML] End of sequence".to_string())
3352 }
3353 }
3354 }
3355 Err(wasmedge_wasi_nn::Error::BackendError(
3356 wasmedge_wasi_nn::BackendError::ContextFull,
3357 )) => {
3358 #[cfg(feature = "logging")]
3359 debug!(target: "stdout", "Context full");
3360
3361 match context_full_state {
3362 ContextFullState::Message => {
3363 match include_usage {
3364 true => *context_full_state = ContextFullState::Usage,
3365 false => *context_full_state = ContextFullState::Done,
3366 }
3367
3368 let created = SystemTime::now()
3369 .duration_since(std::time::UNIX_EPOCH)
3370 .map_err(|e| {
3371 let err_msg = format!(
3372 "Failed to get the current time. Reason: {}",
3373 e
3374 );
3375
3376 #[cfg(feature = "logging")]
3377 error!(target: "stdout", "{}", &err_msg);
3378
3379 LlamaCoreError::Operation(err_msg)
3380 })?;
3381
3382 let chat_completion_chunk = ChatCompletionChunk {
3383 id,
3384 object: "chat.completion.chunk".to_string(),
3385 created: created.as_secs(),
3386 model: graph.name().to_owned(),
3387 system_fingerprint: "fp_44709d6fcb".to_string(),
3388 choices: vec![ChatCompletionChunkChoice {
3389 index: 0,
3390 delta: ChatCompletionChunkChoiceDelta {
3391 role: ChatCompletionRole::Assistant,
3392 content: Some(
3393 "<|WASMEDGE-GGML-CONTEXT-FULL|>".to_string(),
3394 ),
3395 tool_calls: vec![],
3396 },
3397 logprobs: None,
3398 finish_reason: Some(FinishReason::length),
3399 }],
3400 usage: None,
3401 };
3402
3403 let chunk_str = serde_json::to_string(&chat_completion_chunk)
3405 .map_err(|e| {
3406 let err_msg = format!(
3407 "Failed to serialize chat completion chunk. Reason: {}",
3408 e
3409 );
3410
3411 #[cfg(feature = "logging")]
3412 error!(target: "stdout", "{}", &err_msg);
3413
3414 LlamaCoreError::Operation(err_msg)
3415 })?;
3416
3417 Ok(format!("data: {}\n\n", chunk_str))
3418 }
3419 ContextFullState::Usage => {
3420 *context_full_state = ContextFullState::Done;
3421
3422 let token_info = get_token_info_by_graph(graph)?;
3424
3425 let usage = Some(Usage {
3426 prompt_tokens: token_info.prompt_tokens,
3427 completion_tokens: token_info.completion_tokens,
3428 total_tokens: token_info.prompt_tokens
3429 + token_info.completion_tokens,
3430 });
3431
3432 let created = SystemTime::now()
3433 .duration_since(std::time::UNIX_EPOCH)
3434 .map_err(|e| {
3435 let err_msg = format!(
3436 "Failed to get the current time. Reason: {}",
3437 e
3438 );
3439
3440 #[cfg(feature = "logging")]
3441 error!(target: "stdout", "{}", &err_msg);
3442
3443 LlamaCoreError::Operation(err_msg)
3444 })?;
3445
3446 let chat_completion_chunk = ChatCompletionChunk {
3447 id,
3448 object: "chat.completion.chunk".to_string(),
3449 created: created.as_secs(),
3450 model: graph.name().to_owned(),
3451 system_fingerprint: "fp_44709d6fcb".to_string(),
3452 choices: vec![],
3453 usage,
3454 };
3455
3456 let chunk_str = serde_json::to_string(&chat_completion_chunk)
3458 .map_err(|e| {
3459 let err_msg = format!(
3460 "Failed to serialize chat completion chunk. Reason: {}",
3461 e
3462 );
3463
3464 #[cfg(feature = "logging")]
3465 error!(target: "stdout", "{}", &err_msg);
3466
3467 LlamaCoreError::Operation(err_msg)
3468 })?;
3469
3470 Ok(format!("data: {}\n\n", chunk_str))
3471 }
3472 ContextFullState::Done => {
3473 *context_full_state = ContextFullState::EndOfSequence;
3474
3475 Ok("data: [DONE]\n\n".to_string())
3476 }
3477 ContextFullState::EndOfSequence => {
3478 Ok("[GGML] End of sequence".to_string())
3479 }
3480 }
3481 }
3482 Err(wasmedge_wasi_nn::Error::BackendError(
3483 wasmedge_wasi_nn::BackendError::PromptTooLong,
3484 )) => {
3485 #[cfg(feature = "logging")]
3486 debug!(target: "stdout", "Prompt too long");
3487
3488 match prompt_too_long_state {
3489 PromptTooLongState::Message => {
3490 match include_usage {
3491 true => *prompt_too_long_state = PromptTooLongState::Usage,
3492 false => *prompt_too_long_state = PromptTooLongState::Done,
3493 }
3494
3495 let created = SystemTime::now()
3496 .duration_since(std::time::UNIX_EPOCH)
3497 .map_err(|e| {
3498 let err_msg = format!(
3499 "Failed to get the current time. Reason: {}",
3500 e
3501 );
3502
3503 #[cfg(feature = "logging")]
3504 error!(target: "stdout", "{}", &err_msg);
3505
3506 LlamaCoreError::Operation(err_msg)
3507 })?;
3508
3509 let chat_completion_chunk = ChatCompletionChunk {
3510 id,
3511 object: "chat.completion.chunk".to_string(),
3512 created: created.as_secs(),
3513 model: graph.name().to_owned(),
3514 system_fingerprint: "fp_44709d6fcb".to_string(),
3515 choices: vec![ChatCompletionChunkChoice {
3516 index: 0,
3517 delta: ChatCompletionChunkChoiceDelta {
3518 role: ChatCompletionRole::Assistant,
3519 content: None,
3520 tool_calls: vec![],
3521 },
3522 logprobs: None,
3523 finish_reason: Some(FinishReason::length),
3524 }],
3525 usage: None,
3526 };
3527
3528 let chunk_str = serde_json::to_string(&chat_completion_chunk)
3530 .map_err(|e| {
3531 let err_msg = format!(
3532 "Failed to serialize chat completion chunk. Reason: {}",
3533 e
3534 );
3535
3536 #[cfg(feature = "logging")]
3537 error!(target: "stdout", "{}", &err_msg);
3538
3539 LlamaCoreError::Operation(err_msg)
3540 })?;
3541
3542 Ok(format!("data: {}\n\n", chunk_str))
3543 }
3544 PromptTooLongState::Usage => {
3545 *prompt_too_long_state = PromptTooLongState::Done;
3546
3547 let token_info = get_token_info_by_graph(graph)?;
3549
3550 let usage = Some(Usage {
3551 prompt_tokens: token_info.prompt_tokens,
3552 completion_tokens: token_info.completion_tokens,
3553 total_tokens: token_info.prompt_tokens
3554 + token_info.completion_tokens,
3555 });
3556
3557 let created = SystemTime::now()
3558 .duration_since(std::time::UNIX_EPOCH)
3559 .map_err(|e| {
3560 let err_msg = format!(
3561 "Failed to get the current time. Reason: {}",
3562 e
3563 );
3564
3565 #[cfg(feature = "logging")]
3566 error!(target: "stdout", "{}", &err_msg);
3567
3568 LlamaCoreError::Operation(err_msg)
3569 })?;
3570
3571 let chat_completion_chunk = ChatCompletionChunk {
3572 id,
3573 object: "chat.completion.chunk".to_string(),
3574 created: created.as_secs(),
3575 model: graph.name().to_owned(),
3576 system_fingerprint: "fp_44709d6fcb".to_string(),
3577 choices: vec![],
3578 usage,
3579 };
3580
3581 let chunk_str = serde_json::to_string(&chat_completion_chunk)
3583 .map_err(|e| {
3584 let err_msg = format!(
3585 "Failed to serialize chat completion chunk. Reason: {}",
3586 e
3587 );
3588
3589 #[cfg(feature = "logging")]
3590 error!(target: "stdout", "{}", &err_msg);
3591
3592 LlamaCoreError::Operation(err_msg)
3593 })?;
3594
3595 Ok(format!("data: {}\n\n", chunk_str))
3596 }
3597 PromptTooLongState::Done => {
3598 *prompt_too_long_state = PromptTooLongState::EndOfSequence;
3599
3600 Ok("data: [DONE]\n\n".to_string())
3601 }
3602 PromptTooLongState::EndOfSequence => {
3603 Ok("[GGML] End of sequence".to_string())
3604 }
3605 }
3606 }
3607 Err(e) => {
3608 let err_msg =
3609 format!("Failed to compute the chat completion. Reason: {}", e);
3610
3611 #[cfg(feature = "logging")]
3612 error!(target: "stdout", "{}", &err_msg);
3613
3614 Err(LlamaCoreError::Backend(BackendError::ComputeSingle(
3615 err_msg,
3616 )))
3617 }
3618 }
3619 }
3620 false => {
3621 match chat_graphs.iter_mut().next() {
3622 Some((_, graph)) => {
3623 match graph.compute_single() {
3625 Ok(_) => {
3626 #[cfg(feature = "logging")]
3627 debug!(target: "stdout", "Compute the chat stream chunk successfully.");
3628
3629 match stream_state {
3630 StreamState::Usage | StreamState::NoUsage => {
3631 let output_buffer =
3633 get_output_buffer_single(graph, OUTPUT_TENSOR)?;
3634
3635 #[cfg(feature = "logging")]
3636 info!(target: "stdout", "retrieved the output buffer");
3637
3638 let output = match String::from_utf8(
3640 output_buffer.clone(),
3641 ) {
3642 Ok(token) => token,
3643 Err(_) => {
3644 let mutex = CACHED_UTF8_ENCODINGS
3645 .get_or_init(|| Mutex::new(Vec::new()));
3646 let mut cached_encodings = mutex.lock().map_err(|e| {
3647 let err_msg = format!(
3648 "Fail to acquire the lock of `UTF8_ENCODINGS`. Reason: {}",
3649 e
3650 );
3651
3652 #[cfg(feature = "logging")]
3653 error!(target: "stdout", "{}", &err_msg);
3654
3655
3656 LlamaCoreError::Operation(err_msg)
3657 })?;
3658
3659 cached_encodings
3661 .extend_from_slice(&output_buffer[..]);
3662
3663 match String::from_utf8(
3664 cached_encodings.to_vec(),
3665 ) {
3666 Ok(token) => {
3667 cached_encodings.clear();
3669
3670 token
3671 }
3672 Err(e) => {
3673 if cached_encodings.len() > 4 {
3675 let err_msg = format!("Fail to convert a vector of bytes to string. The length of the utf8 bytes exceeds 4. {}", e);
3676
3677 #[cfg(feature = "logging")]
3678 error!(target: "stdout", "{}", &err_msg);
3679
3680 #[cfg(feature = "logging")]
3681 error!(target: "stdout", "The cached buffer: {:?}", &cached_encodings[..]);
3682
3683 cached_encodings.clear();
3691
3692 String::from("")
3693 } else {
3694 let warn_msg = format!("Fail to convert a vector of bytes to string. {}", e);
3695
3696 #[cfg(feature = "logging")]
3697 warn!(target: "stdout", "{}", &warn_msg);
3698
3699 String::from("")
3700 }
3701 }
3702 }
3703 }
3704 };
3705
3706 #[cfg(feature = "logging")]
3707 info!(target: "stdout", "decoded the output buffer");
3708
3709 let created = SystemTime::now()
3710 .duration_since(std::time::UNIX_EPOCH)
3711 .map_err(|e| {
3712 let err_msg = format!(
3713 "Failed to get the current time. Reason: {}",
3714 e
3715 );
3716
3717 #[cfg(feature = "logging")]
3718 error!(target: "stdout", "{}", &err_msg);
3719
3720 LlamaCoreError::Operation(err_msg)
3721 })?;
3722
3723 let chat_completion_chunk = ChatCompletionChunk {
3724 id,
3725 object: "chat.completion.chunk".to_string(),
3726 created: created.as_secs(),
3727 model: graph.name().to_owned(),
3728 system_fingerprint: "fp_44709d6fcb".to_string(),
3729 choices: vec![ChatCompletionChunkChoice {
3730 index: 0,
3731 delta: ChatCompletionChunkChoiceDelta {
3732 role: ChatCompletionRole::Assistant,
3733 content: Some(output),
3734 tool_calls: vec![],
3735 },
3736 logprobs: None,
3737 finish_reason: None,
3738 }],
3739 usage: None,
3740 };
3741
3742 #[cfg(feature = "logging")]
3743 info!(target: "stdout", "created chat completion chunk");
3744
3745 let chunk_str =
3747 serde_json::to_string(&chat_completion_chunk)
3748 .map_err(|e| {
3749 let err_msg = format!(
3750 "Failed to serialize chat completion chunk. Reason: {}",
3751 e
3752 );
3753
3754 #[cfg(feature = "logging")]
3755 error!(target: "stdout", "{}", &err_msg);
3756
3757 LlamaCoreError::Operation(err_msg)
3758 })?;
3759
3760 Ok(format!("data: {}\n\n", chunk_str))
3761 }
3762 StreamState::Done => {
3763 *stream_state = StreamState::EndOfSequence;
3764
3765 Ok("data: [DONE]\n\n".to_string())
3766 }
3767 StreamState::EndOfSequence => {
3768 Ok("[GGML] End of sequence".to_string())
3769 }
3770 }
3771 }
3772 Err(wasmedge_wasi_nn::Error::BackendError(
3773 wasmedge_wasi_nn::BackendError::EndOfSequence,
3774 )) => {
3775 #[cfg(feature = "logging")]
3776 debug!(target: "stdout", "End of sequence");
3777
3778 match stream_state {
3779 StreamState::Usage => {
3780 *stream_state = StreamState::Done;
3781
3782 let token_info = get_token_info_by_graph(graph)?;
3784
3785 let usage = Some(Usage {
3786 prompt_tokens: token_info.prompt_tokens,
3787 completion_tokens: token_info.completion_tokens,
3788 total_tokens: token_info.prompt_tokens
3789 + token_info.completion_tokens,
3790 });
3791
3792 #[cfg(feature = "logging")]
3793 info!(target: "stdout", "token_info: {} prompt tokens, {} completion tokens", token_info.prompt_tokens, token_info.completion_tokens);
3794
3795 let created = SystemTime::now()
3796 .duration_since(std::time::UNIX_EPOCH)
3797 .map_err(|e| {
3798 let err_msg = format!(
3799 "Failed to get the current time. Reason: {}",
3800 e
3801 );
3802
3803 #[cfg(feature = "logging")]
3804 error!(target: "stdout", "{}", &err_msg);
3805
3806 LlamaCoreError::Operation(err_msg)
3807 })?;
3808
3809 let chat_completion_chunk = ChatCompletionChunk {
3810 id,
3811 object: "chat.completion.chunk".to_string(),
3812 created: created.as_secs(),
3813 model: graph.name().to_owned(),
3814 system_fingerprint: "fp_44709d6fcb".to_string(),
3815 choices: vec![],
3816 usage,
3817 };
3818
3819 let chunk_str =
3821 serde_json::to_string(&chat_completion_chunk)
3822 .map_err(|e| {
3823 let err_msg = format!(
3824 "Failed to serialize chat completion chunk. Reason: {}",
3825 e
3826 );
3827
3828 #[cfg(feature = "logging")]
3829 error!(target: "stdout", "{}", &err_msg);
3830
3831 LlamaCoreError::Operation(err_msg)
3832 })?;
3833
3834 Ok(format!("data: {}\n\n", chunk_str))
3835 }
3836 StreamState::Done | StreamState::NoUsage => {
3837 *stream_state = StreamState::EndOfSequence;
3838
3839 Ok("data: [DONE]\n\n".to_string())
3840 }
3841 StreamState::EndOfSequence => {
3842 Ok("[GGML] End of sequence".to_string())
3843 }
3844 }
3845 }
3846 Err(wasmedge_wasi_nn::Error::BackendError(
3847 wasmedge_wasi_nn::BackendError::ContextFull,
3848 )) => {
3849 #[cfg(feature = "logging")]
3850 debug!(target: "stdout", "Context full");
3851
3852 match context_full_state {
3853 ContextFullState::Message => {
3854 match include_usage {
3855 true => {
3856 *context_full_state = ContextFullState::Usage
3857 }
3858 false => {
3859 *context_full_state = ContextFullState::Done
3860 }
3861 }
3862
3863 let created = SystemTime::now()
3864 .duration_since(std::time::UNIX_EPOCH)
3865 .map_err(|e| {
3866 let err_msg = format!(
3867 "Failed to get the current time. Reason: {}",
3868 e
3869 );
3870
3871 #[cfg(feature = "logging")]
3872 error!(target: "stdout", "{}", &err_msg);
3873
3874 LlamaCoreError::Operation(err_msg)
3875 })?;
3876
3877 let chat_completion_chunk = ChatCompletionChunk {
3878 id,
3879 object: "chat.completion.chunk".to_string(),
3880 created: created.as_secs(),
3881 model: graph.name().to_owned(),
3882 system_fingerprint: "fp_44709d6fcb".to_string(),
3883 choices: vec![ChatCompletionChunkChoice {
3884 index: 0,
3885 delta: ChatCompletionChunkChoiceDelta {
3886 role: ChatCompletionRole::Assistant,
3887 content: Some(
3888 "<|WASMEDGE-GGML-CONTEXT-FULL|>"
3889 .to_string(),
3890 ),
3891 tool_calls: vec![],
3892 },
3893 logprobs: None,
3894 finish_reason: Some(FinishReason::length),
3895 }],
3896 usage: None,
3897 };
3898
3899 let chunk_str =
3901 serde_json::to_string(&chat_completion_chunk)
3902 .map_err(|e| {
3903 let err_msg = format!(
3904 "Failed to serialize chat completion chunk. Reason: {}",
3905 e
3906 );
3907
3908 #[cfg(feature = "logging")]
3909 error!(target: "stdout", "{}", &err_msg);
3910
3911 LlamaCoreError::Operation(err_msg)
3912 })?;
3913
3914 Ok(format!("data: {}\n\n", chunk_str))
3915 }
3916 ContextFullState::Usage => {
3917 *context_full_state = ContextFullState::Done;
3918
3919 let token_info = get_token_info_by_graph(graph)?;
3921
3922 let usage = Some(Usage {
3923 prompt_tokens: token_info.prompt_tokens,
3924 completion_tokens: token_info.completion_tokens,
3925 total_tokens: token_info.prompt_tokens
3926 + token_info.completion_tokens,
3927 });
3928
3929 let created = SystemTime::now()
3930 .duration_since(std::time::UNIX_EPOCH)
3931 .map_err(|e| {
3932 let err_msg = format!(
3933 "Failed to get the current time. Reason: {}",
3934 e
3935 );
3936
3937 #[cfg(feature = "logging")]
3938 error!(target: "stdout", "{}", &err_msg);
3939
3940 LlamaCoreError::Operation(err_msg)
3941 })?;
3942
3943 let chat_completion_chunk = ChatCompletionChunk {
3944 id,
3945 object: "chat.completion.chunk".to_string(),
3946 created: created.as_secs(),
3947 model: graph.name().to_owned(),
3948 system_fingerprint: "fp_44709d6fcb".to_string(),
3949 choices: vec![],
3950 usage,
3951 };
3952
3953 let chunk_str =
3955 serde_json::to_string(&chat_completion_chunk)
3956 .map_err(|e| {
3957 let err_msg = format!(
3958 "Failed to serialize chat completion chunk. Reason: {}",
3959 e
3960 );
3961
3962 #[cfg(feature = "logging")]
3963 error!(target: "stdout", "{}", &err_msg);
3964
3965 LlamaCoreError::Operation(err_msg)
3966 })?;
3967
3968 Ok(format!("data: {}\n\n", chunk_str))
3969 }
3970 ContextFullState::Done => {
3971 *context_full_state = ContextFullState::EndOfSequence;
3972
3973 Ok("data: [DONE]\n\n".to_string())
3974 }
3975 ContextFullState::EndOfSequence => {
3976 Ok("[GGML] End of sequence".to_string())
3977 }
3978 }
3979 }
3980 Err(wasmedge_wasi_nn::Error::BackendError(
3981 wasmedge_wasi_nn::BackendError::PromptTooLong,
3982 )) => {
3983 #[cfg(feature = "logging")]
3984 debug!(target: "stdout", "Prompt too long");
3985
3986 match prompt_too_long_state {
3987 PromptTooLongState::Message => {
3988 match include_usage {
3989 true => {
3990 *prompt_too_long_state =
3991 PromptTooLongState::Usage
3992 }
3993 false => {
3994 *prompt_too_long_state =
3995 PromptTooLongState::Done
3996 }
3997 }
3998
3999 let created = SystemTime::now()
4000 .duration_since(std::time::UNIX_EPOCH)
4001 .map_err(|e| {
4002 let err_msg = format!(
4003 "Failed to get the current time. Reason: {}",
4004 e
4005 );
4006
4007 #[cfg(feature = "logging")]
4008 error!(target: "stdout", "{}", &err_msg);
4009
4010 LlamaCoreError::Operation(err_msg)
4011 })?;
4012
4013 let chat_completion_chunk = ChatCompletionChunk {
4014 id,
4015 object: "chat.completion.chunk".to_string(),
4016 created: created.as_secs(),
4017 model: graph.name().to_owned(),
4018 system_fingerprint: "fp_44709d6fcb".to_string(),
4019 choices: vec![ChatCompletionChunkChoice {
4020 index: 0,
4021 delta: ChatCompletionChunkChoiceDelta {
4022 role: ChatCompletionRole::Assistant,
4023 content: None,
4024 tool_calls: vec![],
4025 },
4026 logprobs: None,
4027 finish_reason: Some(FinishReason::length),
4028 }],
4029 usage: None,
4030 };
4031
4032 let chunk_str =
4034 serde_json::to_string(&chat_completion_chunk)
4035 .map_err(|e| {
4036 let err_msg = format!(
4037 "Failed to serialize chat completion chunk. Reason: {}",
4038 e
4039 );
4040
4041 #[cfg(feature = "logging")]
4042 error!(target: "stdout", "{}", &err_msg);
4043
4044 LlamaCoreError::Operation(err_msg)
4045 })?;
4046
4047 Ok(format!("data: {}\n\n", chunk_str))
4048 }
4049 PromptTooLongState::Usage => {
4050 *prompt_too_long_state = PromptTooLongState::Done;
4051
4052 let token_info = get_token_info_by_graph(graph)?;
4054
4055 let usage = Some(Usage {
4056 prompt_tokens: token_info.prompt_tokens,
4057 completion_tokens: token_info.completion_tokens,
4058 total_tokens: token_info.prompt_tokens
4059 + token_info.completion_tokens,
4060 });
4061
4062 let created = SystemTime::now()
4063 .duration_since(std::time::UNIX_EPOCH)
4064 .map_err(|e| {
4065 let err_msg = format!(
4066 "Failed to get the current time. Reason: {}",
4067 e
4068 );
4069
4070 #[cfg(feature = "logging")]
4071 error!(target: "stdout", "{}", &err_msg);
4072
4073 LlamaCoreError::Operation(err_msg)
4074 })?;
4075
4076 let chat_completion_chunk = ChatCompletionChunk {
4077 id,
4078 object: "chat.completion.chunk".to_string(),
4079 created: created.as_secs(),
4080 model: graph.name().to_owned(),
4081 system_fingerprint: "fp_44709d6fcb".to_string(),
4082 choices: vec![],
4083 usage,
4084 };
4085
4086 let chunk_str =
4088 serde_json::to_string(&chat_completion_chunk)
4089 .map_err(|e| {
4090 let err_msg = format!(
4091 "Failed to serialize chat completion chunk. Reason: {}",
4092 e
4093 );
4094
4095 #[cfg(feature = "logging")]
4096 error!(target: "stdout", "{}", &err_msg);
4097
4098 LlamaCoreError::Operation(err_msg)
4099 })?;
4100
4101 Ok(format!("data: {}\n\n", chunk_str))
4102 }
4103 PromptTooLongState::Done => {
4104 *prompt_too_long_state =
4105 PromptTooLongState::EndOfSequence;
4106
4107 Ok("data: [DONE]\n\n".to_string())
4108 }
4109 PromptTooLongState::EndOfSequence => {
4110 Ok("[GGML] End of sequence".to_string())
4111 }
4112 }
4113 }
4114 Err(e) => {
4115 let err_msg = format!(
4116 "Failed to compute the chat completion. Reason: {}",
4117 e
4118 );
4119
4120 #[cfg(feature = "logging")]
4121 error!(target: "stdout", "{}", &err_msg);
4122
4123 Err(LlamaCoreError::Backend(BackendError::ComputeSingle(
4124 err_msg,
4125 )))
4126 }
4127 }
4128 }
4129 None => {
4130 let err_msg = "There is no model available in the chat graphs.";
4131
4132 #[cfg(feature = "logging")]
4133 error!(target: "stdout", "{}", &err_msg);
4134
4135 Err(LlamaCoreError::Operation(err_msg.into()))
4136 }
4137 }
4138 }
4139 }
4140 }
4141 None => {
4142 match chat_graphs.iter_mut().next() {
4143 Some((_, graph)) => {
4144 match graph.compute_single() {
4146 Ok(_) => {
4147 #[cfg(feature = "logging")]
4148 debug!(target: "stdout", "Compute the chat stream chunk successfully.");
4149
4150 match stream_state {
4151 StreamState::Usage | StreamState::NoUsage => {
4152 let output_buffer =
4154 get_output_buffer_single(graph, OUTPUT_TENSOR)?;
4155
4156 #[cfg(feature = "logging")]
4157 info!(target: "stdout", "retrieved the output buffer");
4158
4159 let output = match String::from_utf8(output_buffer.clone()) {
4161 Ok(token) => token,
4162 Err(_) => {
4163 let mutex = CACHED_UTF8_ENCODINGS
4164 .get_or_init(|| Mutex::new(Vec::new()));
4165 let mut cached_encodings = mutex.lock().map_err(|e| {
4166 let err_msg = format!(
4167 "Fail to acquire the lock of `UTF8_ENCODINGS`. Reason: {}",
4168 e
4169 );
4170
4171 #[cfg(feature = "logging")]
4172 error!(target: "stdout", "{}", &err_msg);
4173
4174 LlamaCoreError::Operation(err_msg)
4175 })?;
4176
4177 cached_encodings.extend_from_slice(&output_buffer[..]);
4178
4179 match String::from_utf8(cached_encodings.to_vec()) {
4180 Ok(token) => {
4181 cached_encodings.clear();
4183
4184 token
4185 }
4186 Err(e) => {
4187 if cached_encodings.len() > 4 {
4189 let err_msg = format!("Fail to convert a vector of bytes to string. The length of the utf8 bytes exceeds 4. {}", e);
4190
4191 #[cfg(feature = "logging")]
4192 error!(target: "stdout", "{}", &err_msg);
4193
4194 #[cfg(feature = "logging")]
4195 error!(target: "stdout", "The cached buffer: {:?}", &cached_encodings[..]);
4196
4197 cached_encodings.clear();
4204
4205 String::from("")
4206 } else {
4207 let warn_msg = format!("Fail to convert a vector of bytes to string. {}", e);
4208
4209 #[cfg(feature = "logging")]
4210 warn!(target: "stdout", "{}", &warn_msg);
4211
4212 String::from("")
4213 }
4214 }
4215 }
4216 }
4217 };
4218
4219 #[cfg(feature = "logging")]
4220 info!(target: "stdout", "decoded the output buffer");
4221
4222 let created = SystemTime::now()
4223 .duration_since(std::time::UNIX_EPOCH)
4224 .map_err(|e| {
4225 let err_msg = format!(
4226 "Failed to get the current time. Reason: {}",
4227 e
4228 );
4229
4230 #[cfg(feature = "logging")]
4231 error!(target: "stdout", "{}", &err_msg);
4232
4233 LlamaCoreError::Operation(err_msg)
4234 })?;
4235
4236 let chat_completion_chunk = ChatCompletionChunk {
4237 id,
4238 object: "chat.completion.chunk".to_string(),
4239 created: created.as_secs(),
4240 model: graph.name().to_owned(),
4241 system_fingerprint: "fp_44709d6fcb".to_string(),
4242 choices: vec![ChatCompletionChunkChoice {
4243 index: 0,
4244 delta: ChatCompletionChunkChoiceDelta {
4245 role: ChatCompletionRole::Assistant,
4246 content: Some(output),
4247 tool_calls: vec![],
4248 },
4249 logprobs: None,
4250 finish_reason: None,
4251 }],
4252 usage: None,
4253 };
4254
4255 #[cfg(feature = "logging")]
4256 info!(target: "stdout", "created chat completion chunk");
4257
4258 let chunk_str = serde_json::to_string(&chat_completion_chunk)
4260 .map_err(|e| {
4261 let err_msg = format!(
4262 "Failed to serialize chat completion chunk. Reason: {}",
4263 e
4264 );
4265
4266 #[cfg(feature = "logging")]
4267 error!(target: "stdout", "{}", &err_msg);
4268
4269 LlamaCoreError::Operation(err_msg)
4270 })?;
4271
4272 Ok(format!("data: {}\n\n", chunk_str))
4273 }
4274 StreamState::Done => {
4275 *stream_state = StreamState::EndOfSequence;
4276
4277 Ok("data: [DONE]\n\n".to_string())
4278 }
4279 StreamState::EndOfSequence => {
4280 Ok("[GGML] End of sequence".to_string())
4281 }
4282 }
4283 }
4284 Err(wasmedge_wasi_nn::Error::BackendError(
4285 wasmedge_wasi_nn::BackendError::EndOfSequence,
4286 )) => {
4287 #[cfg(feature = "logging")]
4288 debug!(target: "stdout", "End of sequence");
4289
4290 match stream_state {
4291 StreamState::Usage => {
4292 *stream_state = StreamState::Done;
4293
4294 let token_info = get_token_info_by_graph(graph)?;
4296
4297 let usage = Some(Usage {
4298 prompt_tokens: token_info.prompt_tokens,
4299 completion_tokens: token_info.completion_tokens,
4300 total_tokens: token_info.prompt_tokens
4301 + token_info.completion_tokens,
4302 });
4303
4304 #[cfg(feature = "logging")]
4305 info!(target: "stdout", "token_info: {} prompt tokens, {} completion tokens", token_info.prompt_tokens, token_info.completion_tokens);
4306
4307 let created = SystemTime::now()
4308 .duration_since(std::time::UNIX_EPOCH)
4309 .map_err(|e| {
4310 let err_msg = format!(
4311 "Failed to get the current time. Reason: {}",
4312 e
4313 );
4314
4315 #[cfg(feature = "logging")]
4316 error!(target: "stdout", "{}", &err_msg);
4317
4318 LlamaCoreError::Operation(err_msg)
4319 })?;
4320
4321 let chat_completion_chunk = ChatCompletionChunk {
4322 id,
4323 object: "chat.completion.chunk".to_string(),
4324 created: created.as_secs(),
4325 model: graph.name().to_owned(),
4326 system_fingerprint: "fp_44709d6fcb".to_string(),
4327 choices: vec![],
4328 usage,
4329 };
4330
4331 let chunk_str = serde_json::to_string(&chat_completion_chunk)
4333 .map_err(|e| {
4334 let err_msg = format!(
4335 "Failed to serialize chat completion chunk. Reason: {}",
4336 e
4337 );
4338
4339 #[cfg(feature = "logging")]
4340 error!(target: "stdout", "{}", &err_msg);
4341
4342 LlamaCoreError::Operation(err_msg)
4343 })?;
4344
4345 Ok(format!("data: {}\n\n", chunk_str))
4346 }
4347 StreamState::Done | StreamState::NoUsage => {
4348 *stream_state = StreamState::EndOfSequence;
4349
4350 Ok("data: [DONE]\n\n".to_string())
4351 }
4352 StreamState::EndOfSequence => {
4353 Ok("[GGML] End of sequence".to_string())
4354 }
4355 }
4356 }
4357 Err(wasmedge_wasi_nn::Error::BackendError(
4358 wasmedge_wasi_nn::BackendError::ContextFull,
4359 )) => {
4360 #[cfg(feature = "logging")]
4361 debug!(target: "stdout", "Context full");
4362
4363 match context_full_state {
4364 ContextFullState::Message => {
4365 match include_usage {
4366 true => *context_full_state = ContextFullState::Usage,
4367 false => *context_full_state = ContextFullState::Done,
4368 }
4369
4370 let created = SystemTime::now()
4371 .duration_since(std::time::UNIX_EPOCH)
4372 .map_err(|e| {
4373 let err_msg = format!(
4374 "Failed to get the current time. Reason: {}",
4375 e
4376 );
4377
4378 #[cfg(feature = "logging")]
4379 error!(target: "stdout", "{}", &err_msg);
4380
4381 LlamaCoreError::Operation(err_msg)
4382 })?;
4383
4384 let chat_completion_chunk = ChatCompletionChunk {
4385 id,
4386 object: "chat.completion.chunk".to_string(),
4387 created: created.as_secs(),
4388 model: graph.name().to_owned(),
4389 system_fingerprint: "fp_44709d6fcb".to_string(),
4390 choices: vec![ChatCompletionChunkChoice {
4391 index: 0,
4392 delta: ChatCompletionChunkChoiceDelta {
4393 role: ChatCompletionRole::Assistant,
4394 content: Some(
4395 "<|WASMEDGE-GGML-CONTEXT-FULL|>".to_string(),
4396 ),
4397 tool_calls: vec![],
4398 },
4399 logprobs: None,
4400 finish_reason: Some(FinishReason::length),
4401 }],
4402 usage: None,
4403 };
4404
4405 let chunk_str = serde_json::to_string(&chat_completion_chunk)
4407 .map_err(|e| {
4408 let err_msg = format!(
4409 "Failed to serialize chat completion chunk. Reason: {}",
4410 e
4411 );
4412
4413 #[cfg(feature = "logging")]
4414 error!(target: "stdout", "{}", &err_msg);
4415
4416 LlamaCoreError::Operation(err_msg)
4417 })?;
4418
4419 Ok(format!("data: {}\n\n", chunk_str))
4420 }
4421 ContextFullState::Usage => {
4422 *context_full_state = ContextFullState::Done;
4423
4424 let token_info = get_token_info_by_graph(graph)?;
4426
4427 let usage = Some(Usage {
4428 prompt_tokens: token_info.prompt_tokens,
4429 completion_tokens: token_info.completion_tokens,
4430 total_tokens: token_info.prompt_tokens
4431 + token_info.completion_tokens,
4432 });
4433
4434 let created = SystemTime::now()
4435 .duration_since(std::time::UNIX_EPOCH)
4436 .map_err(|e| {
4437 let err_msg = format!(
4438 "Failed to get the current time. Reason: {}",
4439 e
4440 );
4441
4442 #[cfg(feature = "logging")]
4443 error!(target: "stdout", "{}", &err_msg);
4444
4445 LlamaCoreError::Operation(err_msg)
4446 })?;
4447
4448 let chat_completion_chunk = ChatCompletionChunk {
4449 id,
4450 object: "chat.completion.chunk".to_string(),
4451 created: created.as_secs(),
4452 model: graph.name().to_owned(),
4453 system_fingerprint: "fp_44709d6fcb".to_string(),
4454 choices: vec![],
4455 usage,
4456 };
4457
4458 let chunk_str = serde_json::to_string(&chat_completion_chunk)
4460 .map_err(|e| {
4461 let err_msg = format!(
4462 "Failed to serialize chat completion chunk. Reason: {}",
4463 e
4464 );
4465
4466 #[cfg(feature = "logging")]
4467 error!(target: "stdout", "{}", &err_msg);
4468
4469 LlamaCoreError::Operation(err_msg)
4470 })?;
4471
4472 Ok(format!("data: {}\n\n", chunk_str))
4473 }
4474 ContextFullState::Done => {
4475 *context_full_state = ContextFullState::EndOfSequence;
4476
4477 Ok("data: [DONE]\n\n".to_string())
4478 }
4479 ContextFullState::EndOfSequence => {
4480 Ok("[GGML] End of sequence".to_string())
4481 }
4482 }
4483 }
4484 Err(wasmedge_wasi_nn::Error::BackendError(
4485 wasmedge_wasi_nn::BackendError::PromptTooLong,
4486 )) => {
4487 #[cfg(feature = "logging")]
4488 debug!(target: "stdout", "Prompt too long");
4489
4490 match prompt_too_long_state {
4491 PromptTooLongState::Message => {
4492 match include_usage {
4493 true => *prompt_too_long_state = PromptTooLongState::Usage,
4494 false => *prompt_too_long_state = PromptTooLongState::Done,
4495 }
4496
4497 let created = SystemTime::now()
4498 .duration_since(std::time::UNIX_EPOCH)
4499 .map_err(|e| {
4500 let err_msg = format!(
4501 "Failed to get the current time. Reason: {}",
4502 e
4503 );
4504
4505 #[cfg(feature = "logging")]
4506 error!(target: "stdout", "{}", &err_msg);
4507
4508 LlamaCoreError::Operation(err_msg)
4509 })?;
4510
4511 let chat_completion_chunk = ChatCompletionChunk {
4512 id,
4513 object: "chat.completion.chunk".to_string(),
4514 created: created.as_secs(),
4515 model: graph.name().to_owned(),
4516 system_fingerprint: "fp_44709d6fcb".to_string(),
4517 choices: vec![ChatCompletionChunkChoice {
4518 index: 0,
4519 delta: ChatCompletionChunkChoiceDelta {
4520 role: ChatCompletionRole::Assistant,
4521 content: None,
4522 tool_calls: vec![],
4523 },
4524 logprobs: None,
4525 finish_reason: Some(FinishReason::length),
4526 }],
4527 usage: None,
4528 };
4529
4530 let chunk_str = serde_json::to_string(&chat_completion_chunk)
4532 .map_err(|e| {
4533 let err_msg = format!(
4534 "Failed to serialize chat completion chunk. Reason: {}",
4535 e
4536 );
4537
4538 #[cfg(feature = "logging")]
4539 error!(target: "stdout", "{}", &err_msg);
4540
4541 LlamaCoreError::Operation(err_msg)
4542 })?;
4543
4544 Ok(format!("data: {}\n\n", chunk_str))
4545 }
4546 PromptTooLongState::Usage => {
4547 *prompt_too_long_state = PromptTooLongState::Done;
4548
4549 let token_info = get_token_info_by_graph(graph)?;
4551
4552 let usage = Some(Usage {
4553 prompt_tokens: token_info.prompt_tokens,
4554 completion_tokens: token_info.completion_tokens,
4555 total_tokens: token_info.prompt_tokens
4556 + token_info.completion_tokens,
4557 });
4558
4559 let created = SystemTime::now()
4560 .duration_since(std::time::UNIX_EPOCH)
4561 .map_err(|e| {
4562 let err_msg = format!(
4563 "Failed to get the current time. Reason: {}",
4564 e
4565 );
4566
4567 #[cfg(feature = "logging")]
4568 error!(target: "stdout", "{}", &err_msg);
4569
4570 LlamaCoreError::Operation(err_msg)
4571 })?;
4572
4573 let chat_completion_chunk = ChatCompletionChunk {
4574 id,
4575 object: "chat.completion.chunk".to_string(),
4576 created: created.as_secs(),
4577 model: graph.name().to_owned(),
4578 system_fingerprint: "fp_44709d6fcb".to_string(),
4579 choices: vec![],
4580 usage,
4581 };
4582
4583 let chunk_str = serde_json::to_string(&chat_completion_chunk)
4585 .map_err(|e| {
4586 let err_msg = format!(
4587 "Failed to serialize chat completion chunk. Reason: {}",
4588 e
4589 );
4590
4591 #[cfg(feature = "logging")]
4592 error!(target: "stdout", "{}", &err_msg);
4593
4594 LlamaCoreError::Operation(err_msg)
4595 })?;
4596
4597 Ok(format!("data: {}\n\n", chunk_str))
4598 }
4599 PromptTooLongState::Done => {
4600 *prompt_too_long_state = PromptTooLongState::EndOfSequence;
4601
4602 Ok("data: [DONE]\n\n".to_string())
4603 }
4604 PromptTooLongState::EndOfSequence => {
4605 Ok("[GGML] End of sequence".to_string())
4606 }
4607 }
4608 }
4609 Err(e) => {
4610 let err_msg =
4611 format!("Failed to compute the chat completion. Reason: {}", e);
4612
4613 #[cfg(feature = "logging")]
4614 error!(target: "stdout", "{}", &err_msg);
4615
4616 Err(LlamaCoreError::Backend(BackendError::ComputeSingle(
4617 err_msg,
4618 )))
4619 }
4620 }
4621 }
4622 None => {
4623 let err_msg = "There is no model available in the chat graphs.";
4624
4625 #[cfg(feature = "logging")]
4626 error!(target: "stdout", "{}", &err_msg);
4627
4628 Err(LlamaCoreError::Operation(err_msg.into()))
4629 }
4630 }
4631 }
4632 };
4633
4634 #[cfg(feature = "logging")]
4635 info!(target: "stdout", "Return the chat stream chunk!");
4636
4637 res
4638}
4639
4640#[allow(dead_code)]
4641#[derive(Debug)]
4642struct ParseResult {
4643 raw: String,
4644 content: Option<String>,
4645 tool_calls: Vec<ToolCall>,
4646}