1use crate::{
4 error,
5 metadata::ggml::GgmlMetadata,
6 running_mode,
7 utils::{
8 gen_chat_id, get_output_buffer, get_output_buffer_single, get_token_info_by_graph,
9 get_token_info_by_graph_name, set_tensor_data_u8,
10 },
11 Graph, RunningMode, CHAT_GRAPHS, OUTPUT_TENSOR,
12};
13use chat_prompts::{BuildChatPrompt, ChatPrompt, PromptTemplateType};
14use either::{Either, Left, Right};
15use endpoints::{
16 chat::{
17 ChatCompletionChunk, ChatCompletionChunkChoice, ChatCompletionChunkChoiceDelta,
18 ChatCompletionObject, ChatCompletionObjectChoice, ChatCompletionObjectMessage,
19 ChatCompletionRequest, ChatCompletionRequestMessage, ChatCompletionRole,
20 ChatCompletionUserMessageContent, ContentPart, Function, ToolCall, ToolCallForChunk,
21 ToolChoice,
22 },
23 common::{FinishReason, Usage},
24};
25use error::{BackendError, LlamaCoreError};
26use std::{
27 collections::{HashMap, VecDeque},
28 pin::Pin,
29 sync::{
30 atomic::{AtomicBool, Ordering},
31 Arc, Mutex, OnceLock,
32 },
33 task::{Context, Poll, Waker},
34 time::SystemTime,
35};
36
37pub struct ModelStreamLock {
44 pub active: AtomicBool,
46 pub waker_queue: Mutex<VecDeque<Waker>>,
48}
49
50impl ModelStreamLock {
51 pub fn new() -> Self {
53 Self {
54 active: AtomicBool::new(false),
55 waker_queue: Mutex::new(VecDeque::new()),
56 }
57 }
58
59 pub fn try_acquire(&self) -> bool {
61 self.active
62 .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
63 .is_ok()
64 }
65
66 pub fn release(&self) {
68 self.active.store(false, Ordering::SeqCst);
69
70 if let Ok(mut queue) = self.waker_queue.lock() {
71 if let Some(waker) = queue.pop_front() {
72 waker.wake();
73 }
74 }
75 }
76
77 pub fn register_waker(&self, waker: &Waker) {
79 if let Ok(mut queue) = self.waker_queue.lock() {
80 queue.retain(|w| !w.will_wake(waker));
82 queue.push_back(waker.clone());
83 }
84 }
85}
86
87impl Default for ModelStreamLock {
88 fn default() -> Self {
89 Self::new()
90 }
91}
92
93pub struct ModelLockGuard {
97 lock: Arc<ModelStreamLock>,
98}
99
100impl ModelLockGuard {
101 pub fn new(lock: Arc<ModelStreamLock>) -> Self {
104 Self { lock }
105 }
106}
107
108impl Drop for ModelLockGuard {
109 fn drop(&mut self) {
110 self.lock.release();
111
112 #[cfg(feature = "logging")]
113 info!(target: "stdout", "ModelLockGuard: lock released on drop");
114 }
115}
116
117static MODEL_STREAM_LOCKS: OnceLock<Mutex<HashMap<String, Arc<ModelStreamLock>>>> = OnceLock::new();
120
121fn get_model_locks() -> &'static Mutex<HashMap<String, Arc<ModelStreamLock>>> {
123 MODEL_STREAM_LOCKS.get_or_init(|| {
124 #[cfg(feature = "logging")]
125 info!(target: "stdout", "Initializing model stream locks manager");
126 Mutex::new(HashMap::new())
127 })
128}
129
130pub fn get_or_create_model_lock(model_name: &str) -> Result<Arc<ModelStreamLock>, LlamaCoreError> {
132 let locks = get_model_locks();
133 let mut locks_guard = locks.lock().map_err(|e| {
134 let err_msg = format!("Failed to acquire model locks: {e}");
135
136 #[cfg(feature = "logging")]
137 error!(target: "stdout", "{}", &err_msg);
138
139 LlamaCoreError::Operation(err_msg)
140 })?;
141
142 if !locks_guard.contains_key(model_name) {
143 locks_guard.insert(model_name.to_string(), Arc::new(ModelStreamLock::new()));
144
145 #[cfg(feature = "logging")]
146 info!(target: "stdout", "Created new stream lock for model: {}", model_name);
147 }
148
149 Ok(locks_guard.get(model_name).unwrap().clone())
150}
151
152pub fn get_default_model_lock() -> Result<(String, Arc<ModelStreamLock>), LlamaCoreError> {
155 let chat_graphs = CHAT_GRAPHS.get().ok_or_else(|| {
156 let err_msg = "CHAT_GRAPHS not initialized";
157
158 #[cfg(feature = "logging")]
159 error!(target: "stdout", "{}", &err_msg);
160
161 LlamaCoreError::Operation(err_msg.into())
162 })?;
163
164 let chat_graphs = chat_graphs.lock().map_err(|e| {
165 let err_msg = format!("Failed to acquire CHAT_GRAPHS lock: {e}");
166
167 #[cfg(feature = "logging")]
168 error!(target: "stdout", "{}", &err_msg);
169
170 LlamaCoreError::Operation(err_msg)
171 })?;
172
173 let model_name = chat_graphs
174 .keys()
175 .next()
176 .ok_or_else(|| {
177 let err_msg = "No model available";
178
179 #[cfg(feature = "logging")]
180 error!(target: "stdout", "{}", &err_msg);
181
182 LlamaCoreError::Operation(err_msg.into())
183 })?
184 .clone();
185
186 drop(chat_graphs); let lock = get_or_create_model_lock(&model_name)?;
189 Ok((model_name, lock))
190}
191
192pub async fn chat(
194 chat_request: &mut ChatCompletionRequest,
195) -> Result<
196 (
197 Either<impl futures::TryStream<Ok = String, Error = LlamaCoreError>, ChatCompletionObject>,
198 bool,
199 ),
200 LlamaCoreError,
201> {
202 #[cfg(feature = "logging")]
203 {
204 debug!(target: "stdout", "tool choice: {:?}", chat_request.tool_choice.as_ref());
205 debug!(target: "stdout", "tools: {:?}", chat_request.tools.as_ref());
206 debug!(target: "stdout", "stream mode: {:?}", chat_request.stream);
207 }
208
209 let result = match chat_request.stream {
210 Some(true) => match chat_stream(chat_request).await {
211 Ok((stream, include_tool_calls)) => Ok((Left(stream), include_tool_calls)),
212 Err(e) => Err(e),
213 },
214 Some(false) | None => match chat_once(chat_request).await {
215 Ok((chat_completion_object, include_tool_calls)) => {
216 Ok((Right(chat_completion_object), include_tool_calls))
217 }
218 Err(e) => Err(e),
219 },
220 };
221
222 #[cfg(feature = "logging")]
223 info!(target: "stdout", "Reset the model metadata");
224
225 result
226}
227
228async fn chat_stream(
229 chat_request: &mut ChatCompletionRequest,
230) -> Result<
231 (
232 impl futures::TryStream<Ok = String, Error = LlamaCoreError>,
233 bool,
234 ),
235 LlamaCoreError,
236> {
237 #[cfg(feature = "logging")]
238 info!(target: "stdout", "Process chat completion request in the stream mode");
239
240 let running_mode = running_mode()?;
241 if !running_mode.contains(RunningMode::CHAT) && !running_mode.contains(RunningMode::RAG) {
242 let err_msg = "The chat completion is only supported in the chat or rag mode.";
243
244 #[cfg(feature = "logging")]
245 error!(target: "stdout", "{err_msg}");
246
247 return Err(LlamaCoreError::Operation(err_msg.to_string()));
248 }
249
250 let model_name = chat_request.model.clone();
251 let id = match &chat_request.user {
252 Some(id) => id.clone(),
253 None => gen_chat_id(),
254 };
255 #[cfg(feature = "logging")]
256 info!(target: "stdout", "user: {}", &id);
257
258 #[cfg(feature = "logging")]
259 info!(target: "stdout", "Check model metadata");
260
261 let mut metadata = check_model_metadata(chat_request)?;
263
264 let include_usage = match chat_request.stream_options {
266 Some(ref stream_options) => stream_options.include_usage.unwrap_or_default(),
267 None => metadata.include_usage,
268 };
269 #[cfg(feature = "logging")]
270 info!(target: "stdout", "include_usage: {include_usage}");
271
272 #[cfg(feature = "logging")]
273 info!(target: "stdout", "Build the chat prompt");
274
275 let (prompt, avaible_completion_tokens, tool_use) =
277 build_prompt(model_name.as_ref(), chat_request)?;
278
279 #[cfg(feature = "logging")]
280 {
281 info!(target: "stdout", "prompt:\n{}", &prompt);
282 info!(target: "stdout", "available_completion_tokens: {avaible_completion_tokens}");
283 info!(target: "stdout", "tool_use: {tool_use}");
284 }
285
286 #[cfg(feature = "logging")]
287 info!(target: "stdout", "Update the n_predict");
288
289 update_n_predict(chat_request, &mut metadata, avaible_completion_tokens)?;
291
292 #[cfg(feature = "logging")]
293 info!(target: "stdout", "Feed the prompt to the model");
294
295 set_prompt(chat_request.model.as_ref(), &prompt)?;
297
298 let stream = match tool_use {
299 false => (ChatStream::new(model_name, id, include_usage, None)?, false),
300 true => {
301 let chat_graphs = match CHAT_GRAPHS.get() {
302 Some(chat_graphs) => chat_graphs,
303 None => {
304 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
305
306 #[cfg(feature = "logging")]
307 error!(target: "stdout", "{}", &err_msg);
308
309 return Err(LlamaCoreError::Operation(err_msg.into()));
310 }
311 };
312
313 let mut chat_graphs = chat_graphs.lock().map_err(|e| {
314 let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
315
316 #[cfg(feature = "logging")]
317 error!(target: "stdout", "{}", &err_msg);
318
319 LlamaCoreError::Operation(err_msg)
320 })?;
321
322 match model_name {
323 Some(model_name) => match chat_graphs.contains_key(&model_name) {
324 true => {
325 let graph = chat_graphs.get_mut(&model_name).unwrap();
326 chat_stream_for_tool(graph, id, include_usage)?
327 }
328 false => match chat_graphs.iter_mut().next() {
329 Some((_, graph)) => chat_stream_for_tool(graph, id, include_usage)?,
330 None => {
331 let err_msg = "There is no model available in the chat graphs.";
332
333 #[cfg(feature = "logging")]
334 error!(target: "stdout", "{}", &err_msg);
335
336 return Err(LlamaCoreError::Operation(err_msg.into()));
337 }
338 },
339 },
340 None => match chat_graphs.iter_mut().next() {
341 Some((_, graph)) => chat_stream_for_tool(graph, id, include_usage)?,
342 None => {
343 let err_msg = "There is no model available in the chat graphs.";
344
345 #[cfg(feature = "logging")]
346 error!(target: "stdout", "{}", &err_msg);
347
348 return Err(LlamaCoreError::Operation(err_msg.into()));
349 }
350 },
351 }
352 }
353 };
354
355 #[cfg(feature = "logging")]
356 info!(target: "stdout", "End of the chat completion stream.");
357
358 Ok(stream)
359}
360
361fn chat_stream_for_tool(
362 graph: &mut Graph<GgmlMetadata>,
363 id: impl Into<String>,
364 include_usage: bool,
365) -> Result<(ChatStream, bool), LlamaCoreError> {
366 #[cfg(feature = "logging")]
367 info!(target: "stdout", "Handle chat request with available tools by the model named {}.", graph.name());
368
369 let id = id.into();
370
371 match graph.compute() {
372 Ok(_) => {
373 let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
375 let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
376 let err_msg = format!(
377 "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
378 );
379
380 #[cfg(feature = "logging")]
381 error!(target: "stdout", "{}", &err_msg);
382
383 LlamaCoreError::Operation(err_msg)
384 })?;
385
386 #[cfg(feature = "logging")]
387 info!(target: "stdout", "raw generation:\n{output}");
388
389 let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
391 LlamaCoreError::Operation(format!("Failed to post-process the output. {e}"))
392 })?;
393
394 #[cfg(feature = "logging")]
395 info!(target: "stdout", "post-processed generation:\n{}", &message);
396
397 let token_info = get_token_info_by_graph(graph)?;
399
400 #[cfg(feature = "logging")]
401 info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
402
403 let usage = Some(Usage {
404 prompt_tokens: token_info.prompt_tokens,
405 completion_tokens: token_info.completion_tokens,
406 total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
407 });
408
409 let created = SystemTime::now()
410 .duration_since(std::time::UNIX_EPOCH)
411 .map_err(|e| {
412 let err_msg = format!("Failed to get the current time. Reason: {e}");
413
414 #[cfg(feature = "logging")]
415 error!(target: "stdout", "{}", &err_msg);
416
417 LlamaCoreError::Operation(err_msg)
418 })?;
419
420 if graph.metadata.prompt_template != PromptTemplateType::MistralTool
421 && graph.metadata.prompt_template != PromptTemplateType::ChatMLTool
422 && graph.metadata.prompt_template != PromptTemplateType::GroqLlama3Tool
423 && graph.metadata.prompt_template != PromptTemplateType::Llama3Tool
424 && graph.metadata.prompt_template != PromptTemplateType::InternLM2Tool
425 && graph.metadata.prompt_template != PromptTemplateType::NemotronTool
426 && graph.metadata.prompt_template != PromptTemplateType::FunctionaryV32
427 && graph.metadata.prompt_template != PromptTemplateType::FunctionaryV31
428 && graph.metadata.prompt_template != PromptTemplateType::MistralSmallTool
429 && graph.metadata.prompt_template != PromptTemplateType::Llama4Chat
430 && graph.metadata.prompt_template != PromptTemplateType::Qwen3NoThink
431 && graph.metadata.prompt_template != PromptTemplateType::Smol3NoThink
432 && graph.metadata.prompt_template != PromptTemplateType::Gemma3
433 && graph.metadata.prompt_template != PromptTemplateType::GptOss
434 && graph.metadata.prompt_template != PromptTemplateType::Qwen3Agent
435 && graph.metadata.prompt_template != PromptTemplateType::SeedOssNoThink
436 && graph.metadata.prompt_template != PromptTemplateType::SeedOssThink
437 {
438 let err_msg = format!("Unsupported prompt template: {}. The tool use is only supported for 'mistral-tool', 'chatml-tool', 'groq-llama3-tool', 'llama-3-tool', 'internlm-2-tool', 'nemotron-tool', 'functionary-31', 'functionary-32', 'mistral-small-tool', 'llama-4-chat', 'qwen3-no-think', 'smol-3-no-think', 'gemma-3', 'gpt-oss', 'qwen3-agent', 'seed-oss-no-think', and 'seed-oss-think' prompt templates.", graph.metadata.prompt_template);
439
440 #[cfg(feature = "logging")]
441 error!(target: "stdout", "{}", &err_msg);
442
443 return Err(LlamaCoreError::Operation(err_msg));
444 }
445
446 let parsed_result = parse_tool_calls(&message, graph.metadata.prompt_template)?;
447
448 let content = if parsed_result.tool_calls.is_empty() {
449 Some(parsed_result.raw.clone())
450 } else {
451 parsed_result.content.clone()
452 };
453
454 let (tool_calls, include_tool_calls) = match parsed_result.tool_calls.is_empty() {
455 false => {
456 let tool_calls: Vec<ToolCallForChunk> = parsed_result
457 .tool_calls
458 .into_iter()
459 .enumerate()
460 .map(|(index, tool_call)| ToolCallForChunk {
461 index,
462 id: tool_call.id,
463 ty: tool_call.ty,
464 function: tool_call.function,
465 })
466 .collect();
467 (tool_calls, true)
468 }
469 true => (vec![], false),
470 };
471
472 let tool_call_chunk = {
474 let chat_completion_chunk = ChatCompletionChunk {
475 id: id.clone(),
476 object: "chat.completion.chunk".to_string(),
477 created: created.as_secs(),
478 model: graph.name().to_owned(),
479 system_fingerprint: "fp_44709d6fcb".to_string(),
480 choices: vec![ChatCompletionChunkChoice {
481 index: 0,
482 delta: ChatCompletionChunkChoiceDelta {
483 role: ChatCompletionRole::Assistant,
484 content,
485 tool_calls,
486 },
487 logprobs: None,
488 finish_reason: None,
489 }],
490 usage: None,
491 };
492 let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
493 let err_msg = format!("Failed to serialize chat completion chunk. Reason: {e}");
494
495 #[cfg(feature = "logging")]
496 error!(target: "stdout", "{}", &err_msg);
497
498 LlamaCoreError::Operation(err_msg)
499 })?;
500
501 format!("data: {chunk_str}\n\n")
502 };
503
504 let usage_chunk = {
506 let chat_completion_chunk = ChatCompletionChunk {
507 id: id.clone(),
508 object: "chat.completion.chunk".to_string(),
509 created: created.as_secs(),
510 model: graph.name().to_owned(),
511 system_fingerprint: "fp_44709d6fcb".to_string(),
512 choices: vec![],
513 usage,
514 };
515 let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
516 let err_msg = format!("Failed to serialize chat completion chunk. Reason: {e}");
517
518 #[cfg(feature = "logging")]
519 error!(target: "stdout", "{}", &err_msg);
520
521 LlamaCoreError::Operation(err_msg)
522 })?;
523
524 format!("data: {chunk_str}\n\n")
525 };
526
527 let ending_chunk = "data: [DONE]\n\n".to_string();
529
530 let chunks = vec![tool_call_chunk, usage_chunk, ending_chunk];
531
532 let stream = ChatStream::new(
533 Some(graph.name().to_owned()),
534 id,
535 include_usage,
536 Some(chunks),
537 )?;
538
539 Ok((stream, include_tool_calls))
540 }
541 Err(wasmedge_wasi_nn::Error::BackendError(wasmedge_wasi_nn::BackendError::ContextFull)) => {
542 let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
544 let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
545 let err_msg = format!(
546 "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
547 );
548
549 #[cfg(feature = "logging")]
550 error!(target: "stdout", "{}", &err_msg);
551
552 LlamaCoreError::Operation(err_msg)
553 })?;
554
555 let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
557 let err_msg = format!("Failed to post-process the output. {e}");
558
559 #[cfg(feature = "logging")]
560 error!(target: "stdout", "{}", &err_msg);
561
562 LlamaCoreError::Operation(err_msg)
563 })?;
564
565 let token_info = get_token_info_by_graph(graph)?;
567
568 #[cfg(feature = "logging")]
569 info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
570
571 let usage = Some(Usage {
572 prompt_tokens: token_info.prompt_tokens,
573 completion_tokens: token_info.completion_tokens,
574 total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
575 });
576
577 let created = SystemTime::now()
578 .duration_since(std::time::UNIX_EPOCH)
579 .map_err(|e| {
580 let err_msg = format!("Failed to get the current time. Reason: {e}");
581
582 #[cfg(feature = "logging")]
583 error!(target: "stdout", "{}", &err_msg);
584
585 LlamaCoreError::Operation(err_msg)
586 })?;
587
588 let context_full_chunk = {
590 let chat_completion_chunk = ChatCompletionChunk {
591 id: id.clone(),
592 object: "chat.completion.chunk".to_string(),
593 created: created.as_secs(),
594 model: graph.name().to_owned(),
595 system_fingerprint: "fp_44709d6fcb".to_string(),
596 choices: vec![ChatCompletionChunkChoice {
597 index: 0,
598 delta: ChatCompletionChunkChoiceDelta {
599 role: ChatCompletionRole::Assistant,
600 content: Some(message),
601 tool_calls: vec![],
602 },
603 logprobs: None,
604 finish_reason: Some(FinishReason::length),
605 }],
606 usage: None,
607 };
608
609 let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
611 let err_msg = format!("Failed to serialize chat completion chunk. Reason: {e}");
612
613 #[cfg(feature = "logging")]
614 error!(target: "stdout", "{}", &err_msg);
615
616 LlamaCoreError::Operation(err_msg)
617 })?;
618
619 format!("data: {chunk_str}\n\n")
620 };
621
622 let usage_chunk = {
624 let chat_completion_chunk = ChatCompletionChunk {
625 id: id.clone(),
626 object: "chat.completion.chunk".to_string(),
627 created: created.as_secs(),
628 model: graph.name().to_owned(),
629 system_fingerprint: "fp_44709d6fcb".to_string(),
630 choices: vec![],
631 usage,
632 };
633
634 let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
636 let err_msg = format!("Failed to serialize chat completion chunk. Reason: {e}");
637
638 #[cfg(feature = "logging")]
639 error!(target: "stdout", "{}", &err_msg);
640
641 LlamaCoreError::Operation(err_msg)
642 })?;
643
644 format!("data: {chunk_str}\n\n")
645 };
646
647 let ending_chunk = "data: [DONE]\n\n".to_string();
649
650 let chunks = vec![context_full_chunk, usage_chunk, ending_chunk];
651
652 let stream = ChatStream::new(
653 Some(graph.name().to_owned()),
654 id,
655 include_usage,
656 Some(chunks),
657 )?;
658
659 Ok((stream, false))
660 }
661 Err(wasmedge_wasi_nn::Error::BackendError(
662 wasmedge_wasi_nn::BackendError::PromptTooLong,
663 )) => {
664 #[cfg(feature = "logging")]
665 warn!(target: "stdout", "The prompt is too long. Please reduce the length of your input and try again.");
666
667 let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
669 let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
670 let err_msg = format!(
671 "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
672 );
673
674 #[cfg(feature = "logging")]
675 error!(target: "stdout", "{}", &err_msg);
676
677 LlamaCoreError::Operation(err_msg)
678 })?;
679
680 let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
682 let err_msg = format!("Failed to post-process the output. {e}");
683
684 #[cfg(feature = "logging")]
685 error!(target: "stdout", "{}", &err_msg);
686
687 LlamaCoreError::Operation(err_msg)
688 })?;
689
690 let token_info = get_token_info_by_graph(graph)?;
692
693 #[cfg(feature = "logging")]
694 info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
695
696 let usage = Some(Usage {
697 prompt_tokens: token_info.prompt_tokens,
698 completion_tokens: token_info.completion_tokens,
699 total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
700 });
701
702 let created = SystemTime::now()
703 .duration_since(std::time::UNIX_EPOCH)
704 .map_err(|e| {
705 let err_msg = format!("Failed to get the current time. Reason: {e}");
706
707 #[cfg(feature = "logging")]
708 error!(target: "stdout", "{}", &err_msg);
709
710 LlamaCoreError::Operation(err_msg)
711 })?;
712
713 let prompt_too_long_chunk = {
715 let chat_completion_chunk = ChatCompletionChunk {
716 id: id.clone(),
717 object: "chat.completion.chunk".to_string(),
718 created: created.as_secs(),
719 model: graph.name().to_owned(),
720 system_fingerprint: "fp_44709d6fcb".to_string(),
721 choices: vec![ChatCompletionChunkChoice {
722 index: 0,
723 delta: ChatCompletionChunkChoiceDelta {
724 role: ChatCompletionRole::Assistant,
725 content: Some(message),
726 tool_calls: vec![],
727 },
728 logprobs: None,
729 finish_reason: Some(FinishReason::length),
730 }],
731 usage: None,
732 };
733
734 let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
736 let err_msg = format!("Failed to serialize chat completion chunk. Reason: {e}");
737
738 #[cfg(feature = "logging")]
739 error!(target: "stdout", "{}", &err_msg);
740
741 LlamaCoreError::Operation(err_msg)
742 })?;
743
744 format!("data: {chunk_str}\n\n")
745 };
746
747 let usage_chunk = {
749 let chat_completion_chunk = ChatCompletionChunk {
750 id: id.clone(),
751 object: "chat.completion.chunk".to_string(),
752 created: created.as_secs(),
753 model: graph.name().to_owned(),
754 system_fingerprint: "fp_44709d6fcb".to_string(),
755 choices: vec![],
756 usage,
757 };
758
759 let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
761 let err_msg = format!("Failed to serialize chat completion chunk. Reason: {e}");
762
763 #[cfg(feature = "logging")]
764 error!(target: "stdout", "{}", &err_msg);
765
766 LlamaCoreError::Operation(err_msg)
767 })?;
768
769 format!("data: {chunk_str}\n\n")
770 };
771
772 let ending_chunk = "data: [DONE]\n\n".to_string();
774
775 let chunks = vec![prompt_too_long_chunk, usage_chunk, ending_chunk];
776
777 let stream = ChatStream::new(
778 Some(graph.name().to_owned()),
779 id,
780 include_usage,
781 Some(chunks),
782 )?;
783
784 Ok((stream, false))
785 }
786 Err(e) => {
787 let err_msg = format!("Failed to compute the chat completion. Reason: {e}");
788
789 #[cfg(feature = "logging")]
790 error!(target: "stdout", "{}", &err_msg);
791
792 Err(LlamaCoreError::Backend(BackendError::Compute(err_msg)))
793 }
794 }
795}
796
797async fn chat_once(
798 chat_request: &mut ChatCompletionRequest,
799) -> Result<(ChatCompletionObject, bool), LlamaCoreError> {
800 #[cfg(feature = "logging")]
801 info!(target: "stdout", "Processing chat completion request in non-stream mode");
802
803 let running_mode = running_mode()?;
804 if !running_mode.contains(RunningMode::CHAT) && !running_mode.contains(RunningMode::RAG) {
805 let err_msg = "The chat completion is only supported in the chat or rag mode.";
806
807 #[cfg(feature = "logging")]
808 error!(target: "stdout", "{err_msg}");
809
810 return Err(LlamaCoreError::Operation(err_msg.to_string()));
811 }
812
813 let model_name = chat_request.model.clone();
814
815 let (resolved_model, model_lock) = match &model_name {
817 Some(name) => (name.clone(), get_or_create_model_lock(name)?),
818 None => get_default_model_lock()?,
819 };
820
821 while !model_lock.try_acquire() {
824 #[cfg(feature = "logging")]
825 debug!(target: "stdout", "Non-stream request waiting for model lock: {}", &resolved_model);
826
827 std::hint::spin_loop();
829 }
830
831 #[cfg(feature = "logging")]
832 info!(target: "stdout", "Non-stream request acquired lock for model: {}", &resolved_model);
833
834 let _lock_guard = ModelLockGuard::new(model_lock);
836
837 let id = match &chat_request.user {
838 Some(id) => id.clone(),
839 None => gen_chat_id(),
840 };
841
842 #[cfg(feature = "logging")]
843 info!(target: "stdout", "user: {}", &id);
844
845 #[cfg(feature = "logging")]
846 info!(target: "stdout", "Check model metadata");
847
848 let mut metadata = check_model_metadata(chat_request)?;
850
851 #[cfg(feature = "logging")]
852 info!(target: "stdout", "Build the chat prompt");
853
854 let (prompt, avaible_completion_tokens, tool_use) =
856 build_prompt(Some(&resolved_model), chat_request)?;
857
858 #[cfg(feature = "logging")]
859 {
860 info!(target: "stdout", "prompt:\n{}", &prompt);
861 info!(target: "stdout", "available_completion_tokens: {avaible_completion_tokens}");
862 info!(target: "stdout", "tool_use: {tool_use}");
863 }
864
865 #[cfg(feature = "logging")]
866 info!(target: "stdout", "Update n_predict");
867
868 update_n_predict(chat_request, &mut metadata, avaible_completion_tokens)?;
870
871 #[cfg(feature = "logging")]
872 info!(target: "stdout", "Feed the prompt to the model");
873
874 set_prompt(Some(&resolved_model), &prompt)?;
876
877 #[cfg(feature = "logging")]
878 info!(target: "stdout", "Compute chat completion.");
879
880 let res = compute(Some(&resolved_model), id, tool_use);
882
883 #[cfg(feature = "logging")]
884 info!(target: "stdout", "End of the chat completion");
885
886 reset_model_metadata(Some(&resolved_model))?;
888
889 res
890 }
892
893fn compute(
894 model_name: Option<&String>,
895 id: impl Into<String>,
896 tool_use: bool,
897) -> Result<(ChatCompletionObject, bool), LlamaCoreError> {
898 let chat_graphs = match CHAT_GRAPHS.get() {
899 Some(chat_graphs) => chat_graphs,
900 None => {
901 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
902
903 #[cfg(feature = "logging")]
904 error!(target: "stdout", "{}", &err_msg);
905
906 return Err(LlamaCoreError::Operation(err_msg.into()));
907 }
908 };
909
910 let mut chat_graphs = chat_graphs.lock().map_err(|e| {
911 let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
912
913 #[cfg(feature = "logging")]
914 error!(target: "stdout", "{}", &err_msg);
915
916 LlamaCoreError::Operation(err_msg)
917 })?;
918
919 match model_name {
920 Some(model_name) => match chat_graphs.contains_key(model_name) {
921 true => {
922 let graph = chat_graphs.get_mut(model_name).unwrap();
923 compute_by_graph(graph, id, tool_use)
924 }
925 false => match chat_graphs.iter_mut().next() {
926 Some((_, graph)) => compute_by_graph(graph, id, tool_use),
927 None => {
928 let err_msg = "There is no model available in the chat graphs.";
929
930 #[cfg(feature = "logging")]
931 error!(target: "stdout", "{}", &err_msg);
932
933 Err(LlamaCoreError::Operation(err_msg.into()))
934 }
935 },
936 },
937 None => match chat_graphs.iter_mut().next() {
938 Some((_, graph)) => compute_by_graph(graph, id, tool_use),
939 None => {
940 let err_msg = "There is no model available in the chat graphs.";
941
942 #[cfg(feature = "logging")]
943 error!(target: "stdout", "{}", &err_msg);
944
945 Err(LlamaCoreError::Operation(err_msg.into()))
946 }
947 },
948 }
949}
950
951fn compute_by_graph(
952 graph: &mut Graph<GgmlMetadata>,
953 id: impl Into<String>,
954 tool_use: bool,
955) -> Result<(ChatCompletionObject, bool), LlamaCoreError> {
956 #[cfg(feature = "logging")]
957 info!(target: "stdout", "Compute chat completion by the model named {}.", graph.name());
958
959 match graph.compute() {
960 Ok(_) => {
961 let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
963 let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
964 let err_msg = format!(
965 "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
966 );
967
968 #[cfg(feature = "logging")]
969 error!(target: "stdout", "{}", &err_msg);
970
971 LlamaCoreError::Operation(err_msg)
972 })?;
973
974 #[cfg(feature = "logging")]
975 info!(target: "stdout", "raw generation: {output}");
976
977 let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
979 LlamaCoreError::Operation(format!("Failed to post-process the output. {e}"))
980 })?;
981
982 #[cfg(feature = "logging")]
983 info!(target: "stdout", "post-processed generation:\n{}", &message);
984
985 let token_info = get_token_info_by_graph(graph)?;
987
988 #[cfg(feature = "logging")]
989 info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
990
991 let created = SystemTime::now()
992 .duration_since(std::time::UNIX_EPOCH)
993 .map_err(|e| {
994 let err_msg = format!("Failed to get the current time. Reason: {e}");
995
996 #[cfg(feature = "logging")]
997 error!(target: "stdout", "{}", &err_msg);
998
999 LlamaCoreError::Operation(err_msg)
1000 })?;
1001
1002 match tool_use {
1003 true => {
1004 if graph.metadata.prompt_template != PromptTemplateType::MistralTool
1005 && graph.metadata.prompt_template != PromptTemplateType::ChatMLTool
1006 && graph.metadata.prompt_template != PromptTemplateType::GroqLlama3Tool
1007 && graph.metadata.prompt_template != PromptTemplateType::Llama3Tool
1008 && graph.metadata.prompt_template != PromptTemplateType::InternLM2Tool
1009 && graph.metadata.prompt_template != PromptTemplateType::NemotronTool
1010 && graph.metadata.prompt_template != PromptTemplateType::FunctionaryV32
1011 && graph.metadata.prompt_template != PromptTemplateType::FunctionaryV31
1012 && graph.metadata.prompt_template != PromptTemplateType::MistralSmallTool
1013 && graph.metadata.prompt_template != PromptTemplateType::Llama4Chat
1014 && graph.metadata.prompt_template != PromptTemplateType::Qwen3NoThink
1015 && graph.metadata.prompt_template != PromptTemplateType::Smol3NoThink
1016 && graph.metadata.prompt_template != PromptTemplateType::Gemma3
1017 && graph.metadata.prompt_template != PromptTemplateType::GptOss
1018 && graph.metadata.prompt_template != PromptTemplateType::Qwen3Agent
1019 && graph.metadata.prompt_template != PromptTemplateType::SeedOssNoThink
1020 && graph.metadata.prompt_template != PromptTemplateType::SeedOssThink
1021 {
1022 let err_msg = format!("Unsupported prompt template: {}. The tool use is only supported for 'mistral-tool', 'chatml-tool', 'groq-llama3-tool', 'llama-3-tool', 'internlm-2-tool', 'nemotron-tool', 'functionary-31', 'functionary-32', 'mistral-small-tool', 'llama-4-chat', 'qwen3-no-think', 'smol-3-no-think', 'gemma-3', 'gpt-oss', 'qwen3-agent', 'seed-oss-no-think', and 'seed-oss-think' prompt templates.", graph.metadata.prompt_template);
1023
1024 #[cfg(feature = "logging")]
1025 error!(target: "stdout", "{}", &err_msg);
1026
1027 return Err(LlamaCoreError::Operation(err_msg));
1028 }
1029
1030 let parsed_result = parse_tool_calls(&message, graph.metadata.prompt_template)?;
1031
1032 let (finish_reason, content, include_tool_calls) =
1033 if parsed_result.tool_calls.is_empty() {
1034 (FinishReason::stop, Some(parsed_result.raw.clone()), false)
1035 } else if graph.metadata.prompt_template != PromptTemplateType::Qwen3Agent {
1036 (
1037 FinishReason::tool_calls,
1038 Some(parsed_result.raw.clone()),
1039 true,
1040 )
1041 } else {
1042 (
1043 FinishReason::tool_calls,
1044 parsed_result.content.clone(),
1045 true,
1046 )
1047 };
1048
1049 let res = ChatCompletionObject {
1050 id: id.into(),
1051 object: String::from("chat.completion"),
1052 created: created.as_secs(),
1053 model: graph.name().to_owned(),
1054 choices: vec![ChatCompletionObjectChoice {
1055 index: 0,
1056 message: ChatCompletionObjectMessage {
1057 role: ChatCompletionRole::Assistant,
1058 content,
1059 tool_calls: parsed_result.tool_calls,
1060 function_call: None,
1061 },
1062 finish_reason,
1063 logprobs: None,
1064 }],
1065 usage: Usage {
1066 prompt_tokens: token_info.prompt_tokens,
1067 completion_tokens: token_info.completion_tokens,
1068 total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
1069 },
1070 };
1071
1072 Ok((res, include_tool_calls))
1074 }
1075 false => {
1076 let res = ChatCompletionObject {
1078 id: id.into(),
1079 object: String::from("chat.completion"),
1080 created: created.as_secs(),
1081 model: graph.name().to_owned(),
1082 choices: vec![ChatCompletionObjectChoice {
1083 index: 0,
1084 message: ChatCompletionObjectMessage {
1085 role: ChatCompletionRole::Assistant,
1086 content: Some(message),
1087 tool_calls: vec![],
1088 function_call: None,
1089 },
1090 finish_reason: FinishReason::stop,
1091 logprobs: None,
1092 }],
1093 usage: Usage {
1094 prompt_tokens: token_info.prompt_tokens,
1095 completion_tokens: token_info.completion_tokens,
1096 total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
1097 },
1098 };
1099
1100 Ok((res, false))
1101 }
1102 }
1103 }
1104 Err(wasmedge_wasi_nn::Error::BackendError(wasmedge_wasi_nn::BackendError::ContextFull)) => {
1105 let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
1107 let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
1108 let err_msg = format!(
1109 "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
1110 );
1111
1112 #[cfg(feature = "logging")]
1113 error!(target: "stdout", "{}", &err_msg);
1114
1115 LlamaCoreError::Operation(err_msg)
1116 })?;
1117
1118 let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
1120 let err_msg = format!("Failed to post-process the output. {e}");
1121
1122 #[cfg(feature = "logging")]
1123 error!(target: "stdout", "{}", &err_msg);
1124
1125 LlamaCoreError::Operation(err_msg)
1126 })?;
1127
1128 let token_info = get_token_info_by_graph(graph)?;
1130
1131 #[cfg(feature = "logging")]
1132 info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
1133
1134 let created = SystemTime::now()
1135 .duration_since(std::time::UNIX_EPOCH)
1136 .map_err(|e| {
1137 let err_msg = format!("Failed to get the current time. Reason: {e}");
1138
1139 #[cfg(feature = "logging")]
1140 error!(target: "stdout", "{}", &err_msg);
1141
1142 LlamaCoreError::Operation(err_msg)
1143 })?;
1144
1145 let res = ChatCompletionObject {
1147 id: id.into(),
1148 object: String::from("chat.completion"),
1149 created: created.as_secs(),
1150 model: graph.name().to_owned(),
1151 choices: vec![ChatCompletionObjectChoice {
1152 index: 0,
1153 message: ChatCompletionObjectMessage {
1154 role: ChatCompletionRole::Assistant,
1155 content: Some(message),
1156 tool_calls: vec![],
1157 function_call: None,
1158 },
1159 finish_reason: FinishReason::length,
1160 logprobs: None,
1161 }],
1162 usage: Usage {
1163 prompt_tokens: token_info.prompt_tokens,
1164 completion_tokens: token_info.completion_tokens,
1165 total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
1166 },
1167 };
1168
1169 Ok((res, false))
1170 }
1171 Err(wasmedge_wasi_nn::Error::BackendError(
1172 wasmedge_wasi_nn::BackendError::PromptTooLong,
1173 )) => {
1174 #[cfg(feature = "logging")]
1175 warn!(target: "stdout", "The prompt is too long. Please reduce the length of your input and try again.");
1176
1177 let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
1179 let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
1180 let err_msg = format!(
1181 "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
1182 );
1183
1184 #[cfg(feature = "logging")]
1185 error!(target: "stdout", "{}", &err_msg);
1186
1187 LlamaCoreError::Operation(err_msg)
1188 })?;
1189
1190 let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
1192 let err_msg = format!("Failed to post-process the output. {e}");
1193
1194 #[cfg(feature = "logging")]
1195 error!(target: "stdout", "{}", &err_msg);
1196
1197 LlamaCoreError::Operation(err_msg)
1198 })?;
1199
1200 let token_info = get_token_info_by_graph(graph)?;
1202
1203 #[cfg(feature = "logging")]
1204 info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
1205
1206 let usage = Usage {
1207 prompt_tokens: token_info.prompt_tokens,
1208 completion_tokens: token_info.completion_tokens,
1209 total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
1210 };
1211
1212 let created = SystemTime::now()
1213 .duration_since(std::time::UNIX_EPOCH)
1214 .map_err(|e| {
1215 let err_msg = format!("Failed to get the current time. Reason: {e}");
1216
1217 #[cfg(feature = "logging")]
1218 error!(target: "stdout", "{}", &err_msg);
1219
1220 LlamaCoreError::Operation(err_msg)
1221 })?;
1222
1223 let res = ChatCompletionObject {
1225 id: id.into(),
1226 object: String::from("chat.completion"),
1227 created: created.as_secs(),
1228 model: graph.name().to_owned(),
1229 choices: vec![ChatCompletionObjectChoice {
1230 index: 0,
1231 message: ChatCompletionObjectMessage {
1232 role: ChatCompletionRole::Assistant,
1233 content: Some(message),
1234 tool_calls: vec![],
1235 function_call: None,
1236 },
1237 finish_reason: FinishReason::length,
1238 logprobs: None,
1239 }],
1240 usage,
1241 };
1242
1243 Ok((res, false))
1244 }
1245 Err(e) => {
1246 let err_msg = format!("Failed to compute the chat completion. Reason: {e}");
1247
1248 #[cfg(feature = "logging")]
1249 error!(target: "stdout", "{}", &err_msg);
1250
1251 Err(LlamaCoreError::Backend(BackendError::Compute(err_msg)))
1252 }
1253 }
1254}
1255
1256fn parse_tool_calls(
1257 input: &str,
1258 prompt_template: PromptTemplateType,
1259) -> Result<ParseResult, LlamaCoreError> {
1260 match prompt_template {
1261 PromptTemplateType::MistralTool => match regex::Regex::new(r"\[\{.*?\}\]") {
1262 Ok(re) => {
1263 let mut values: Vec<serde_json::Value> = vec![];
1264 for cap in re.captures_iter(input) {
1265 let matched = &cap[0];
1266
1267 #[cfg(feature = "logging")]
1268 info!(target: "stdout", "captured: {matched}");
1269
1270 match serde_json::from_str::<Vec<serde_json::Value>>(matched) {
1271 Ok(group) => values.extend(group),
1272 Err(e) => {
1273 let err_msg =
1274 format!("Failed to deserialize generated tool calls. Reason: {e}");
1275
1276 #[cfg(feature = "logging")]
1277 error!(target: "stdout", "{}", &err_msg);
1278
1279 return Err(LlamaCoreError::Operation(err_msg));
1280 }
1281 }
1282 }
1283
1284 let mut tool_calls: Vec<ToolCall> = vec![];
1285 for value in values.iter() {
1286 let name = match value.get("name") {
1287 Some(name) => name.to_string().replace("\"", ""),
1288 None => {
1289 let err_msg = format!(
1290 "Failed to get the name of the function. Tool call: {value:?}"
1291 );
1292
1293 #[cfg(feature = "logging")]
1294 error!(target: "stdout", "{}", &err_msg);
1295
1296 return Err(LlamaCoreError::Operation(err_msg));
1297 }
1298 };
1299
1300 let arguments = match value.get("arguments") {
1301 Some(arguments) => arguments.to_string(),
1302 None => {
1303 let err_msg = format!(
1304 "Failed to get the arguments of the function. Tool call: {value:?}"
1305 );
1306
1307 #[cfg(feature = "logging")]
1308 error!(target: "stdout", "{}", &err_msg);
1309
1310 return Err(LlamaCoreError::Operation(err_msg));
1311 }
1312 };
1313
1314 let function = Function { name, arguments };
1315
1316 let tool_call = ToolCall {
1317 id: "call_abc123".to_string(),
1318 ty: "function".to_string(),
1319 function,
1320 };
1321
1322 tool_calls.push(tool_call);
1323 }
1324
1325 let parsed = ParseResult {
1326 raw: input.to_owned(),
1327 content: None,
1328 tool_calls,
1329 };
1330
1331 #[cfg(feature = "logging")]
1332 info!(target: "stdout", "parsed result: {parsed:?}");
1333
1334 Ok(parsed)
1335 }
1336 Err(e) => {
1337 let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1338
1339 #[cfg(feature = "logging")]
1340 error!(target: "stdout", "{}", &err_msg);
1341
1342 Err(LlamaCoreError::Operation(err_msg))
1343 }
1344 },
1345 PromptTemplateType::ChatMLTool => {
1346 match regex::Regex::new(r"<tool_call>(.*?)</tool_call>") {
1347 Ok(re) => {
1348 let mut values: Vec<serde_json::Value> = vec![];
1349 for cap in re.captures_iter(input) {
1350 let matched = cap[1].replace("\\n", ""); #[cfg(feature = "logging")]
1353 info!(target: "stdout", "captured: {}", &matched);
1354
1355 match serde_json::from_str::<serde_json::Value>(&matched) {
1356 Ok(value) => values.push(value),
1357 Err(e) => {
1358 let err_msg = format!(
1359 "Failed to deserialize generated tool calls. Reason: {e}"
1360 );
1361
1362 #[cfg(feature = "logging")]
1363 error!(target: "stdout", "{}", &err_msg);
1364
1365 return Err(LlamaCoreError::Operation(err_msg));
1366 }
1367 }
1368 }
1369
1370 let mut tool_calls: Vec<ToolCall> = vec![];
1371 for value in values.iter() {
1372 let name = match value.get("name") {
1373 Some(name) => name.to_string().replace("\"", ""),
1374 None => {
1375 let err_msg = format!(
1376 "Failed to get the name of the function. Tool call: {value:?}"
1377 );
1378
1379 #[cfg(feature = "logging")]
1380 error!(target: "stdout", "{}", &err_msg);
1381
1382 return Err(LlamaCoreError::Operation(err_msg));
1383 }
1384 };
1385
1386 let arguments = match value.get("arguments") {
1387 Some(arguments) => arguments.to_string(),
1388 None => {
1389 let err_msg = format!(
1390 "Failed to get the arguments of the function. Tool call: {value:?}"
1391 );
1392
1393 #[cfg(feature = "logging")]
1394 error!(target: "stdout", "{}", &err_msg);
1395
1396 return Err(LlamaCoreError::Operation(err_msg));
1397 }
1398 };
1399
1400 let function = Function { name, arguments };
1401
1402 let tool_call = ToolCall {
1403 id: "call_abc123".to_string(),
1404 ty: "function".to_string(),
1405 function,
1406 };
1407
1408 tool_calls.push(tool_call);
1409 }
1410
1411 let parsed = ParseResult {
1412 raw: input.to_owned(),
1413 content: None,
1414 tool_calls,
1415 };
1416
1417 #[cfg(feature = "logging")]
1418 info!(target: "stdout", "parsed result: {parsed:?}");
1419
1420 Ok(parsed)
1421 }
1422 Err(e) => {
1423 let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1424
1425 #[cfg(feature = "logging")]
1426 error!(target: "stdout", "{}", &err_msg);
1427
1428 Err(LlamaCoreError::Operation(err_msg))
1429 }
1430 }
1431 }
1432 PromptTemplateType::GroqLlama3Tool => {
1433 #[cfg(feature = "logging")]
1434 info!(target: "stdout", "raw input: {input}");
1435
1436 match regex::Regex::new(r"(?s)<tool_call>((.|\r|\n)*?)</tool_call>") {
1437 Ok(re) => {
1438 let mut values: Vec<serde_json::Value> = vec![];
1439 for cap in re.captures_iter(input) {
1440 let matched = cap[1].trim();
1441
1442 #[cfg(feature = "logging")]
1443 info!(target: "stdout", "captured: {matched}");
1444
1445 match serde_json::from_str::<serde_json::Value>(matched) {
1446 Ok(value) => values.push(value),
1447 Err(e) => {
1448 let err_msg = format!(
1449 "Failed to deserialize generated tool calls. Reason: {e}"
1450 );
1451
1452 #[cfg(feature = "logging")]
1453 error!(target: "stdout", "{}", &err_msg);
1454
1455 return Err(LlamaCoreError::Operation(err_msg));
1456 }
1457 }
1458 }
1459
1460 let mut tool_calls: Vec<ToolCall> = vec![];
1461 for value in values.iter() {
1462 let name = match value.get("name") {
1463 Some(name) => name.to_string().replace("\"", ""),
1464 None => {
1465 let err_msg = format!(
1466 "Failed to get the name of the function. Tool call: {value:?}"
1467 );
1468
1469 #[cfg(feature = "logging")]
1470 error!(target: "stdout", "{}", &err_msg);
1471
1472 return Err(LlamaCoreError::Operation(err_msg));
1473 }
1474 };
1475
1476 let arguments = match value.get("arguments") {
1477 Some(arguments) => {
1478 if arguments.is_string() {
1479 arguments.as_str().unwrap().to_string()
1480 } else if arguments.is_object() {
1481 let map = arguments.as_object().unwrap();
1482
1483 #[cfg(feature = "logging")]
1484 info!(target: "stdout", "func arguments: {map:?}");
1485
1486 serde_json::to_string(map).unwrap()
1487 } else {
1488 serde_json::to_string(arguments).unwrap()
1489 }
1490 }
1491 None => {
1492 let err_msg = format!(
1493 "Failed to get the arguments of the function. Tool call: {value:?}"
1494 );
1495
1496 #[cfg(feature = "logging")]
1497 error!(target: "stdout", "{}", &err_msg);
1498
1499 return Err(LlamaCoreError::Operation(err_msg));
1500 }
1501 };
1502
1503 let function = Function { name, arguments };
1504
1505 let tool_call = ToolCall {
1506 id: "call_abc123".to_string(),
1507 ty: "function".to_string(),
1508 function,
1509 };
1510
1511 tool_calls.push(tool_call);
1512 }
1513
1514 let parsed = if tool_calls.is_empty() {
1515 ParseResult {
1516 raw: input.to_owned(),
1517 content: Some(input.to_owned()),
1518 tool_calls: vec![],
1519 }
1520 } else {
1521 ParseResult {
1522 raw: input.to_owned(),
1523 content: None,
1524 tool_calls,
1525 }
1526 };
1527
1528 #[cfg(feature = "logging")]
1529 info!(target: "stdout", "parsed result: {parsed:?}");
1530
1531 Ok(parsed)
1532 }
1533 Err(e) => {
1534 let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1535
1536 #[cfg(feature = "logging")]
1537 error!(target: "stdout", "{}", &err_msg);
1538
1539 Err(LlamaCoreError::Operation(err_msg))
1540 }
1541 }
1542 }
1543 PromptTemplateType::Llama3Tool => {
1544 #[cfg(feature = "logging")]
1545 info!(target: "stdout", "raw input: {input}");
1546
1547 let re = match regex::Regex::new(r"^\{(.|\r|\n)*\}$") {
1548 Ok(re) => re,
1549 Err(e) => {
1550 let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1551
1552 #[cfg(feature = "logging")]
1553 error!(target: "stdout", "{}", &err_msg);
1554
1555 return Err(LlamaCoreError::Operation(err_msg));
1556 }
1557 };
1558
1559 if re.is_match(input) {
1560 match serde_json::from_str::<serde_json::Value>(input) {
1561 Ok(value) => {
1562 let values: Vec<serde_json::Value> = vec![value];
1563
1564 let mut tool_calls: Vec<ToolCall> = vec![];
1565 for value in values.iter() {
1566 let name = match value.get("name") {
1567 Some(name) => name.to_string().replace("\"", ""),
1568 None => {
1569 let err_msg = format!(
1570 "Failed to get the name of the function. Tool call: {value:?}"
1571 );
1572
1573 #[cfg(feature = "logging")]
1574 error!(target: "stdout", "{}", &err_msg);
1575
1576 return Err(LlamaCoreError::Operation(err_msg));
1577 }
1578 };
1579
1580 let arguments = match value.get("parameters") {
1581 Some(arguments) => arguments.to_string(),
1582 None => {
1583 let err_msg = format!(
1584 "Failed to get the arguments of the function. Tool call: {value:?}"
1585 );
1586
1587 #[cfg(feature = "logging")]
1588 error!(target: "stdout", "{}", &err_msg);
1589
1590 return Err(LlamaCoreError::Operation(err_msg));
1591 }
1592 };
1593
1594 let function = Function { name, arguments };
1595
1596 let tool_call = ToolCall {
1597 id: "call_abc123".to_string(),
1598 ty: "function".to_string(),
1599 function,
1600 };
1601
1602 tool_calls.push(tool_call);
1603 }
1604
1605 let parsed = ParseResult {
1606 raw: input.to_owned(),
1607 content: None,
1608 tool_calls,
1609 };
1610
1611 #[cfg(feature = "logging")]
1612 info!(target: "stdout", "parsed result: {parsed:?}");
1613
1614 Ok(parsed)
1615 }
1616 Err(e) => {
1617 let err_msg =
1618 format!("Failed to deserialize generated tool calls. Reason: {e}");
1619
1620 #[cfg(feature = "logging")]
1621 error!(target: "stdout", "{}", &err_msg);
1622
1623 Err(LlamaCoreError::Operation(err_msg))
1624 }
1625 }
1626 } else {
1627 let parsed = ParseResult {
1628 raw: input.to_owned(),
1629 content: None,
1630 tool_calls: vec![],
1631 };
1632
1633 #[cfg(feature = "logging")]
1634 info!(target: "stdout", "parsed result: {parsed:?}");
1635
1636 Ok(parsed)
1637 }
1638 }
1639 PromptTemplateType::InternLM2Tool => {
1640 #[cfg(feature = "logging")]
1641 info!(target: "stdout", "raw input: {input}");
1642
1643 let blocks: Vec<&str> = input.trim().split("<|action_start|><|plugin|>").collect();
1644
1645 #[cfg(feature = "logging")]
1646 info!(target: "stdout", "blocks: {blocks:?}");
1647
1648 let mut tool_calls: Vec<ToolCall> = vec![];
1649 let mut content = String::new();
1650 for block in blocks {
1651 let block = block.trim();
1652 if !block.is_empty() {
1653 if block.ends_with("<|action_end|>") {
1654 let value = block.trim().trim_end_matches("<|action_end|>");
1655
1656 #[cfg(feature = "logging")]
1657 info!(target: "stdout", "tool call: {value}");
1658
1659 match serde_json::from_str::<serde_json::Value>(value) {
1660 Ok(value) => {
1661 let name = match value.get("name") {
1662 Some(name) => name.to_string().replace("\"", ""),
1663 None => {
1664 let err_msg = format!(
1665 "Failed to get the name of the function. Tool call: {value:?}"
1666 );
1667
1668 #[cfg(feature = "logging")]
1669 error!(target: "stdout", "{}", &err_msg);
1670
1671 return Err(LlamaCoreError::Operation(err_msg));
1672 }
1673 };
1674
1675 let arguments = match value.get("parameters") {
1676 Some(arguments) => arguments.to_string(),
1677 None => {
1678 let err_msg = format!(
1679 "Failed to get the arguments of the function. Tool call: {value:?}"
1680 );
1681
1682 #[cfg(feature = "logging")]
1683 error!(target: "stdout", "{}", &err_msg);
1684
1685 return Err(LlamaCoreError::Operation(err_msg));
1686 }
1687 };
1688
1689 let function = Function { name, arguments };
1690
1691 let tool_call = ToolCall {
1692 id: "call_abc123".to_string(),
1693 ty: "function".to_string(),
1694 function,
1695 };
1696
1697 tool_calls.push(tool_call);
1698 }
1699 Err(e) => {
1700 let err_msg = format!(
1701 "Failed to deserialize generated tool calls. Reason: {e}"
1702 );
1703
1704 #[cfg(feature = "logging")]
1705 error!(target: "stdout", "{}", &err_msg);
1706
1707 return Err(LlamaCoreError::Operation(err_msg));
1708 }
1709 }
1710 } else {
1711 content.push_str(block);
1712 content.push('\n');
1713 }
1714 }
1715 }
1716
1717 let parsed = match content.is_empty() {
1718 true => ParseResult {
1719 raw: input.to_owned(),
1720 content: None,
1721 tool_calls,
1722 },
1723 false => ParseResult {
1724 raw: input.to_owned(),
1725 content: Some(content.trim().to_owned()),
1726 tool_calls,
1727 },
1728 };
1729
1730 #[cfg(feature = "logging")]
1731 info!(target: "stdout", "parsed result: {parsed:?}");
1732
1733 Ok(parsed)
1734 }
1735 PromptTemplateType::NemotronTool => {
1736 #[cfg(feature = "logging")]
1737 info!(target: "stdout", "raw input: {input}");
1738
1739 match regex::Regex::new(r"(?s)<toolcall>\s*(.*?)\s*</toolcall>") {
1740 Ok(re) => {
1741 let mut values: Vec<serde_json::Value> = vec![];
1742 for cap in re.captures_iter(input) {
1743 #[cfg(feature = "logging")]
1744 info!(target: "stdout", "captured: {}", &cap[0]);
1745
1746 #[cfg(feature = "logging")]
1747 info!(target: "stdout", "extracted: {}", &cap[1]);
1748
1749 let matched = cap[1].trim();
1750
1751 #[cfg(feature = "logging")]
1752 info!(target: "stdout", "captured: {matched}");
1753
1754 match serde_json::from_str::<serde_json::Value>(matched) {
1755 Ok(value) => values.push(value),
1756 Err(e) => {
1757 let err_msg = format!(
1758 "Failed to deserialize generated tool calls. Reason: {e}"
1759 );
1760
1761 #[cfg(feature = "logging")]
1762 error!(target: "stdout", "{}", &err_msg);
1763
1764 return Err(LlamaCoreError::Operation(err_msg));
1765 }
1766 }
1767 }
1768
1769 let mut tool_calls: Vec<ToolCall> = vec![];
1770 for value in values.iter() {
1771 let name = match value.get("name") {
1772 Some(name) => name.to_string().replace("\"", ""),
1773 None => {
1774 let err_msg = format!(
1775 "Failed to get the name of the function. Tool call: {value:?}"
1776 );
1777
1778 #[cfg(feature = "logging")]
1779 error!(target: "stdout", "{}", &err_msg);
1780
1781 return Err(LlamaCoreError::Operation(err_msg));
1782 }
1783 };
1784
1785 let arguments = match value.get("arguments") {
1786 Some(arguments) => arguments.to_string(),
1787 None => {
1788 let err_msg = format!(
1789 "Failed to get the arguments of the function. Tool call: {value:?}"
1790 );
1791
1792 #[cfg(feature = "logging")]
1793 error!(target: "stdout", "{}", &err_msg);
1794
1795 return Err(LlamaCoreError::Operation(err_msg));
1796 }
1797 };
1798
1799 let function = Function { name, arguments };
1800
1801 let tool_call = ToolCall {
1802 id: "call_abc123".to_string(),
1803 ty: "function".to_string(),
1804 function,
1805 };
1806
1807 tool_calls.push(tool_call);
1808 }
1809
1810 let parsed = ParseResult {
1811 raw: input.to_owned(),
1812 content: None,
1813 tool_calls,
1814 };
1815
1816 #[cfg(feature = "logging")]
1817 info!(target: "stdout", "parsed result: {parsed:?}");
1818
1819 Ok(parsed)
1820 }
1821 Err(e) => {
1822 let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1823
1824 #[cfg(feature = "logging")]
1825 error!(target: "stdout", "{}", &err_msg);
1826
1827 Err(LlamaCoreError::Operation(err_msg))
1828 }
1829 }
1830 }
1831 PromptTemplateType::FunctionaryV32 => {
1832 #[cfg(feature = "logging")]
1833 info!(target: "stdout", "raw input: {input}");
1834
1835 match regex::Regex::new(r">>>\s*(\w+)\s*\{(.*)\}<\|eot_id\|>") {
1836 Ok(re) => {
1837 let mut tool_calls: Vec<ToolCall> = vec![];
1838 for cap in re.captures_iter(input) {
1839 #[cfg(feature = "logging")]
1840 info!(target: "stdout", "func_name: {}", &cap[1]);
1841
1842 #[cfg(feature = "logging")]
1843 info!(target: "stdout", "arguments: {}", &cap[2]);
1844
1845 let tool_call = ToolCall {
1846 id: "call_abc123".to_string(),
1847 ty: "function".to_string(),
1848 function: Function {
1849 name: cap[1].to_string(),
1850 arguments: cap[2].to_string(),
1851 },
1852 };
1853
1854 tool_calls.push(tool_call);
1855 }
1856
1857 let parsed = ParseResult {
1858 raw: input.to_owned(),
1859 content: None,
1860 tool_calls,
1861 };
1862
1863 #[cfg(feature = "logging")]
1864 info!(target: "stdout", "parsed result: {parsed:?}");
1865
1866 Ok(parsed)
1867 }
1868 Err(e) => {
1869 let warn_msg = format!("Failed to create a regex pattern. Reason: {e}");
1870
1871 #[cfg(feature = "logging")]
1872 warn!(target: "stdout", "{}", &warn_msg);
1873
1874 Ok(ParseResult {
1875 raw: input.to_owned(),
1876 content: None,
1877 tool_calls: vec![],
1878 })
1879 }
1880 }
1881 }
1882 PromptTemplateType::FunctionaryV31 => {
1883 #[cfg(feature = "logging")]
1884 info!(target: "stdout", "raw input: {input}");
1885
1886 match regex::Regex::new(r"<function=(\w+)>\s*(\{.*?\})</function>") {
1887 Ok(re) => {
1888 let mut tool_calls: Vec<ToolCall> = vec![];
1889 for cap in re.captures_iter(input) {
1890 #[cfg(feature = "logging")]
1891 info!(target: "stdout", "func_name: {}", &cap[1]);
1892
1893 #[cfg(feature = "logging")]
1894 info!(target: "stdout", "arguments: {}", &cap[2]);
1895
1896 let tool_call = ToolCall {
1897 id: "call_abc123".to_string(),
1898 ty: "function".to_string(),
1899 function: Function {
1900 name: cap[1].to_string(),
1901 arguments: cap[2].to_string(),
1902 },
1903 };
1904
1905 tool_calls.push(tool_call);
1906 }
1907
1908 let parsed = ParseResult {
1909 raw: input.to_owned(),
1910 content: None,
1911 tool_calls,
1912 };
1913
1914 #[cfg(feature = "logging")]
1915 info!(target: "stdout", "parsed result: {parsed:?}");
1916
1917 Ok(parsed)
1918 }
1919 Err(e) => {
1920 let warn_msg = format!("Failed to create a regex pattern. Reason: {e}");
1921
1922 #[cfg(feature = "logging")]
1923 warn!(target: "stdout", "{}", &warn_msg);
1924
1925 Ok(ParseResult {
1926 raw: input.to_owned(),
1927 content: None,
1928 tool_calls: vec![],
1929 })
1930 }
1931 }
1932 }
1933 PromptTemplateType::MistralSmallTool => {
1934 #[cfg(feature = "logging")]
1935 info!(target: "stdout", "raw input: {input}");
1936
1937 match regex::Regex::new(r"\[TOOL_CALLS\]\s*(\[(.*?)\])") {
1938 Ok(re) => {
1939 let mut values: Vec<serde_json::Value> = vec![];
1940 if let Some(cap) = re.captures(input) {
1941 let matched = cap[1].trim();
1942
1943 #[cfg(feature = "logging")]
1944 info!(target: "stdout", "captured: {matched}");
1945
1946 match serde_json::from_str::<Vec<serde_json::Value>>(matched) {
1947 Ok(vals) => values = vals,
1948 Err(e) => {
1949 let err_msg = format!(
1950 "Failed to deserialize generated tool calls. Reason: {e}"
1951 );
1952
1953 #[cfg(feature = "logging")]
1954 error!(target: "stdout", "{}", &err_msg);
1955
1956 return Err(LlamaCoreError::Operation(err_msg));
1957 }
1958 }
1959 };
1960
1961 let mut tool_calls: Vec<ToolCall> = vec![];
1962 for value in values.iter() {
1963 if let Some(object_map) = value.as_object() {
1964 if object_map.contains_key("function") {
1965 let mut function = Function {
1966 name: String::new(),
1967 arguments: String::new(),
1968 };
1969
1970 let value = object_map.get("function").unwrap();
1971 let func_map = value.as_object().unwrap();
1972 if func_map.contains_key("name") {
1973 let func_name = func_map.get("name").unwrap().as_str().unwrap();
1974 println!("Function name: {func_name:?}");
1975
1976 function.name = func_name.to_string();
1977 }
1978 if func_map.contains_key("arguments") {
1979 let args = func_map.get("arguments").unwrap();
1980 let arguments = args.to_string();
1981 println!("Arguments: {arguments:?}");
1982
1983 function.arguments = arguments;
1984 }
1985
1986 let tool_call = ToolCall {
1987 id: "call_abc123".to_string(),
1988 ty: "function".to_string(),
1989 function,
1990 };
1991
1992 tool_calls.push(tool_call);
1993 } else if object_map.contains_key("name") {
1994 let mut function = Function {
1995 name: String::new(),
1996 arguments: String::new(),
1997 };
1998
1999 let name = object_map.get("name").unwrap().as_str().unwrap();
2000 println!("name: {name:?}");
2001 function.name = name.to_string();
2002
2003 if object_map.contains_key("arguments") {
2004 let args = object_map.get("arguments").unwrap();
2005 let arguments = args.to_string();
2006 println!("Arguments: {arguments:?}");
2007
2008 function.arguments = arguments;
2009 }
2010
2011 let tool_call = ToolCall {
2012 id: "call_abc123".to_string(),
2013 ty: "function".to_string(),
2014 function,
2015 };
2016
2017 tool_calls.push(tool_call);
2018 }
2019 }
2020 }
2021
2022 let parsed = ParseResult {
2023 raw: input.to_owned(),
2024 content: None,
2025 tool_calls,
2026 };
2027
2028 #[cfg(feature = "logging")]
2029 info!(target: "stdout", "parsed result: {parsed:?}");
2030
2031 Ok(parsed)
2032 }
2033 Err(e) => {
2034 let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
2035
2036 #[cfg(feature = "logging")]
2037 error!(target: "stdout", "{}", &err_msg);
2038
2039 Err(LlamaCoreError::Operation(err_msg))
2040 }
2041 }
2042 }
2043 PromptTemplateType::Llama4Chat => {
2044 #[cfg(feature = "logging")]
2045 info!(target: "stdout", "raw input: {input:?}");
2046
2047 let mut tool_calls: Vec<ToolCall> = vec![];
2048 if let Ok(value) = serde_json::from_str::<serde_json::Value>(input) {
2049 match value.as_object() {
2050 Some(object_map) => {
2051 #[cfg(feature = "logging")]
2052 debug!(target: "stdout", "object_map: {object_map:?}");
2053
2054 if object_map.contains_key("name") {
2056 let name = object_map.get("name").unwrap().as_str().unwrap();
2057
2058 #[cfg(feature = "logging")]
2059 debug!(target: "stdout", "name: {name:?}");
2060
2061 let mut function = Function {
2062 name: name.to_string(),
2063 arguments: String::new(),
2064 };
2065
2066 if object_map.contains_key("parameters") {
2068 let args = object_map.get("parameters").unwrap();
2069 let arguments = args.to_string();
2070
2071 #[cfg(feature = "logging")]
2072 debug!(target: "stdout", "arguments: {:?}", &arguments);
2073
2074 function.arguments = arguments;
2075 }
2076
2077 tool_calls.push(ToolCall {
2078 id: "call_abc123".to_string(),
2079 ty: "function".to_string(),
2080 function,
2081 });
2082 } else {
2083 let err_msg = format!(
2084 "Failed to get the name of the function. raw input: {input:?}"
2085 );
2086
2087 #[cfg(feature = "logging")]
2088 error!(target: "stdout", "{}", &err_msg);
2089
2090 return Err(LlamaCoreError::Operation(err_msg));
2091 }
2092 }
2093 None => {
2094 let err_msg = format!("Failed to parse the JSON string. JSON: {input}");
2095
2096 #[cfg(feature = "logging")]
2097 error!(target: "stdout", "{}", &err_msg);
2098
2099 return Err(LlamaCoreError::Operation(err_msg));
2100 }
2101 }
2102 }
2103
2104 let parsed = ParseResult {
2105 raw: input.to_owned(),
2106 content: None,
2107 tool_calls,
2108 };
2109
2110 #[cfg(feature = "logging")]
2111 info!(target: "stdout", "parsed result: {parsed:?}");
2112
2113 Ok(parsed)
2114 }
2115 PromptTemplateType::Qwen3NoThink | PromptTemplateType::Smol3NoThink => {
2116 #[cfg(feature = "logging")]
2117 info!(target: "stdout", "raw input: {input:?}");
2118
2119 match regex::Regex::new(r"(?s)<tool_call>((.|\r|\n)*?)</tool_call>") {
2120 Ok(re) => {
2121 let mut values: Vec<serde_json::Value> = vec![];
2122 for cap in re.captures_iter(input) {
2123 let mut matched = cap[1].trim();
2124
2125 if matched.starts_with("\\n") {
2126 matched = matched.trim_start_matches("\\n");
2127 }
2128
2129 if matched.ends_with("\\n") {
2130 matched = matched.trim_end_matches("\\n");
2131 }
2132
2133 #[cfg(feature = "logging")]
2134 info!(target: "stdout", "captured: {matched:#?}");
2135
2136 if !matched.is_empty() {
2137 match serde_json::from_str::<serde_json::Value>(matched) {
2138 Ok(value) => values.push(value),
2139 Err(e) => {
2140 let err_msg = format!(
2141 "Failed to deserialize generated tool calls: {matched:#?}. Reason: {e}"
2142 );
2143
2144 #[cfg(feature = "logging")]
2145 error!(target: "stdout", "{}", &err_msg);
2146
2147 return Err(LlamaCoreError::Operation(err_msg));
2148 }
2149 }
2150 }
2151 }
2152
2153 let mut tool_calls: Vec<ToolCall> = vec![];
2154 for value in values.iter() {
2155 let name = match value.get("name") {
2156 Some(name) => name.to_string().replace("\"", ""),
2157 None => {
2158 let err_msg = format!(
2159 "Failed to get the name of the function. Tool call: {value:?}"
2160 );
2161
2162 #[cfg(feature = "logging")]
2163 error!(target: "stdout", "{}", &err_msg);
2164
2165 return Err(LlamaCoreError::Operation(err_msg));
2166 }
2167 };
2168
2169 let arguments = match value.get("arguments") {
2170 Some(arguments) => {
2171 if arguments.is_string() {
2172 arguments.as_str().unwrap().to_string()
2173 } else if arguments.is_object() {
2174 let map = arguments.as_object().unwrap();
2175
2176 #[cfg(feature = "logging")]
2177 info!(target: "stdout", "func arguments: {map:?}");
2178
2179 serde_json::to_string(map).unwrap()
2180 } else {
2181 serde_json::to_string(arguments).unwrap()
2182 }
2183 }
2184 None => {
2185 let err_msg = format!(
2186 "Failed to get the arguments of the function. Tool call: {value:?}"
2187 );
2188
2189 #[cfg(feature = "logging")]
2190 error!(target: "stdout", "{}", &err_msg);
2191
2192 return Err(LlamaCoreError::Operation(err_msg));
2193 }
2194 };
2195
2196 let function = Function { name, arguments };
2197
2198 let tool_call = ToolCall {
2199 id: "call_abc123".to_string(),
2200 ty: "function".to_string(),
2201 function,
2202 };
2203
2204 tool_calls.push(tool_call);
2205 }
2206
2207 let parsed = if tool_calls.is_empty() {
2208 ParseResult {
2209 raw: input.to_owned(),
2210 content: Some(input.to_owned()),
2211 tool_calls: vec![],
2212 }
2213 } else {
2214 ParseResult {
2215 raw: input.to_owned(),
2216 content: None,
2217 tool_calls,
2218 }
2219 };
2220
2221 #[cfg(feature = "logging")]
2222 info!(target: "stdout", "parsed result: {parsed:?}");
2223
2224 Ok(parsed)
2225 }
2226 Err(e) => {
2227 let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
2228
2229 #[cfg(feature = "logging")]
2230 error!(target: "stdout", "{}", &err_msg);
2231
2232 Err(LlamaCoreError::Operation(err_msg))
2233 }
2234 }
2235 }
2236 PromptTemplateType::Gemma3 => {
2237 #[cfg(feature = "logging")]
2238 info!(target: "stdout", "raw input: {input:?}");
2239
2240 match regex::Regex::new(r"(?s)```json\s*(.*?)\s*```") {
2241 Ok(re) => {
2242 let mut values: Vec<serde_json::Value> = vec![];
2243 for cap in re.captures_iter(input) {
2244 let mut matched = cap[1].trim();
2245
2246 if matched.starts_with("\\n") {
2247 matched = matched.trim_start_matches("\\n");
2248 }
2249
2250 if matched.ends_with("\\n") {
2251 matched = matched.trim_end_matches("\\n");
2252 }
2253
2254 #[cfg(feature = "logging")]
2255 info!(target: "stdout", "captured: {matched:#?}");
2256
2257 if !matched.is_empty() {
2258 match serde_json::from_str::<serde_json::Value>(matched) {
2259 Ok(value) => values.push(value),
2260 Err(e) => {
2261 let err_msg = format!(
2262 "Failed to deserialize generated tool calls: {matched:#?}. Reason: {e}"
2263 );
2264
2265 #[cfg(feature = "logging")]
2266 error!(target: "stdout", "{}", &err_msg);
2267
2268 return Err(LlamaCoreError::Operation(err_msg));
2269 }
2270 }
2271 }
2272 }
2273
2274 let mut tool_calls: Vec<ToolCall> = vec![];
2275 for value in values.iter() {
2276 let name = match value.get("name") {
2277 Some(name) => name.to_string().replace("\"", ""),
2278 None => {
2279 let err_msg = format!(
2280 "Failed to get the name of the function. Tool call: {value:?}"
2281 );
2282
2283 #[cfg(feature = "logging")]
2284 error!(target: "stdout", "{}", &err_msg);
2285
2286 return Err(LlamaCoreError::Operation(err_msg));
2287 }
2288 };
2289
2290 let arguments = match value.get("arguments") {
2291 Some(arguments) => {
2292 if arguments.is_string() {
2293 arguments.as_str().unwrap().to_string()
2294 } else if arguments.is_object() {
2295 let map = arguments.as_object().unwrap();
2296
2297 #[cfg(feature = "logging")]
2298 info!(target: "stdout", "func arguments: {map:?}");
2299
2300 serde_json::to_string(map).unwrap()
2301 } else {
2302 serde_json::to_string(arguments).unwrap()
2303 }
2304 }
2305 None => {
2306 let err_msg = format!(
2307 "Failed to get the arguments of the function. Tool call: {value:?}"
2308 );
2309
2310 #[cfg(feature = "logging")]
2311 error!(target: "stdout", "{}", &err_msg);
2312
2313 return Err(LlamaCoreError::Operation(err_msg));
2314 }
2315 };
2316
2317 let function = Function { name, arguments };
2318
2319 let tool_call = ToolCall {
2320 id: "call_abc123".to_string(),
2321 ty: "function".to_string(),
2322 function,
2323 };
2324
2325 tool_calls.push(tool_call);
2326 }
2327
2328 let parsed = if tool_calls.is_empty() {
2329 ParseResult {
2330 raw: input.to_owned(),
2331 content: Some(input.to_owned()),
2332 tool_calls: vec![],
2333 }
2334 } else {
2335 ParseResult {
2336 raw: input.to_owned(),
2337 content: None,
2338 tool_calls,
2339 }
2340 };
2341
2342 #[cfg(feature = "logging")]
2343 info!(target: "stdout", "parsed result: {parsed:?}");
2344
2345 Ok(parsed)
2346 }
2347 Err(e) => {
2348 let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
2349
2350 #[cfg(feature = "logging")]
2351 error!(target: "stdout", "{}", &err_msg);
2352
2353 Err(LlamaCoreError::Operation(err_msg))
2354 }
2355 }
2356 }
2357 PromptTemplateType::GptOss => {
2358 #[cfg(feature = "logging")]
2359 info!(target: "stdout", "raw input: {input:?}");
2360
2361 match regex::Regex::new(
2363 r"<\|channel\|>commentary to=functions\.([^<\s]+)\s*<\|constrain\|>json<\|message\|>([^<]*)<\|call\|>$",
2364 ) {
2365 Ok(re) => {
2366 if let Some(cap) = re.captures(input) {
2367 let function_name = cap[1].trim();
2368 let arguments = cap[2].trim();
2369
2370 #[cfg(feature = "logging")]
2371 info!(target: "stdout", "extracted function_name: {function_name}, arguments: {arguments}");
2372
2373 let function = Function {
2374 name: function_name.to_string(),
2375 arguments: arguments.to_string(),
2376 };
2377
2378 let tool_call = ToolCall {
2379 id: "call_abc123".to_string(),
2380 ty: "function".to_string(),
2381 function,
2382 };
2383
2384 let parsed = ParseResult {
2385 raw: input.to_owned(),
2386 content: None,
2387 tool_calls: vec![tool_call],
2388 };
2389
2390 #[cfg(feature = "logging")]
2391 info!(target: "stdout", "parsed result: {parsed:?}");
2392
2393 Ok(parsed)
2394 } else {
2395 match regex::Regex::new(r"(?s)```json\s*(.*?)\s*```") {
2396 Ok(re) => {
2397 let mut values: Vec<serde_json::Value> = vec![];
2398 for cap in re.captures_iter(input) {
2399 let mut matched = cap[1].trim();
2400
2401 if matched.starts_with("\\n") {
2402 matched = matched.trim_start_matches("\\n");
2403 }
2404
2405 if matched.ends_with("\\n") {
2406 matched = matched.trim_end_matches("\\n");
2407 }
2408
2409 #[cfg(feature = "logging")]
2410 info!(target: "stdout", "captured: {matched:#?}");
2411
2412 if !matched.is_empty() {
2413 match serde_json::from_str::<serde_json::Value>(matched) {
2414 Ok(value) => values.push(value),
2415 Err(e) => {
2416 let err_msg = format!(
2417 "Failed to deserialize generated tool calls: {matched:#?}. Reason: {e}"
2418 );
2419
2420 #[cfg(feature = "logging")]
2421 error!(target: "stdout", "{}", &err_msg);
2422
2423 return Err(LlamaCoreError::Operation(err_msg));
2424 }
2425 }
2426 }
2427 }
2428
2429 let mut tool_calls: Vec<ToolCall> = vec![];
2430 for value in values.iter() {
2431 let name = match value.get("name") {
2432 Some(name) => name.to_string().replace("\"", ""),
2433 None => {
2434 let err_msg = format!(
2435 "Failed to get the name of the function. Tool call: {value:?}"
2436 );
2437
2438 #[cfg(feature = "logging")]
2439 error!(target: "stdout", "{}", &err_msg);
2440
2441 return Err(LlamaCoreError::Operation(err_msg));
2442 }
2443 };
2444
2445 let arguments = match value.get("arguments") {
2446 Some(arguments) => {
2447 if arguments.is_string() {
2448 arguments.as_str().unwrap().to_string()
2449 } else if arguments.is_object() {
2450 let map = arguments.as_object().unwrap();
2451
2452 #[cfg(feature = "logging")]
2453 info!(target: "stdout", "func arguments: {map:?}");
2454
2455 serde_json::to_string(map).unwrap()
2456 } else {
2457 serde_json::to_string(arguments).unwrap()
2458 }
2459 }
2460 None => {
2461 let err_msg = format!(
2462 "Failed to get the arguments of the function. Tool call: {value:?}"
2463 );
2464
2465 #[cfg(feature = "logging")]
2466 error!(target: "stdout", "{}", &err_msg);
2467
2468 return Err(LlamaCoreError::Operation(err_msg));
2469 }
2470 };
2471
2472 let function = Function { name, arguments };
2473
2474 let tool_call = ToolCall {
2475 id: "call_abc123".to_string(),
2476 ty: "function".to_string(),
2477 function,
2478 };
2479
2480 tool_calls.push(tool_call);
2481 }
2482
2483 let parsed = if tool_calls.is_empty() {
2484 ParseResult {
2485 raw: input.to_owned(),
2486 content: Some(input.to_owned()),
2487 tool_calls: vec![],
2488 }
2489 } else {
2490 ParseResult {
2491 raw: input.to_owned(),
2492 content: Some(input.to_owned()),
2493 tool_calls,
2494 }
2495 };
2496
2497 #[cfg(feature = "logging")]
2498 info!(target: "stdout", "parsed result: {parsed:?}");
2499
2500 Ok(parsed)
2501 }
2502 Err(e) => {
2503 let err_msg =
2504 format!("Failed to create a regex pattern. Reason: {e}");
2505
2506 #[cfg(feature = "logging")]
2507 error!(target: "stdout", "{}", &err_msg);
2508
2509 Err(LlamaCoreError::Operation(err_msg))
2510 }
2511 }
2512 }
2513 }
2514 Err(e) => {
2515 let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
2516
2517 #[cfg(feature = "logging")]
2518 error!(target: "stdout", "{}", &err_msg);
2519
2520 Err(LlamaCoreError::Operation(err_msg))
2521 }
2522 }
2523 }
2524 PromptTemplateType::Qwen3Agent => {
2525 #[cfg(feature = "logging")]
2526 info!(target: "stdout", "Raw input to tool call parser: {input:?}");
2527
2528 match regex::Regex::new(r"<action>(.*?)</action>")
2530 .unwrap()
2531 .captures(input)
2532 {
2533 Some(captures) => {
2534 let action = captures.get(1).unwrap().as_str();
2535
2536 #[cfg(feature = "logging")]
2537 info!(target: "stdout", "Action: {action}");
2538
2539 match serde_json::from_str::<serde_json::Value>(action) {
2540 Ok(value) => {
2541 let name = match value.get("name") {
2542 Some(name) => name.to_string().replace("\"", ""),
2543 None => {
2544 let err_msg = format!(
2545 "Failed to get the name of the function. Tool call: {value:?}"
2546 );
2547
2548 #[cfg(feature = "logging")]
2549 error!(target: "stdout", "{}", &err_msg);
2550
2551 return Err(LlamaCoreError::Operation(err_msg));
2552 }
2553 };
2554
2555 let arguments = match value.get("arguments") {
2556 Some(arguments) => {
2557 if arguments.is_string() {
2558 arguments.as_str().unwrap().to_string()
2559 } else if arguments.is_object() {
2560 let map = arguments.as_object().unwrap();
2561
2562 #[cfg(feature = "logging")]
2563 info!(target: "stdout", "func arguments: {map:?}");
2564
2565 serde_json::to_string(map).unwrap()
2566 } else {
2567 serde_json::to_string(arguments).unwrap()
2568 }
2569 }
2570 None => {
2571 let err_msg = format!(
2572 "Failed to get the arguments of the function. Tool call: {value:?}"
2573 );
2574
2575 #[cfg(feature = "logging")]
2576 error!(target: "stdout", "{}", &err_msg);
2577
2578 return Err(LlamaCoreError::Operation(err_msg));
2579 }
2580 };
2581
2582 let function = Function { name, arguments };
2583
2584 let tool_call = ToolCall {
2585 id: "call_abc123".to_string(),
2586 ty: "function".to_string(),
2587 function,
2588 };
2589
2590 Ok(ParseResult {
2591 raw: input.to_owned(),
2592 content: Some(input.to_owned()),
2593 tool_calls: vec![tool_call],
2594 })
2595 }
2596 Err(e) => {
2597 let err_msg = format!(
2598 "Failed to deserialize generated tool calls: {action:#?}. Reason: {e}"
2599 );
2600
2601 #[cfg(feature = "logging")]
2602 error!(target: "stdout", "{}", &err_msg);
2603
2604 Err(LlamaCoreError::Operation(err_msg))
2605 }
2606 }
2607 }
2608 None => match input.contains("<final_answer>") {
2609 true => Ok(ParseResult {
2610 raw: input.to_owned(),
2611 content: Some(input.to_owned()),
2612 tool_calls: vec![],
2613 }),
2614 false => {
2615 let content = format!("<final_answer>{}</final_answer>", input.trim());
2616
2617 Ok(ParseResult {
2618 raw: input.to_owned(),
2619 content: Some(content),
2620 tool_calls: vec![],
2621 })
2622 }
2623 },
2624 }
2625 }
2626 PromptTemplateType::SeedOssNoThink | PromptTemplateType::SeedOssThink => {
2627 #[cfg(feature = "logging")]
2628 info!(target: "stdout", "Raw input to tool call parser: {input:?}");
2629
2630 match regex::Regex::new(r"```json\n([\s\S]*?)\n") {
2631 Ok(re) => {
2632 let mut values: Vec<serde_json::Value> = vec![];
2633 for cap in re.captures_iter(input) {
2634 let mut matched = cap[1].trim();
2635
2636 if matched.starts_with("\\n") {
2637 matched = matched.trim_start_matches("\\n");
2638 }
2639
2640 if matched.ends_with("\\n") {
2641 matched = matched.trim_end_matches("\\n");
2642 }
2643
2644 #[cfg(feature = "logging")]
2645 info!(target: "stdout", "captured: {matched:#?}");
2646
2647 if !matched.is_empty() {
2648 match serde_json::from_str::<serde_json::Value>(matched) {
2649 Ok(value) => values.push(value),
2650 Err(e) => {
2651 let err_msg = format!(
2652 "Failed to deserialize generated tool calls: {matched:#?}. Reason: {e}"
2653 );
2654
2655 #[cfg(feature = "logging")]
2656 error!(target: "stdout", "{}", &err_msg);
2657
2658 return Err(LlamaCoreError::Operation(err_msg));
2659 }
2660 }
2661 }
2662 }
2663
2664 let mut tool_calls: Vec<ToolCall> = vec![];
2665 for value in values.iter() {
2666 let name = match value.get("name") {
2667 Some(name) => name.to_string().replace("\"", ""),
2668 None => {
2669 let err_msg = format!(
2670 "Failed to get the name of the function. Tool call: {value:?}"
2671 );
2672
2673 #[cfg(feature = "logging")]
2674 error!(target: "stdout", "{}", &err_msg);
2675
2676 return Err(LlamaCoreError::Operation(err_msg));
2677 }
2678 };
2679
2680 let arguments = match value.get("arguments") {
2681 Some(arguments) => {
2682 if arguments.is_string() {
2683 arguments.as_str().unwrap().to_string()
2684 } else if arguments.is_object() {
2685 let map = arguments.as_object().unwrap();
2686
2687 #[cfg(feature = "logging")]
2688 info!(target: "stdout", "func arguments: {map:?}");
2689
2690 serde_json::to_string(map).unwrap()
2691 } else {
2692 serde_json::to_string(arguments).unwrap()
2693 }
2694 }
2695 None => {
2696 let err_msg = format!(
2697 "Failed to get the arguments of the function. Tool call: {value:?}"
2698 );
2699
2700 #[cfg(feature = "logging")]
2701 error!(target: "stdout", "{}", &err_msg);
2702
2703 return Err(LlamaCoreError::Operation(err_msg));
2704 }
2705 };
2706
2707 let function = Function { name, arguments };
2708
2709 let tool_call = ToolCall {
2710 id: "call_abc123".to_string(),
2711 ty: "function".to_string(),
2712 function,
2713 };
2714
2715 tool_calls.push(tool_call);
2716 }
2717
2718 let parsed = if tool_calls.is_empty() {
2719 ParseResult {
2720 raw: input.to_owned(),
2721 content: Some(input.to_owned()),
2722 tool_calls: vec![],
2723 }
2724 } else {
2725 ParseResult {
2726 raw: input.to_owned(),
2727 content: None,
2728 tool_calls,
2729 }
2730 };
2731
2732 #[cfg(feature = "logging")]
2733 info!(target: "stdout", "parsed result: {parsed:?}");
2734
2735 Ok(parsed)
2736 }
2737 Err(e) => {
2738 let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
2739
2740 #[cfg(feature = "logging")]
2741 error!(target: "stdout", "{}", &err_msg);
2742
2743 Err(LlamaCoreError::Operation(err_msg))
2744 }
2745 }
2746 }
2747 _ => {
2748 let err_msg = format!(
2749 "The tool use is only supported for prompt templates: {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, and {}.",
2750 PromptTemplateType::MistralTool,
2751 PromptTemplateType::ChatMLTool,
2752 PromptTemplateType::GroqLlama3Tool,
2753 PromptTemplateType::Llama3Tool,
2754 PromptTemplateType::InternLM2Tool,
2755 PromptTemplateType::NemotronTool,
2756 PromptTemplateType::FunctionaryV32,
2757 PromptTemplateType::MistralSmallTool,
2758 PromptTemplateType::Llama4Chat,
2759 PromptTemplateType::Qwen3NoThink,
2760 PromptTemplateType::Smol3NoThink,
2761 PromptTemplateType::Gemma3,
2762 PromptTemplateType::GptOss,
2763 PromptTemplateType::Qwen3Agent,
2764 PromptTemplateType::SeedOssNoThink,
2765 PromptTemplateType::SeedOssThink
2766 );
2767
2768 #[cfg(feature = "logging")]
2769 error!(target: "stdout", "{}", &err_msg);
2770
2771 Err(LlamaCoreError::Operation(err_msg))
2772 }
2773 }
2774}
2775
2776fn check_model_metadata(
2777 chat_request: &ChatCompletionRequest,
2778) -> Result<GgmlMetadata, LlamaCoreError> {
2779 let mut should_update = false;
2780 let mut metadata = get_model_metadata(chat_request.model.as_ref())?;
2781
2782 if metadata.prompt_template.is_image_supported() {
2784 if let Some(ChatCompletionRequestMessage::User(user_message)) = chat_request.messages.last()
2785 {
2786 if let ChatCompletionUserMessageContent::Parts(parts) = user_message.content() {
2787 for part in parts {
2788 if let ContentPart::Image(image_part) = part {
2789 let image = image_part.image();
2790
2791 if image.is_url() {
2792 let err_msg = "The image is provided in URL format. Only base64 format is supported.".to_string();
2793
2794 #[cfg(feature = "logging")]
2795 error!(target: "stdout", "{}", &err_msg);
2796
2797 return Err(LlamaCoreError::Operation(err_msg));
2798 } else {
2799 #[cfg(feature = "logging")]
2800 info!(target: "stdout", "The image is provided in base64 format.");
2801
2802 break;
2805 }
2806 }
2807 }
2808 }
2809 }
2810 }
2811
2812 if let Some(temp) = chat_request.temperature {
2814 if metadata.temperature != temp {
2815 metadata.temperature = temp;
2817
2818 if !should_update {
2819 should_update = true;
2820 }
2821 }
2822 }
2823
2824 if let Some(top_p) = chat_request.top_p {
2826 if metadata.top_p != top_p {
2827 metadata.top_p = top_p;
2829
2830 if !should_update {
2831 should_update = true;
2832 }
2833 }
2834 }
2835
2836 if let Some(frequency_penalty) = chat_request.frequency_penalty {
2838 if metadata.frequency_penalty != frequency_penalty {
2839 metadata.frequency_penalty = frequency_penalty;
2841
2842 if !should_update {
2843 should_update = true;
2844 }
2845 }
2846 }
2847
2848 if let Some(presence_penalty) = chat_request.presence_penalty {
2850 if metadata.presence_penalty != presence_penalty {
2851 metadata.presence_penalty = presence_penalty;
2853
2854 if !should_update {
2855 should_update = true;
2856 }
2857 }
2858 }
2859
2860 if metadata.embeddings {
2862 metadata.embeddings = false;
2863
2864 if !should_update {
2865 should_update = true;
2866 }
2867 }
2868
2869 if should_update {
2870 #[cfg(feature = "logging")]
2871 info!(target: "stdout", "Update the model metadata.");
2872
2873 update_model_metadata(chat_request.model.as_ref(), &metadata)?;
2875 }
2876
2877 Ok(metadata)
2878}
2879
2880fn update_n_predict(
2881 chat_request: &ChatCompletionRequest,
2882 metadata: &mut GgmlMetadata,
2883 available_completion_tokens: u64,
2884) -> Result<(), LlamaCoreError> {
2885 let mut should_update = false;
2886
2887 #[cfg(feature = "logging")]
2888 info!(target: "stdout", "n_predict: {}", metadata.n_predict);
2889
2890 if let Some(max_completion_tokens) = chat_request.max_completion_tokens {
2896 if metadata.n_predict != max_completion_tokens {
2897 #[cfg(feature = "logging")]
2898 info!(target: "stdout", "Update n_predict with max_completion_tokens from {} to {}", metadata.n_predict, max_completion_tokens);
2899
2900 metadata.n_predict = max_completion_tokens;
2901
2902 if !should_update {
2903 should_update = true;
2904 }
2905 }
2906 }
2907
2908 if metadata.n_predict == -2 {
2910 #[cfg(feature = "logging")]
2911 info!(target: "stdout", "Update n_predict with available_completion_tokens from {} to {}", metadata.n_predict, available_completion_tokens);
2912
2913 metadata.n_predict = available_completion_tokens as i32;
2915
2916 if !should_update {
2917 should_update = true;
2918 }
2919 }
2920
2921 if metadata.n_predict == -1
2922 || (metadata.n_predict > 0 && metadata.n_predict < available_completion_tokens as i32)
2923 || (metadata.n_predict < 0 && metadata.n_predict != -2)
2924 {
2926 #[cfg(feature = "logging")]
2927 info!(target: "stdout", "Update n_predict with available_completion_tokens from {} to {}", metadata.n_predict, available_completion_tokens);
2928
2929 metadata.n_predict = available_completion_tokens as i32;
2931
2932 if !should_update {
2933 should_update = true;
2934 }
2935 }
2936
2937 if should_update {
2938 #[cfg(feature = "logging")]
2939 info!(target: "stdout", "Update the model metadata.");
2940
2941 update_model_metadata(chat_request.model.as_ref(), metadata)?;
2943 }
2944
2945 Ok(())
2946}
2947
2948fn post_process(
2950 output: impl AsRef<str>,
2951 template_ty: &PromptTemplateType,
2952) -> Result<String, String> {
2953 let output = if *template_ty == PromptTemplateType::Baichuan2 {
2954 if output.as_ref().contains("用户:") {
2955 output.as_ref().trim_end_matches("用户:").trim().to_owned()
2956 } else {
2957 output.as_ref().trim().to_owned()
2958 }
2959 } else if *template_ty == PromptTemplateType::OpenChat {
2960 if output.as_ref().contains("<|end_of_turn|>") {
2961 output
2962 .as_ref()
2963 .trim_end_matches("<|end_of_turn|>")
2964 .trim()
2965 .to_owned()
2966 } else {
2967 output.as_ref().trim().to_owned()
2968 }
2969 } else if *template_ty == PromptTemplateType::GemmaInstruct
2970 || *template_ty == PromptTemplateType::Gemma3
2971 {
2972 let s = output.as_ref().trim();
2973 if s.ends_with("<end_of_turn>") {
2974 s.trim_end_matches("<end_of_turn>").trim().to_owned()
2975 } else {
2976 s.to_owned()
2977 }
2978 } else if *template_ty == PromptTemplateType::ChatML
2979 || *template_ty == PromptTemplateType::ChatMLTool
2980 || *template_ty == PromptTemplateType::InternLM2Tool
2981 || *template_ty == PromptTemplateType::MiniCPMV
2982 {
2983 let mut s = output.as_ref().trim();
2984 if s.ends_with("<|endoftext|>") {
2985 s = s.trim_end_matches("<|endoftext|>").trim();
2986 }
2987
2988 if s.starts_with(":") {
2989 s = s.trim_start_matches(":").trim();
2990 }
2991
2992 let x = {
2994 let pat = r#"<think>
2995
2996</think>
2997"#;
2998 if s.contains(pat) {
2999 let x = s.replace(pat, "");
3000 if x.starts_with("()") {
3001 x.trim_start_matches("()").to_owned()
3002 } else {
3003 x.to_owned()
3004 }
3005 } else {
3006 s.to_owned()
3007 }
3008 };
3009 s = x.trim();
3010
3011 if s.contains("<|im_start|>") && s.contains("<|im_end|>") {
3012 let idx_start = s.find("<|im_start|>").unwrap();
3013 let idx_end = s.find("<|im_end|>").unwrap();
3014
3015 match idx_start <= idx_end {
3016 true => s.split("<|im_start|>").collect::<Vec<_>>()[0]
3017 .trim()
3018 .to_owned(),
3019 false => s.split("<|im_end|>").collect::<Vec<_>>()[0]
3020 .trim()
3021 .to_owned(),
3022 }
3023 } else if s.contains("<|im_start|>") {
3024 s.split("<|im_start|>").collect::<Vec<_>>()[0]
3025 .trim()
3026 .to_owned()
3027 } else if s.contains("<|im_end|>") {
3028 let output = s.trim_end_matches("<|im_end|>").trim();
3029 if output.starts_with(": ") {
3030 output.trim_start_matches(": ").to_owned()
3031 } else {
3032 output.to_owned()
3033 }
3034 } else {
3035 s.to_owned()
3036 }
3037 } else if *template_ty == PromptTemplateType::Zephyr
3038 || *template_ty == PromptTemplateType::MistralLite
3039 || *template_ty == PromptTemplateType::MistralTool
3040 || *template_ty == PromptTemplateType::MistralInstruct
3041 || *template_ty == PromptTemplateType::MistralSmallChat
3042 || *template_ty == PromptTemplateType::MistralSmallTool
3043 || *template_ty == PromptTemplateType::BreezeInstruct
3044 {
3045 if output.as_ref().contains("</s><") {
3046 output.as_ref().trim_end_matches("</s><").trim().to_owned()
3047 } else if output.as_ref().contains("</s>") {
3048 output
3049 .as_ref()
3050 .strip_suffix("</s>")
3051 .unwrap()
3052 .trim()
3053 .to_owned()
3054 } else {
3055 output.as_ref().trim().to_owned()
3056 }
3057 } else if *template_ty == PromptTemplateType::DeepseekChat {
3058 if output.as_ref().contains("<|end_of_sentence|>") {
3059 output
3060 .as_ref()
3061 .trim_end_matches("<|end_of_sentence|>")
3062 .trim()
3063 .replace("<|end_of_sentence|>", " ")
3064 .trim()
3065 .to_owned()
3066 } else {
3067 output.as_ref().trim().to_owned()
3068 }
3069 } else if *template_ty == PromptTemplateType::HumanAssistant {
3070 if output.as_ref().contains("Human:") {
3071 output.as_ref().trim_end_matches("Human:").trim().to_owned()
3072 } else {
3073 output.as_ref().trim().to_owned()
3074 }
3075 } else if *template_ty == PromptTemplateType::SolarInstruct {
3076 let s = output.as_ref().trim();
3077
3078 if s.starts_with("### Answer") {
3079 let s = s.trim_start_matches("###").trim();
3080
3081 if s.starts_with("Answer:\n") {
3082 s.replace("Answer:\n", "Answer: ")
3083 } else {
3084 s.to_owned()
3085 }
3086 } else {
3087 s.to_owned()
3088 }
3089 } else if *template_ty == PromptTemplateType::Llama2Chat
3090 || *template_ty == PromptTemplateType::NemotronTool
3091 || *template_ty == PromptTemplateType::NemotronChat
3092 {
3093 let s = output.as_ref().trim();
3094 if s.ends_with("</s>") {
3095 s.trim_end_matches("</s>").trim().to_owned()
3096 } else {
3097 s.to_owned()
3098 }
3099 } else if *template_ty == PromptTemplateType::Llama3Chat
3100 || *template_ty == PromptTemplateType::GroqLlama3Tool
3101 || *template_ty == PromptTemplateType::Llama3Tool
3102 || *template_ty == PromptTemplateType::FunctionaryV32
3103 {
3104 let s = output.as_ref().trim();
3105 if s.ends_with("<|eot_id|>") {
3106 s.trim_end_matches("<|eot_id|>").trim().to_owned()
3107 } else {
3108 s.to_owned()
3109 }
3110 } else if *template_ty == PromptTemplateType::Phi3Chat {
3111 let s = output.as_ref().trim();
3112 if s.ends_with("<|end|>") {
3113 s.trim_end_matches("<|end|>").trim().to_owned()
3114 } else {
3115 s.to_owned()
3116 }
3117 } else if *template_ty == PromptTemplateType::Phi4Chat {
3118 let mut s = output.as_ref().trim();
3119
3120 if s.starts_with("think>") {
3121 s = s.trim_start_matches("think>").trim();
3122 }
3123
3124 if s.ends_with("<|im_end|>") {
3125 s.trim_end_matches("<|im_end|>").trim().to_owned()
3126 } else if s.ends_with("<|end|>") {
3127 s.trim_end_matches("<|end|>").trim().to_owned()
3128 } else {
3129 s.to_owned()
3130 }
3131 } else if *template_ty == PromptTemplateType::FunctionaryV31 {
3132 let mut s = output.as_ref().trim();
3133 if s.ends_with("<|eot_id|>") {
3134 s = s.trim_end_matches("<|eot_id|>").trim();
3135 }
3136 if s.ends_with("<|eom_id|>") {
3137 s = s.trim_end_matches("<|eom_id|>").trim();
3138 }
3139 s.to_owned()
3140 } else if *template_ty == PromptTemplateType::MoxinChat
3141 || *template_ty == PromptTemplateType::MoxinInstruct
3142 {
3143 let s = output.as_ref().trim();
3144 if s.ends_with("</s>") {
3145 s.trim_end_matches("</s>").trim().to_owned()
3146 } else if s.ends_with("[INST]") {
3147 s.trim_end_matches("[INST]").trim().to_owned()
3148 } else {
3149 s.to_owned()
3150 }
3151 } else if *template_ty == PromptTemplateType::Falcon3 {
3152 let s = output.as_ref().trim();
3153 if s.ends_with("<|endoftext|>") {
3154 s.trim_end_matches("<|endoftext|>").trim().to_owned()
3155 } else {
3156 s.to_owned()
3157 }
3158 } else if *template_ty == PromptTemplateType::Megrez {
3159 let s = output.as_ref().trim();
3160 if s.ends_with("<|turn_end|>") {
3161 s.trim_end_matches("<|turn_end|>").trim().to_owned()
3162 } else {
3163 s.to_owned()
3164 }
3165 } else if *template_ty == PromptTemplateType::Qwen2vl
3166 || *template_ty == PromptTemplateType::Qwen3NoThink
3167 || *template_ty == PromptTemplateType::ChatMLThink
3168 {
3169 let mut s = output.as_ref().trim();
3170
3171 if s.starts_with(":") {
3172 s = s.trim_start_matches(":").trim();
3173 }
3174
3175 if s.starts_with("</think>") {
3176 s = s.trim_start_matches("</think>").trim();
3177 }
3178
3179 if s.ends_with("<|im_end|>") {
3180 s.trim_end_matches("<|im_end|>").trim().to_owned()
3181 } else {
3182 s.to_owned()
3183 }
3184 } else if *template_ty == PromptTemplateType::VicunaLlava {
3185 let s = output.as_ref().trim();
3186 if s.ends_with("</s>") {
3187 s.trim_end_matches("</s>").trim().to_owned()
3188 } else {
3189 s.to_owned()
3190 }
3191 } else if *template_ty == PromptTemplateType::ExaoneDeepChat
3192 || *template_ty == PromptTemplateType::ExaoneChat
3193 {
3194 let mut s = output.as_ref().trim();
3195
3196 if s.ends_with("[|endofturn|]") {
3197 s = s.trim_end_matches("[|endofturn|]").trim();
3198 }
3199
3200 s.to_owned()
3201 } else if *template_ty == PromptTemplateType::Llama4Chat {
3202 let mut s = output.as_ref().trim();
3203
3204 if s.ends_with("<|eot|>") {
3205 s = s.trim_end_matches("<|eot|>").trim();
3206 }
3207
3208 s.to_owned()
3209 } else if *template_ty == PromptTemplateType::Smolvl {
3210 let mut s = output.as_ref().trim();
3211
3212 if s.starts_with(":") {
3213 s = s.trim_start_matches(":").trim();
3214 }
3215
3216 if s.ends_with("<end_of_utterance>") {
3217 s = s.trim_end_matches("<end_of_utterance>").trim();
3218 }
3219
3220 if s.contains("<end_of_utterance>:") {
3221 let parts = s.split("<end_of_utterance>:").collect::<Vec<_>>();
3222 parts.last().unwrap().trim().to_owned()
3223 } else {
3224 s.to_owned()
3225 }
3226 } else if *template_ty == PromptTemplateType::Smol3NoThink {
3227 let mut s = output.as_ref().trim();
3228
3229 if s.ends_with("<|im_end|>") {
3230 s = s.trim_end_matches("<|im_end|>").trim();
3231 }
3232
3233 let re = regex::Regex::new(r"(?s)^<think>.*?</think>\s*").unwrap();
3234 re.replace(s, "").to_string()
3235 } else if *template_ty == PromptTemplateType::GptOss {
3236 let s = output.as_ref().trim();
3237
3238 let re =
3239 regex::Regex::new(r"(?s).*<\|channel\|>final<\|message\|>(.*?)<\|return\|>$").unwrap();
3240
3241 if let Some(caps) = re.captures(s) {
3242 let extracted = &caps[1];
3243 extracted.to_owned()
3244 } else {
3245 s.to_owned()
3246 }
3247 } else if *template_ty == PromptTemplateType::Qwen3Agent {
3248 let mut s = output.as_ref().trim();
3249
3250 if s.starts_with(":") {
3251 s = s.trim_start_matches(":").trim();
3252 }
3253
3254 if s.starts_with("</think>") {
3255 s = s.trim_start_matches("</think>").trim();
3256 }
3257
3258 if s.ends_with("<|im_end|>") {
3259 s = s.trim_end_matches("<|im_end|>").trim();
3260 }
3261
3262 if s.contains("<final_answer>") && !s.contains("</final_answer>") {
3263 format!("{s}</final_answer>")
3264 } else {
3265 s.to_owned()
3266 }
3267 } else if *template_ty == PromptTemplateType::SeedOssNoThink {
3268 let s = output.as_ref().trim();
3269
3270 let re = regex::Regex::new(r"(?s)</seed:think>\s*(.*?)\s*<seed:eos>").unwrap();
3271
3272 if let Some(caps) = re.captures(s) {
3273 let extracted = &caps[1];
3274 extracted.to_owned()
3275 } else {
3276 s.to_owned()
3277 }
3278 } else {
3279 output.as_ref().trim().to_owned()
3280 };
3281
3282 Ok(output)
3283}
3284
3285fn build_prompt(
3297 model_name: Option<&String>,
3298 chat_request: &mut ChatCompletionRequest,
3299) -> Result<(String, u64, bool), LlamaCoreError> {
3300 let metadata = get_model_metadata(model_name)?;
3301 let ctx_size = metadata.ctx_size as u64;
3302 let chat_prompt = ChatPrompt::from(metadata.prompt_template);
3303
3304 let max_prompt_tokens = ctx_size * 4 / 5;
3306
3307 loop {
3308 {
3310 }
3323
3324 if chat_request.messages.is_empty() {
3325 let err_msg = "The messages in the chat request are empty.";
3326
3327 #[cfg(feature = "logging")]
3328 error!(target: "stdout", "{err_msg}");
3329
3330 return Err(LlamaCoreError::Operation(err_msg.to_owned()));
3331 }
3332
3333 #[cfg(feature = "logging")]
3334 {
3335 let mut role_chain = String::new();
3336 for (idx, message) in chat_request.messages.iter().enumerate() {
3337 if idx == chat_request.messages.len() - 1 {
3338 role_chain.push_str(&format!("{}", message.role()));
3339 } else {
3340 role_chain.push_str(&format!("{} -> ", message.role()));
3341 }
3342 }
3343 info!(target: "stdout", "Role chain: {role_chain}");
3344 }
3345
3346 let (prompt, tool_use) = match chat_request.tool_choice.as_ref() {
3347 Some(tool_choice) => match tool_choice {
3348 ToolChoice::None => {
3349 match chat_prompt.build_with_tools(&mut chat_request.messages, Some(&[])) {
3350 Ok(prompt) => (prompt, false),
3351 Err(e) => {
3352 let err_msg = format!("Fail to build chat prompts. Reason: {e}");
3353
3354 #[cfg(feature = "logging")]
3355 error!(target: "stdout", "{}", &err_msg);
3356
3357 return Err(LlamaCoreError::Operation(err_msg));
3358 }
3359 }
3360 }
3361 _ => match chat_request.tools.as_ref() {
3362 Some(tools) => match chat_prompt
3363 .build_with_tools(&mut chat_request.messages, Some(tools.as_slice()))
3364 {
3365 Ok(prompt) => (prompt, true),
3366 Err(e) => {
3367 let err_msg = format!("Fail to build chat prompts. Reason: {e}");
3368
3369 #[cfg(feature = "logging")]
3370 error!(target: "stdout", "{}", &err_msg);
3371
3372 return Err(LlamaCoreError::Operation(err_msg));
3373 }
3374 },
3375 None => {
3376 #[cfg(feature = "logging")]
3377 warn!(target: "stdout", "The tool choice without tools is not supported.");
3378
3379 match chat_prompt.build_with_tools(&mut chat_request.messages, None) {
3380 Ok(prompt) => (prompt, false),
3381 Err(e) => {
3382 let err_msg = format!("Fail to build chat prompts. Reason: {e}");
3383
3384 #[cfg(feature = "logging")]
3385 error!(target: "stdout", "{}", &err_msg);
3386
3387 return Err(LlamaCoreError::Operation(err_msg));
3388 }
3389 }
3390 }
3391 },
3392 },
3393 None => match chat_prompt.build_with_tools(&mut chat_request.messages, None) {
3394 Ok(prompt) => (prompt, false),
3395 Err(e) => {
3396 let err_msg = format!("Fail to build chat prompts. Reason: {e}");
3397
3398 #[cfg(feature = "logging")]
3399 error!(target: "stdout", "{}", &err_msg);
3400
3401 return Err(LlamaCoreError::Operation(err_msg));
3402 }
3403 },
3404 };
3405 #[cfg(feature = "logging")]
3406 info!(target: "stdout", "Try to set prompt: {prompt}");
3407
3408 set_prompt(model_name, &prompt)?;
3410
3411 let token_info = get_token_info_by_graph_name(model_name)?;
3413
3414 match token_info.prompt_tokens > max_prompt_tokens {
3415 true => {
3416 match chat_request.messages[0].role() {
3417 ChatCompletionRole::System => {
3418 if chat_request.messages.len() == 4
3420 && chat_request.messages[1].role() == ChatCompletionRole::User
3421 && chat_request.messages[2].role() == ChatCompletionRole::Assistant
3422 && chat_request.messages[3].role() == ChatCompletionRole::Tool
3423 {
3424 let err_msg = format!(
3425 "The number of prompt tokens ({}) is greater than the max prompt tokens ({}). Please increase the context size.",
3426 token_info.prompt_tokens, max_prompt_tokens
3427 );
3428
3429 #[cfg(feature = "logging")]
3430 error!(target: "stdout", "{}", &err_msg);
3431
3432 return Err(LlamaCoreError::Operation(err_msg));
3433 }
3434
3435 if chat_request.messages.len() > 2 {
3436 #[cfg(feature = "logging")]
3437 info!(target: "stdout", "Prune chat history: current length {}", chat_request.messages.len());
3438
3439 if chat_request.messages[1].role() == ChatCompletionRole::User {
3442 let user_message = chat_request.messages.remove(1);
3443
3444 #[cfg(feature = "logging")]
3445 info!(target: "stdout", "Remove a user message from the chat history: {user_message:?}");
3446 }
3447
3448 while chat_request.messages[1].role() != ChatCompletionRole::User {
3451 let message = chat_request.messages.remove(1);
3452
3453 #[cfg(feature = "logging")]
3454 info!(target: "stdout", "Remove a {} message from the chat history: {:?}", message.role(), message);
3455
3456 if chat_request.messages.len() == 1 {
3457 let err_msg = format!("The last message in the chat history should be a user message, but found a {} message.", message.role());
3458
3459 #[cfg(feature = "logging")]
3460 error!(target: "stdout", "{err_msg}");
3461
3462 return Err(LlamaCoreError::Operation(err_msg));
3463 }
3464 }
3465 } else if token_info.prompt_tokens > ctx_size {
3466 let err_msg = format!(
3467 "The number of prompt tokens ({}) is greater than the context size ({}). Please increase the context size, or simplify the input message.",
3468 token_info.prompt_tokens, ctx_size
3469 );
3470
3471 #[cfg(feature = "logging")]
3472 error!(target: "stdout", "{}", &err_msg);
3473
3474 return Err(LlamaCoreError::Operation(err_msg));
3475 } else {
3476 return Ok((prompt, ctx_size - token_info.prompt_tokens, tool_use));
3477 }
3478 }
3479 ChatCompletionRole::User => {
3480 if chat_request.messages.len() == 3
3482 && chat_request.messages[1].role() == ChatCompletionRole::User
3483 && chat_request.messages[2].role() == ChatCompletionRole::Assistant
3484 && chat_request.messages[3].role() == ChatCompletionRole::Tool
3485 {
3486 let err_msg = format!(
3487 "The number of prompt tokens ({}) is greater than the max prompt tokens ({}). Please increase the context size.",
3488 token_info.prompt_tokens, max_prompt_tokens
3489 );
3490
3491 #[cfg(feature = "logging")]
3492 error!(target: "stdout", "{}", &err_msg);
3493
3494 return Err(LlamaCoreError::Operation(err_msg));
3495 }
3496
3497 if chat_request.messages.len() > 1 {
3498 if chat_request.messages[0].role() == ChatCompletionRole::User {
3503 let user_message = chat_request.messages.remove(0);
3504
3505 #[cfg(feature = "logging")]
3506 info!(target: "stdout", "Remove a user message from the chat history: {user_message:?}");
3507 }
3508
3509 while chat_request.messages[0].role() != ChatCompletionRole::User {
3512 let message = chat_request.messages.remove(0);
3513
3514 #[cfg(feature = "logging")]
3515 info!(target: "stdout", "Remove a {} message from the chat history: {:?}", message.role(), message);
3516
3517 if chat_request.messages.is_empty() {
3518 let err_msg = format!("The last message in the chat history should be a user message, but found a {} message.", message.role());
3519
3520 #[cfg(feature = "logging")]
3521 error!(target: "stdout", "{err_msg}");
3522
3523 return Err(LlamaCoreError::Operation(err_msg));
3524 }
3525 }
3526 } else if token_info.prompt_tokens > ctx_size {
3527 let err_msg = format!(
3528 "The number of prompt tokens ({}) is greater than the context size ({}). Please increase the context size, or simplify the input message.",
3529 token_info.prompt_tokens, ctx_size
3530 );
3531
3532 #[cfg(feature = "logging")]
3533 error!(target: "stdout", "{}", &err_msg);
3534
3535 return Err(LlamaCoreError::Operation(err_msg));
3536 } else {
3537 return Ok((prompt, ctx_size - token_info.prompt_tokens, tool_use));
3538 }
3539 }
3540 _ => {
3541 #[cfg(feature = "logging")]
3542 info!(target: "stdout", "remove a {} message from the message queue", chat_request.messages[0].role());
3543
3544 chat_request.messages.remove(0);
3545 }
3546 }
3547
3548 continue;
3549 }
3550 false => return Ok((prompt, ctx_size - max_prompt_tokens, tool_use)),
3551 }
3552 }
3553}
3554
3555fn set_prompt(model_name: Option<&String>, prompt: impl AsRef<str>) -> Result<(), LlamaCoreError> {
3556 let chat_graphs = match CHAT_GRAPHS.get() {
3557 Some(chat_graphs) => chat_graphs,
3558 None => {
3559 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
3560
3561 #[cfg(feature = "logging")]
3562 error!(target: "stdout", "{}", &err_msg);
3563
3564 return Err(LlamaCoreError::Operation(err_msg.into()));
3565 }
3566 };
3567
3568 let mut chat_graphs = chat_graphs.lock().map_err(|e| {
3569 let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
3570
3571 #[cfg(feature = "logging")]
3572 error!(target: "stdout", "{}", &err_msg);
3573
3574 LlamaCoreError::Operation(err_msg)
3575 })?;
3576
3577 match model_name {
3578 Some(model_name) => {
3579 #[cfg(feature = "logging")]
3580 info!(target: "stdout", "Set prompt to the chat model named {model_name}");
3581
3582 match chat_graphs.contains_key(model_name) {
3583 true => {
3584 let graph = chat_graphs.get_mut(model_name).unwrap();
3585 let tensor_data = prompt.as_ref().as_bytes().to_vec();
3586 set_tensor_data_u8(graph, 0, &tensor_data)
3587 }
3588 false => match chat_graphs.iter_mut().next() {
3589 Some((_, graph)) => {
3590 let tensor_data = prompt.as_ref().as_bytes().to_vec();
3591 set_tensor_data_u8(graph, 0, &tensor_data)
3592 }
3593 None => {
3594 let err_msg = "There is no model available in the chat graphs.";
3595
3596 #[cfg(feature = "logging")]
3597 error!(target: "stdout", "{}", &err_msg);
3598
3599 Err(LlamaCoreError::Operation(err_msg.into()))
3600 }
3601 },
3602 }
3603 }
3604 None => {
3605 #[cfg(feature = "logging")]
3606 info!(target: "stdout", "Set prompt to the default chat model.");
3607
3608 match chat_graphs.iter_mut().next() {
3609 Some((_, graph)) => {
3610 let tensor_data = prompt.as_ref().as_bytes().to_vec();
3611 set_tensor_data_u8(graph, 0, &tensor_data)
3612 }
3613 None => {
3614 let err_msg = "There is no model available in the chat graphs while trying to set prompt to the default model.";
3615
3616 #[cfg(feature = "logging")]
3617 error!(target: "stdout", "{err_msg}");
3618
3619 Err(LlamaCoreError::Operation(err_msg.into()))
3620 }
3621 }
3622 }
3623 }
3624}
3625
3626fn get_model_metadata(model_name: Option<&String>) -> Result<GgmlMetadata, LlamaCoreError> {
3628 let chat_graphs = match CHAT_GRAPHS.get() {
3629 Some(chat_graphs) => chat_graphs,
3630 None => {
3631 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
3632
3633 #[cfg(feature = "logging")]
3634 error!(target: "stdout", "{err_msg}");
3635
3636 return Err(LlamaCoreError::Operation(err_msg.into()));
3637 }
3638 };
3639
3640 let chat_graphs = chat_graphs.lock().map_err(|e| {
3641 let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
3642
3643 #[cfg(feature = "logging")]
3644 error!(target: "stdout", "{}", &err_msg);
3645
3646 LlamaCoreError::Operation(err_msg)
3647 })?;
3648
3649 match model_name {
3650 Some(model_name) => match chat_graphs.contains_key(model_name) {
3651 true => {
3652 let graph = chat_graphs.get(model_name).unwrap();
3653 Ok(graph.metadata.clone())
3654 }
3655 false => match chat_graphs.iter().next() {
3656 Some((_, graph)) => Ok(graph.metadata.clone()),
3657 None => {
3658 let err_msg = "There is no model available in the chat graphs.";
3659
3660 #[cfg(feature = "logging")]
3661 error!(target: "stdout", "{}", &err_msg);
3662
3663 Err(LlamaCoreError::Operation(err_msg.into()))
3664 }
3665 },
3666 },
3667 None => match chat_graphs.iter().next() {
3668 Some((_, graph)) => Ok(graph.metadata.clone()),
3669 None => {
3670 let err_msg = "There is no model available in the chat graphs.";
3671
3672 #[cfg(feature = "logging")]
3673 error!(target: "stdout", "{err_msg}");
3674
3675 Err(LlamaCoreError::Operation(err_msg.into()))
3676 }
3677 },
3678 }
3679}
3680
3681fn update_model_metadata(
3682 model_name: Option<&String>,
3683 metadata: &GgmlMetadata,
3684) -> Result<(), LlamaCoreError> {
3685 let config = match serde_json::to_string(metadata) {
3686 Ok(config) => config,
3687 Err(e) => {
3688 let err_msg = format!("Fail to serialize metadata to a JSON string. {e}");
3689
3690 #[cfg(feature = "logging")]
3691 error!(target: "stdout", "{}", &err_msg);
3692
3693 return Err(LlamaCoreError::Operation(err_msg));
3694 }
3695 };
3696
3697 let chat_graphs = match CHAT_GRAPHS.get() {
3698 Some(chat_graphs) => chat_graphs,
3699 None => {
3700 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
3701
3702 #[cfg(feature = "logging")]
3703 error!(target: "stdout", "{err_msg}");
3704
3705 return Err(LlamaCoreError::Operation(err_msg.into()));
3706 }
3707 };
3708
3709 let mut chat_graphs = chat_graphs.lock().map_err(|e| {
3710 let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. Reason: {e}");
3711
3712 #[cfg(feature = "logging")]
3713 error!(target: "stdout", "{}", &err_msg);
3714
3715 LlamaCoreError::Operation(err_msg)
3716 })?;
3717
3718 match model_name {
3719 Some(model_name) => {
3720 match chat_graphs.contains_key(model_name) {
3721 true => {
3722 let graph = chat_graphs.get_mut(model_name).unwrap();
3723 set_tensor_data_u8(graph, 1, config.as_bytes())
3725 }
3726 false => match chat_graphs.iter_mut().next() {
3727 Some((_, graph)) => {
3728 set_tensor_data_u8(graph, 1, config.as_bytes())
3730 }
3731 None => {
3732 let err_msg = "There is no model available in the chat graphs.";
3733
3734 #[cfg(feature = "logging")]
3735 error!(target: "stdout", "{}", &err_msg);
3736
3737 Err(LlamaCoreError::Operation(err_msg.into()))
3738 }
3739 },
3740 }
3741 }
3742 None => {
3743 match chat_graphs.iter_mut().next() {
3744 Some((_, graph)) => {
3745 set_tensor_data_u8(graph, 1, config.as_bytes())
3747 }
3748 None => {
3749 let err_msg = "There is no model available in the chat graphs.";
3750
3751 #[cfg(feature = "logging")]
3752 error!(target: "stdout", "{err_msg}");
3753
3754 Err(LlamaCoreError::Operation(err_msg.into()))
3755 }
3756 }
3757 }
3758 }
3759}
3760
3761fn reset_model_metadata(model_name: Option<&String>) -> Result<(), LlamaCoreError> {
3762 let metadata = get_model_metadata(model_name)?;
3764
3765 update_model_metadata(model_name, &metadata)
3767}
3768
3769#[derive(Debug, Clone, Copy, PartialEq, Eq)]
3770enum ContextFullState {
3771 Message,
3772 Usage,
3773 Done,
3774 EndOfSequence,
3775}
3776
3777#[derive(Debug, Clone, Copy, PartialEq, Eq)]
3778enum StreamState {
3779 Usage,
3780 NoUsage,
3781 Done,
3782 EndOfSequence,
3783}
3784
3785#[derive(Debug, Clone, Copy, PartialEq, Eq)]
3786enum PromptTooLongState {
3787 Message,
3788 Usage,
3789 Done,
3790 EndOfSequence,
3791}
3792
3793struct ChatStream {
3794 id: String,
3795 model: Option<String>,
3796 include_usage: bool,
3797 context_full_state: ContextFullState,
3798 prompt_too_long_state: PromptTooLongState,
3799 stream_state: StreamState,
3800 cache: Option<VecDeque<String>>,
3801 is_waiting: bool,
3802 has_lock: bool,
3803 model_lock: Arc<ModelStreamLock>,
3805 utf8_cache: Vec<u8>,
3806}
3807impl ChatStream {
3808 fn new(
3809 model: Option<String>,
3810 id: String,
3811 include_usage: bool,
3812 cache: Option<Vec<String>>,
3813 ) -> Result<Self, LlamaCoreError> {
3814 let (resolved_model, model_lock) = match &model {
3816 Some(name) => (name.clone(), get_or_create_model_lock(name)?),
3817 None => get_default_model_lock()?,
3818 };
3819
3820 let has_lock = model_lock.try_acquire();
3822
3823 #[cfg(feature = "logging")]
3824 if !has_lock {
3825 info!(target: "stdout", "Lock acquisition failed for model {}, creating with waiting status", &resolved_model);
3826 } else {
3827 info!(target: "stdout", "Lock acquired for model {}", &resolved_model);
3828 }
3829
3830 Ok(ChatStream {
3831 id,
3832 model: Some(resolved_model),
3833 include_usage,
3834 context_full_state: ContextFullState::Message,
3835 prompt_too_long_state: PromptTooLongState::Message,
3836 stream_state: if include_usage {
3837 StreamState::Usage
3838 } else {
3839 StreamState::NoUsage
3840 },
3841 cache: cache.map(VecDeque::from),
3842 is_waiting: !has_lock,
3843 has_lock,
3844 model_lock,
3845 utf8_cache: Vec::new(),
3846 })
3847 }
3848
3849 fn try_acquire_lock(&mut self) -> bool {
3851 if self.has_lock {
3852 return true;
3853 }
3854
3855 let acquired = self.model_lock.try_acquire();
3856
3857 if acquired {
3858 self.has_lock = true;
3859 self.is_waiting = false;
3860
3861 #[cfg(feature = "logging")]
3862 info!(target: "stdout", "ChatStream {} acquired lock for model {:?}", &self.id, &self.model);
3863 }
3864
3865 acquired
3866 }
3867}
3868impl Drop for ChatStream {
3869 fn drop(&mut self) {
3870 if self.has_lock || (self.cache.is_none() && !self.is_waiting) {
3872 #[cfg(feature = "logging")]
3873 info!(target: "stdout", "Cleaning up context for ChatStream {}", &self.id);
3874
3875 match &self.model {
3876 Some(model_name) => {
3877 match CHAT_GRAPHS.get() {
3878 Some(chat_graphs) => {
3879 match chat_graphs.lock() {
3880 Ok(mut chat_graphs) => match chat_graphs.contains_key(model_name) {
3881 true => {
3882 let graph = chat_graphs.get_mut(model_name).unwrap();
3883
3884 if let Err(e) = graph.finish_single() {
3886 let err_msg = format!(
3887 "Failed to clean up the context. Reason: {e}"
3888 );
3889
3890 #[cfg(feature = "logging")]
3891 error!(target: "stdout", "{}", &err_msg);
3892
3893 #[cfg(not(feature = "logging"))]
3894 println!(
3895 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3896 &err_msg
3897 );
3898 }
3899 }
3900 false => match chat_graphs.iter_mut().next() {
3901 Some((_, graph)) => {
3902 if let Err(e) = graph.finish_single() {
3904 let err_msg = format!(
3905 "Failed to clean up the context. Reason: {e}"
3906 );
3907
3908 #[cfg(feature = "logging")]
3909 error!(target: "stdout", "{}", &err_msg);
3910
3911 #[cfg(not(feature = "logging"))]
3912 println!(
3913 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3914 &err_msg
3915 );
3916 }
3917 }
3918 None => {
3919 let err_msg =
3920 "There is no model available in the chat graphs.";
3921
3922 #[cfg(feature = "logging")]
3923 error!(target: "stdout", "{}", &err_msg);
3924
3925 #[cfg(not(feature = "logging"))]
3926 println!(
3927 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3928 &err_msg
3929 );
3930 }
3931 },
3932 },
3933 Err(e) => {
3934 let err_msg =
3935 format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
3936
3937 #[cfg(feature = "logging")]
3938 error!(target: "stdout", "{}", &err_msg);
3939
3940 #[cfg(not(feature = "logging"))]
3941 println!(
3942 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3943 &err_msg
3944 );
3945 }
3946 }
3947 }
3948 None => {
3949 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
3950
3951 #[cfg(feature = "logging")]
3952 error!(target: "stdout", "{}", &err_msg);
3953
3954 #[cfg(not(feature = "logging"))]
3955 println!(
3956 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3957 &err_msg
3958 );
3959 }
3960 };
3961 }
3962 None => {
3963 match CHAT_GRAPHS.get() {
3964 Some(chat_graphs) => {
3965 match chat_graphs.lock() {
3966 Ok(mut chat_graphs) => match chat_graphs.iter_mut().next() {
3967 Some((_, graph)) => {
3968 if let Err(e) = graph.finish_single() {
3970 let err_msg = format!(
3971 "Failed to clean up the context. Reason: {e}"
3972 );
3973
3974 #[cfg(feature = "logging")]
3975 error!(target: "stdout", "{}", &err_msg);
3976
3977 #[cfg(not(feature = "logging"))]
3978 println!(
3979 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3980 &err_msg
3981 );
3982 }
3983 }
3984 None => {
3985 let err_msg =
3986 "There is no model available in the chat graphs.";
3987
3988 #[cfg(feature = "logging")]
3989 error!(target: "stdout", "{err_msg}");
3990
3991 #[cfg(not(feature = "logging"))]
3992 println!(
3993 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3994 err_msg
3995 );
3996 }
3997 },
3998 Err(e) => {
3999 let err_msg =
4000 format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
4001
4002 #[cfg(feature = "logging")]
4003 error!(target: "stdout", "{}", &err_msg);
4004
4005 #[cfg(not(feature = "logging"))]
4006 println!(
4007 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
4008 &err_msg
4009 );
4010 }
4011 }
4012 }
4013 None => {
4014 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
4015
4016 #[cfg(feature = "logging")]
4017 error!(target: "stdout", "{}", &err_msg);
4018
4019 #[cfg(not(feature = "logging"))]
4020 println!(
4021 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
4022 &err_msg
4023 );
4024 }
4025 };
4026 }
4027 }
4028
4029 #[cfg(feature = "logging")]
4030 info!(target: "stdout", "Model context cleanup done!");
4031 }
4032
4033 if let Err(e) = reset_model_metadata(self.model.as_ref()) {
4035 let err_msg = format!("Fail to reset model metadata. Reason: {e}");
4036
4037 #[cfg(feature = "logging")]
4038 error!(target: "stdout", "{}", &err_msg);
4039
4040 #[cfg(not(feature = "logging"))]
4041 println!("[ERROR][llama_core] {}", &err_msg);
4042 }
4043 #[cfg(feature = "logging")]
4044 info!(target: "stdout", "Model metadata reset done!");
4045
4046 if self.has_lock {
4048 self.model_lock.release();
4049
4050 #[cfg(feature = "logging")]
4051 info!(target: "stdout", "Lock released for ChatStream {} (model {:?})", &self.id, &self.model);
4052 }
4053 }
4054}
4055impl futures::Stream for ChatStream {
4056 type Item = Result<String, LlamaCoreError>;
4057
4058 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
4059 let this = self.get_mut();
4060
4061 if this.is_waiting {
4063 if !this.try_acquire_lock() {
4064 this.model_lock.register_waker(cx.waker());
4066
4067 #[cfg(feature = "logging")]
4068 debug!(target: "stdout", "ChatStream {} waiting for model {:?}", &this.id, &this.model);
4069
4070 return Poll::Pending;
4071 }
4072
4073 #[cfg(feature = "logging")]
4074 info!(target: "stdout", "ChatStream {} acquired lock and is now active", &this.id);
4075 }
4077
4078 if !this.has_lock && !this.try_acquire_lock() {
4080 this.is_waiting = true;
4082
4083 this.model_lock.register_waker(cx.waker());
4085
4086 return Poll::Pending;
4087 }
4088
4089 if let Some(cache) = &mut this.cache {
4090 let x = cache.pop_front();
4091
4092 #[cfg(feature = "logging")]
4093 info!(target: "stdout", "Get the next item from the cache for ChatStream {}: {:?}", &this.id, &x);
4094
4095 match x {
4096 Some(x) => Poll::Ready(Some(Ok(x))),
4097 None => Poll::Ready(None),
4098 }
4099 } else {
4100 let res = compute_stream(
4101 this.model.clone(),
4102 this.id.clone(),
4103 this.include_usage,
4104 &mut this.prompt_too_long_state,
4105 &mut this.context_full_state,
4106 &mut this.stream_state,
4107 &mut this.utf8_cache,
4108 );
4109
4110 match res {
4111 Ok(x) => {
4112 #[cfg(feature = "logging")]
4113 info!(target: "stdout", "next item for ChatStream {}: {}", &this.id, &x);
4114
4115 if x != "[GGML] End of sequence" && !x.is_empty() {
4116 Poll::Ready(Some(Ok(x)))
4117 } else {
4118 Poll::Ready(None)
4120 }
4121 }
4122 Err(e) => Poll::Ready(Some(Err(e))),
4123 }
4124 }
4125 }
4126}
4127
4128fn compute_stream(
4129 model_name: Option<String>,
4130 id: String,
4131 include_usage: bool,
4132 prompt_too_long_state: &mut PromptTooLongState,
4133 context_full_state: &mut ContextFullState,
4134 stream_state: &mut StreamState,
4135 utf8_cache: &mut Vec<u8>,
4136) -> Result<String, LlamaCoreError> {
4137 #[cfg(feature = "logging")]
4138 info!(target: "stdout", "Computing stream chunk for ChatStream {}", &id);
4139
4140 #[cfg(feature = "logging")]
4141 debug!(target: "stdout", "prompt_too_long_state: {:?}", *prompt_too_long_state);
4142 #[cfg(feature = "logging")]
4143 debug!(target: "stdout", "context_full_state: {:?}", *context_full_state);
4144 #[cfg(feature = "logging")]
4145 debug!(target: "stdout", "stream_state: {:?}", *stream_state);
4146
4147 if *prompt_too_long_state == PromptTooLongState::EndOfSequence
4148 || *context_full_state == ContextFullState::EndOfSequence
4149 || *stream_state == StreamState::EndOfSequence
4150 {
4151 #[cfg(feature = "logging")]
4152 info!(target: "stdout", "Return the chat stream chunk!");
4153
4154 return Ok("[GGML] End of sequence".to_string());
4155 }
4156
4157 let chat_graphs = match CHAT_GRAPHS.get() {
4158 Some(chat_graphs) => chat_graphs,
4159 None => {
4160 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
4161
4162 #[cfg(feature = "logging")]
4163 error!(target: "stdout", "{}", &err_msg);
4164
4165 return Err(LlamaCoreError::Operation(err_msg.into()));
4166 }
4167 };
4168
4169 let mut chat_graphs = chat_graphs.lock().map_err(|e| {
4171 let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
4172
4173 #[cfg(feature = "logging")]
4174 error!(target: "stdout", "{}", &err_msg);
4175
4176 LlamaCoreError::Operation(err_msg)
4177 })?;
4178
4179 let res = match &model_name {
4181 Some(model_name) => {
4182 match chat_graphs.contains_key(model_name) {
4183 true => {
4184 let graph = chat_graphs.get_mut(model_name).unwrap();
4185 match graph.compute_single() {
4187 Ok(_) => {
4188 #[cfg(feature = "logging")]
4189 debug!(target: "stdout", "Compute the chat stream chunk successfully.");
4190
4191 match stream_state {
4193 StreamState::Usage | StreamState::NoUsage => {
4194 let output_buffer =
4196 get_output_buffer_single(graph, OUTPUT_TENSOR)?;
4197
4198 #[cfg(feature = "logging")]
4199 info!(target: "stdout", "retrieved the output buffer");
4200
4201 let output = match String::from_utf8(output_buffer.clone()) {
4203 Ok(token) => token,
4204 Err(_) => {
4205 utf8_cache.extend_from_slice(&output_buffer[..]);
4207
4208 match String::from_utf8(utf8_cache.clone()) {
4209 Ok(token) => {
4210 utf8_cache.clear();
4211 token
4212 }
4213 Err(e) => {
4214 if utf8_cache.len() > 4 {
4215 #[cfg(feature = "logging")]
4216 error!(target: "stdout", "UTF-8 decode failed, cache too long: {e}");
4217 #[cfg(feature = "logging")]
4218 error!(target: "stdout", "The cached buffer: {:?}", &utf8_cache[..]);
4219 utf8_cache.clear();
4220 } else {
4221 #[cfg(feature = "logging")]
4222 warn!(target: "stdout", "UTF-8 decode incomplete: {e}");
4223 }
4224 String::new()
4225 }
4226 }
4227 }
4228 };
4229
4230 #[cfg(feature = "logging")]
4231 info!(target: "stdout", "decoded the output buffer");
4232
4233 let created = SystemTime::now()
4234 .duration_since(std::time::UNIX_EPOCH)
4235 .map_err(|e| {
4236 let err_msg = format!(
4237 "Failed to get the current time. Reason: {e}"
4238 );
4239
4240 #[cfg(feature = "logging")]
4241 error!(target: "stdout", "{}", &err_msg);
4242
4243 LlamaCoreError::Operation(err_msg)
4244 })?;
4245
4246 let chat_completion_chunk = ChatCompletionChunk {
4247 id,
4248 object: "chat.completion.chunk".to_string(),
4249 created: created.as_secs(),
4250 model: graph.name().to_owned(),
4251 system_fingerprint: "fp_44709d6fcb".to_string(),
4252 choices: vec![ChatCompletionChunkChoice {
4253 index: 0,
4254 delta: ChatCompletionChunkChoiceDelta {
4255 role: ChatCompletionRole::Assistant,
4256 content: Some(output),
4257 tool_calls: vec![],
4258 },
4259 logprobs: None,
4260 finish_reason: None,
4261 }],
4262 usage: None,
4263 };
4264
4265 #[cfg(feature = "logging")]
4266 info!(target: "stdout", "created chat completion chunk");
4267
4268 let chunk_str = serde_json::to_string(&chat_completion_chunk)
4270 .map_err(|e| {
4271 let err_msg = format!(
4272 "Failed to serialize chat completion chunk. Reason: {e}"
4273 );
4274
4275 #[cfg(feature = "logging")]
4276 error!(target: "stdout", "{}", &err_msg);
4277
4278 LlamaCoreError::Operation(err_msg)
4279 })?;
4280
4281 Ok(format!("data: {chunk_str}\n\n"))
4282 }
4283 StreamState::Done => {
4284 *stream_state = StreamState::EndOfSequence;
4285
4286 Ok("data: [DONE]\n\n".to_string())
4287 }
4288 StreamState::EndOfSequence => {
4289 Ok("[GGML] End of sequence".to_string())
4290 }
4291 }
4292 }
4293 Err(wasmedge_wasi_nn::Error::BackendError(
4294 wasmedge_wasi_nn::BackendError::EndOfSequence,
4295 )) => {
4296 #[cfg(feature = "logging")]
4297 debug!(target: "stdout", "End of sequence");
4298
4299 match stream_state {
4300 StreamState::Usage => {
4301 *stream_state = StreamState::Done;
4302
4303 let token_info = get_token_info_by_graph(graph)?;
4305
4306 let usage = Some(Usage {
4307 prompt_tokens: token_info.prompt_tokens,
4308 completion_tokens: token_info.completion_tokens,
4309 total_tokens: token_info.prompt_tokens
4310 + token_info.completion_tokens,
4311 });
4312
4313 #[cfg(feature = "logging")]
4314 info!(target: "stdout", "token_info: {} prompt tokens, {} completion tokens", token_info.prompt_tokens, token_info.completion_tokens);
4315
4316 let created = SystemTime::now()
4317 .duration_since(std::time::UNIX_EPOCH)
4318 .map_err(|e| {
4319 let err_msg = format!(
4320 "Failed to get the current time. Reason: {e}"
4321 );
4322
4323 #[cfg(feature = "logging")]
4324 error!(target: "stdout", "{}", &err_msg);
4325
4326 LlamaCoreError::Operation(err_msg)
4327 })?;
4328
4329 let chat_completion_chunk = ChatCompletionChunk {
4330 id,
4331 object: "chat.completion.chunk".to_string(),
4332 created: created.as_secs(),
4333 model: graph.name().to_owned(),
4334 system_fingerprint: "fp_44709d6fcb".to_string(),
4335 choices: vec![],
4336 usage,
4337 };
4338
4339 let chunk_str = serde_json::to_string(&chat_completion_chunk)
4341 .map_err(|e| {
4342 let err_msg = format!(
4343 "Failed to serialize chat completion chunk. Reason: {e}"
4344 );
4345
4346 #[cfg(feature = "logging")]
4347 error!(target: "stdout", "{}", &err_msg);
4348
4349 LlamaCoreError::Operation(err_msg)
4350 })?;
4351
4352 Ok(format!("data: {chunk_str}\n\n"))
4353 }
4354 StreamState::Done | StreamState::NoUsage => {
4355 *stream_state = StreamState::EndOfSequence;
4356
4357 Ok("data: [DONE]\n\n".to_string())
4358 }
4359 StreamState::EndOfSequence => {
4360 Ok("[GGML] End of sequence".to_string())
4361 }
4362 }
4363 }
4364 Err(wasmedge_wasi_nn::Error::BackendError(
4365 wasmedge_wasi_nn::BackendError::ContextFull,
4366 )) => {
4367 #[cfg(feature = "logging")]
4368 debug!(target: "stdout", "Context full");
4369
4370 match context_full_state {
4371 ContextFullState::Message => {
4372 match include_usage {
4373 true => *context_full_state = ContextFullState::Usage,
4374 false => *context_full_state = ContextFullState::Done,
4375 }
4376
4377 let created = SystemTime::now()
4378 .duration_since(std::time::UNIX_EPOCH)
4379 .map_err(|e| {
4380 let err_msg = format!(
4381 "Failed to get the current time. Reason: {e}"
4382 );
4383
4384 #[cfg(feature = "logging")]
4385 error!(target: "stdout", "{}", &err_msg);
4386
4387 LlamaCoreError::Operation(err_msg)
4388 })?;
4389
4390 let chat_completion_chunk = ChatCompletionChunk {
4391 id,
4392 object: "chat.completion.chunk".to_string(),
4393 created: created.as_secs(),
4394 model: graph.name().to_owned(),
4395 system_fingerprint: "fp_44709d6fcb".to_string(),
4396 choices: vec![ChatCompletionChunkChoice {
4397 index: 0,
4398 delta: ChatCompletionChunkChoiceDelta {
4399 role: ChatCompletionRole::Assistant,
4400 content: Some(
4401 "<|WASMEDGE-GGML-CONTEXT-FULL|>".to_string(),
4402 ),
4403 tool_calls: vec![],
4404 },
4405 logprobs: None,
4406 finish_reason: Some(FinishReason::length),
4407 }],
4408 usage: None,
4409 };
4410
4411 let chunk_str = serde_json::to_string(&chat_completion_chunk)
4413 .map_err(|e| {
4414 let err_msg = format!(
4415 "Failed to serialize chat completion chunk. Reason: {e}"
4416 );
4417
4418 #[cfg(feature = "logging")]
4419 error!(target: "stdout", "{}", &err_msg);
4420
4421 LlamaCoreError::Operation(err_msg)
4422 })?;
4423
4424 Ok(format!("data: {chunk_str}\n\n"))
4425 }
4426 ContextFullState::Usage => {
4427 *context_full_state = ContextFullState::Done;
4428
4429 let token_info = get_token_info_by_graph(graph)?;
4431
4432 let usage = Some(Usage {
4433 prompt_tokens: token_info.prompt_tokens,
4434 completion_tokens: token_info.completion_tokens,
4435 total_tokens: token_info.prompt_tokens
4436 + token_info.completion_tokens,
4437 });
4438
4439 let created = SystemTime::now()
4440 .duration_since(std::time::UNIX_EPOCH)
4441 .map_err(|e| {
4442 let err_msg = format!(
4443 "Failed to get the current time. Reason: {e}"
4444 );
4445
4446 #[cfg(feature = "logging")]
4447 error!(target: "stdout", "{}", &err_msg);
4448
4449 LlamaCoreError::Operation(err_msg)
4450 })?;
4451
4452 let chat_completion_chunk = ChatCompletionChunk {
4453 id,
4454 object: "chat.completion.chunk".to_string(),
4455 created: created.as_secs(),
4456 model: graph.name().to_owned(),
4457 system_fingerprint: "fp_44709d6fcb".to_string(),
4458 choices: vec![],
4459 usage,
4460 };
4461
4462 let chunk_str = serde_json::to_string(&chat_completion_chunk)
4464 .map_err(|e| {
4465 let err_msg = format!(
4466 "Failed to serialize chat completion chunk. Reason: {e}"
4467 );
4468
4469 #[cfg(feature = "logging")]
4470 error!(target: "stdout", "{}", &err_msg);
4471
4472 LlamaCoreError::Operation(err_msg)
4473 })?;
4474
4475 Ok(format!("data: {chunk_str}\n\n"))
4476 }
4477 ContextFullState::Done => {
4478 *context_full_state = ContextFullState::EndOfSequence;
4479
4480 Ok("data: [DONE]\n\n".to_string())
4481 }
4482 ContextFullState::EndOfSequence => {
4483 Ok("[GGML] End of sequence".to_string())
4484 }
4485 }
4486 }
4487 Err(wasmedge_wasi_nn::Error::BackendError(
4488 wasmedge_wasi_nn::BackendError::PromptTooLong,
4489 )) => {
4490 #[cfg(feature = "logging")]
4491 debug!(target: "stdout", "Prompt too long");
4492
4493 match prompt_too_long_state {
4494 PromptTooLongState::Message => {
4495 match include_usage {
4496 true => *prompt_too_long_state = PromptTooLongState::Usage,
4497 false => *prompt_too_long_state = PromptTooLongState::Done,
4498 }
4499
4500 let created = SystemTime::now()
4501 .duration_since(std::time::UNIX_EPOCH)
4502 .map_err(|e| {
4503 let err_msg = format!(
4504 "Failed to get the current time. Reason: {e}"
4505 );
4506
4507 #[cfg(feature = "logging")]
4508 error!(target: "stdout", "{}", &err_msg);
4509
4510 LlamaCoreError::Operation(err_msg)
4511 })?;
4512
4513 let chat_completion_chunk = ChatCompletionChunk {
4514 id,
4515 object: "chat.completion.chunk".to_string(),
4516 created: created.as_secs(),
4517 model: graph.name().to_owned(),
4518 system_fingerprint: "fp_44709d6fcb".to_string(),
4519 choices: vec![ChatCompletionChunkChoice {
4520 index: 0,
4521 delta: ChatCompletionChunkChoiceDelta {
4522 role: ChatCompletionRole::Assistant,
4523 content: None,
4524 tool_calls: vec![],
4525 },
4526 logprobs: None,
4527 finish_reason: Some(FinishReason::length),
4528 }],
4529 usage: None,
4530 };
4531
4532 let chunk_str = serde_json::to_string(&chat_completion_chunk)
4534 .map_err(|e| {
4535 let err_msg = format!(
4536 "Failed to serialize chat completion chunk. Reason: {e}"
4537 );
4538
4539 #[cfg(feature = "logging")]
4540 error!(target: "stdout", "{}", &err_msg);
4541
4542 LlamaCoreError::Operation(err_msg)
4543 })?;
4544
4545 Ok(format!("data: {chunk_str}\n\n"))
4546 }
4547 PromptTooLongState::Usage => {
4548 *prompt_too_long_state = PromptTooLongState::Done;
4549
4550 let token_info = get_token_info_by_graph(graph)?;
4552
4553 let usage = Some(Usage {
4554 prompt_tokens: token_info.prompt_tokens,
4555 completion_tokens: token_info.completion_tokens,
4556 total_tokens: token_info.prompt_tokens
4557 + token_info.completion_tokens,
4558 });
4559
4560 let created = SystemTime::now()
4561 .duration_since(std::time::UNIX_EPOCH)
4562 .map_err(|e| {
4563 let err_msg = format!(
4564 "Failed to get the current time. Reason: {e}"
4565 );
4566
4567 #[cfg(feature = "logging")]
4568 error!(target: "stdout", "{}", &err_msg);
4569
4570 LlamaCoreError::Operation(err_msg)
4571 })?;
4572
4573 let chat_completion_chunk = ChatCompletionChunk {
4574 id,
4575 object: "chat.completion.chunk".to_string(),
4576 created: created.as_secs(),
4577 model: graph.name().to_owned(),
4578 system_fingerprint: "fp_44709d6fcb".to_string(),
4579 choices: vec![],
4580 usage,
4581 };
4582
4583 let chunk_str = serde_json::to_string(&chat_completion_chunk)
4585 .map_err(|e| {
4586 let err_msg = format!(
4587 "Failed to serialize chat completion chunk. Reason: {e}"
4588 );
4589
4590 #[cfg(feature = "logging")]
4591 error!(target: "stdout", "{}", &err_msg);
4592
4593 LlamaCoreError::Operation(err_msg)
4594 })?;
4595
4596 Ok(format!("data: {chunk_str}\n\n"))
4597 }
4598 PromptTooLongState::Done => {
4599 *prompt_too_long_state = PromptTooLongState::EndOfSequence;
4600
4601 Ok("data: [DONE]\n\n".to_string())
4602 }
4603 PromptTooLongState::EndOfSequence => {
4604 Ok("[GGML] End of sequence".to_string())
4605 }
4606 }
4607 }
4608 Err(e) => {
4609 let err_msg =
4610 format!("Failed to compute the chat completion. Reason: {e}");
4611
4612 #[cfg(feature = "logging")]
4613 error!(target: "stdout", "{}", &err_msg);
4614
4615 Err(LlamaCoreError::Backend(BackendError::ComputeSingle(
4616 err_msg,
4617 )))
4618 }
4619 }
4620 }
4621 false => {
4622 match chat_graphs.iter_mut().next() {
4623 Some((_, graph)) => {
4624 match graph.compute_single() {
4626 Ok(_) => {
4627 #[cfg(feature = "logging")]
4628 debug!(target: "stdout", "Compute the chat stream chunk successfully.");
4629
4630 match stream_state {
4631 StreamState::Usage | StreamState::NoUsage => {
4632 let output_buffer =
4634 get_output_buffer_single(graph, OUTPUT_TENSOR)?;
4635
4636 #[cfg(feature = "logging")]
4637 info!(target: "stdout", "retrieved the output buffer");
4638
4639 let output = match String::from_utf8(
4641 output_buffer.clone(),
4642 ) {
4643 Ok(token) => token,
4644 Err(_) => {
4645 utf8_cache
4647 .extend_from_slice(&output_buffer[..]);
4648
4649 match String::from_utf8(utf8_cache.clone()) {
4650 Ok(token) => {
4651 utf8_cache.clear();
4652 token
4653 }
4654 Err(e) => {
4655 if utf8_cache.len() > 4 {
4656 #[cfg(feature = "logging")]
4657 error!(target: "stdout", "UTF-8 decode failed, cache too long: {e}");
4658 #[cfg(feature = "logging")]
4659 error!(target: "stdout", "The cached buffer: {:?}", &utf8_cache[..]);
4660 utf8_cache.clear();
4661 } else {
4662 #[cfg(feature = "logging")]
4663 warn!(target: "stdout", "UTF-8 decode incomplete: {e}");
4664 }
4665 String::new()
4666 }
4667 }
4668 }
4669 };
4670
4671 #[cfg(feature = "logging")]
4672 info!(target: "stdout", "decoded the output buffer");
4673
4674 let created = SystemTime::now()
4675 .duration_since(std::time::UNIX_EPOCH)
4676 .map_err(|e| {
4677 let err_msg = format!(
4678 "Failed to get the current time. Reason: {e}"
4679 );
4680
4681 #[cfg(feature = "logging")]
4682 error!(target: "stdout", "{}", &err_msg);
4683
4684 LlamaCoreError::Operation(err_msg)
4685 })?;
4686
4687 let chat_completion_chunk = ChatCompletionChunk {
4688 id,
4689 object: "chat.completion.chunk".to_string(),
4690 created: created.as_secs(),
4691 model: graph.name().to_owned(),
4692 system_fingerprint: "fp_44709d6fcb".to_string(),
4693 choices: vec![ChatCompletionChunkChoice {
4694 index: 0,
4695 delta: ChatCompletionChunkChoiceDelta {
4696 role: ChatCompletionRole::Assistant,
4697 content: Some(output),
4698 tool_calls: vec![],
4699 },
4700 logprobs: None,
4701 finish_reason: None,
4702 }],
4703 usage: None,
4704 };
4705
4706 #[cfg(feature = "logging")]
4707 info!(target: "stdout", "created chat completion chunk");
4708
4709 let chunk_str =
4711 serde_json::to_string(&chat_completion_chunk)
4712 .map_err(|e| {
4713 let err_msg = format!(
4714 "Failed to serialize chat completion chunk. Reason: {e}"
4715 );
4716
4717 #[cfg(feature = "logging")]
4718 error!(target: "stdout", "{}", &err_msg);
4719
4720 LlamaCoreError::Operation(err_msg)
4721 })?;
4722
4723 Ok(format!("data: {chunk_str}\n\n"))
4724 }
4725 StreamState::Done => {
4726 *stream_state = StreamState::EndOfSequence;
4727
4728 Ok("data: [DONE]\n\n".to_string())
4729 }
4730 StreamState::EndOfSequence => {
4731 Ok("[GGML] End of sequence".to_string())
4732 }
4733 }
4734 }
4735 Err(wasmedge_wasi_nn::Error::BackendError(
4736 wasmedge_wasi_nn::BackendError::EndOfSequence,
4737 )) => {
4738 #[cfg(feature = "logging")]
4739 debug!(target: "stdout", "End of sequence");
4740
4741 match stream_state {
4742 StreamState::Usage => {
4743 *stream_state = StreamState::Done;
4744
4745 let token_info = get_token_info_by_graph(graph)?;
4747
4748 let usage = Some(Usage {
4749 prompt_tokens: token_info.prompt_tokens,
4750 completion_tokens: token_info.completion_tokens,
4751 total_tokens: token_info.prompt_tokens
4752 + token_info.completion_tokens,
4753 });
4754
4755 #[cfg(feature = "logging")]
4756 info!(target: "stdout", "token_info: {} prompt tokens, {} completion tokens", token_info.prompt_tokens, token_info.completion_tokens);
4757
4758 let created = SystemTime::now()
4759 .duration_since(std::time::UNIX_EPOCH)
4760 .map_err(|e| {
4761 let err_msg = format!(
4762 "Failed to get the current time. Reason: {e}"
4763 );
4764
4765 #[cfg(feature = "logging")]
4766 error!(target: "stdout", "{}", &err_msg);
4767
4768 LlamaCoreError::Operation(err_msg)
4769 })?;
4770
4771 let chat_completion_chunk = ChatCompletionChunk {
4772 id,
4773 object: "chat.completion.chunk".to_string(),
4774 created: created.as_secs(),
4775 model: graph.name().to_owned(),
4776 system_fingerprint: "fp_44709d6fcb".to_string(),
4777 choices: vec![],
4778 usage,
4779 };
4780
4781 let chunk_str =
4783 serde_json::to_string(&chat_completion_chunk)
4784 .map_err(|e| {
4785 let err_msg = format!(
4786 "Failed to serialize chat completion chunk. Reason: {e}"
4787 );
4788
4789 #[cfg(feature = "logging")]
4790 error!(target: "stdout", "{}", &err_msg);
4791
4792 LlamaCoreError::Operation(err_msg)
4793 })?;
4794
4795 Ok(format!("data: {chunk_str}\n\n"))
4796 }
4797 StreamState::Done | StreamState::NoUsage => {
4798 *stream_state = StreamState::EndOfSequence;
4799
4800 Ok("data: [DONE]\n\n".to_string())
4801 }
4802 StreamState::EndOfSequence => {
4803 Ok("[GGML] End of sequence".to_string())
4804 }
4805 }
4806 }
4807 Err(wasmedge_wasi_nn::Error::BackendError(
4808 wasmedge_wasi_nn::BackendError::ContextFull,
4809 )) => {
4810 #[cfg(feature = "logging")]
4811 debug!(target: "stdout", "Context full");
4812
4813 match context_full_state {
4814 ContextFullState::Message => {
4815 match include_usage {
4816 true => {
4817 *context_full_state = ContextFullState::Usage
4818 }
4819 false => {
4820 *context_full_state = ContextFullState::Done
4821 }
4822 }
4823
4824 let created = SystemTime::now()
4825 .duration_since(std::time::UNIX_EPOCH)
4826 .map_err(|e| {
4827 let err_msg = format!(
4828 "Failed to get the current time. Reason: {e}"
4829 );
4830
4831 #[cfg(feature = "logging")]
4832 error!(target: "stdout", "{}", &err_msg);
4833
4834 LlamaCoreError::Operation(err_msg)
4835 })?;
4836
4837 let chat_completion_chunk = ChatCompletionChunk {
4838 id,
4839 object: "chat.completion.chunk".to_string(),
4840 created: created.as_secs(),
4841 model: graph.name().to_owned(),
4842 system_fingerprint: "fp_44709d6fcb".to_string(),
4843 choices: vec![ChatCompletionChunkChoice {
4844 index: 0,
4845 delta: ChatCompletionChunkChoiceDelta {
4846 role: ChatCompletionRole::Assistant,
4847 content: Some(
4848 "<|WASMEDGE-GGML-CONTEXT-FULL|>"
4849 .to_string(),
4850 ),
4851 tool_calls: vec![],
4852 },
4853 logprobs: None,
4854 finish_reason: Some(FinishReason::length),
4855 }],
4856 usage: None,
4857 };
4858
4859 let chunk_str =
4861 serde_json::to_string(&chat_completion_chunk)
4862 .map_err(|e| {
4863 let err_msg = format!(
4864 "Failed to serialize chat completion chunk. Reason: {e}"
4865 );
4866
4867 #[cfg(feature = "logging")]
4868 error!(target: "stdout", "{}", &err_msg);
4869
4870 LlamaCoreError::Operation(err_msg)
4871 })?;
4872
4873 Ok(format!("data: {chunk_str}\n\n"))
4874 }
4875 ContextFullState::Usage => {
4876 *context_full_state = ContextFullState::Done;
4877
4878 let token_info = get_token_info_by_graph(graph)?;
4880
4881 let usage = Some(Usage {
4882 prompt_tokens: token_info.prompt_tokens,
4883 completion_tokens: token_info.completion_tokens,
4884 total_tokens: token_info.prompt_tokens
4885 + token_info.completion_tokens,
4886 });
4887
4888 let created = SystemTime::now()
4889 .duration_since(std::time::UNIX_EPOCH)
4890 .map_err(|e| {
4891 let err_msg = format!(
4892 "Failed to get the current time. Reason: {e}"
4893 );
4894
4895 #[cfg(feature = "logging")]
4896 error!(target: "stdout", "{}", &err_msg);
4897
4898 LlamaCoreError::Operation(err_msg)
4899 })?;
4900
4901 let chat_completion_chunk = ChatCompletionChunk {
4902 id,
4903 object: "chat.completion.chunk".to_string(),
4904 created: created.as_secs(),
4905 model: graph.name().to_owned(),
4906 system_fingerprint: "fp_44709d6fcb".to_string(),
4907 choices: vec![],
4908 usage,
4909 };
4910
4911 let chunk_str =
4913 serde_json::to_string(&chat_completion_chunk)
4914 .map_err(|e| {
4915 let err_msg = format!(
4916 "Failed to serialize chat completion chunk. Reason: {e}"
4917 );
4918
4919 #[cfg(feature = "logging")]
4920 error!(target: "stdout", "{}", &err_msg);
4921
4922 LlamaCoreError::Operation(err_msg)
4923 })?;
4924
4925 Ok(format!("data: {chunk_str}\n\n"))
4926 }
4927 ContextFullState::Done => {
4928 *context_full_state = ContextFullState::EndOfSequence;
4929
4930 Ok("data: [DONE]\n\n".to_string())
4931 }
4932 ContextFullState::EndOfSequence => {
4933 Ok("[GGML] End of sequence".to_string())
4934 }
4935 }
4936 }
4937 Err(wasmedge_wasi_nn::Error::BackendError(
4938 wasmedge_wasi_nn::BackendError::PromptTooLong,
4939 )) => {
4940 #[cfg(feature = "logging")]
4941 debug!(target: "stdout", "Prompt too long");
4942
4943 match prompt_too_long_state {
4944 PromptTooLongState::Message => {
4945 match include_usage {
4946 true => {
4947 *prompt_too_long_state =
4948 PromptTooLongState::Usage
4949 }
4950 false => {
4951 *prompt_too_long_state =
4952 PromptTooLongState::Done
4953 }
4954 }
4955
4956 let created = SystemTime::now()
4957 .duration_since(std::time::UNIX_EPOCH)
4958 .map_err(|e| {
4959 let err_msg = format!(
4960 "Failed to get the current time. Reason: {e}"
4961 );
4962
4963 #[cfg(feature = "logging")]
4964 error!(target: "stdout", "{}", &err_msg);
4965
4966 LlamaCoreError::Operation(err_msg)
4967 })?;
4968
4969 let chat_completion_chunk = ChatCompletionChunk {
4970 id,
4971 object: "chat.completion.chunk".to_string(),
4972 created: created.as_secs(),
4973 model: graph.name().to_owned(),
4974 system_fingerprint: "fp_44709d6fcb".to_string(),
4975 choices: vec![ChatCompletionChunkChoice {
4976 index: 0,
4977 delta: ChatCompletionChunkChoiceDelta {
4978 role: ChatCompletionRole::Assistant,
4979 content: None,
4980 tool_calls: vec![],
4981 },
4982 logprobs: None,
4983 finish_reason: Some(FinishReason::length),
4984 }],
4985 usage: None,
4986 };
4987
4988 let chunk_str =
4990 serde_json::to_string(&chat_completion_chunk)
4991 .map_err(|e| {
4992 let err_msg = format!(
4993 "Failed to serialize chat completion chunk. Reason: {e}"
4994 );
4995
4996 #[cfg(feature = "logging")]
4997 error!(target: "stdout", "{}", &err_msg);
4998
4999 LlamaCoreError::Operation(err_msg)
5000 })?;
5001
5002 Ok(format!("data: {chunk_str}\n\n"))
5003 }
5004 PromptTooLongState::Usage => {
5005 *prompt_too_long_state = PromptTooLongState::Done;
5006
5007 let token_info = get_token_info_by_graph(graph)?;
5009
5010 let usage = Some(Usage {
5011 prompt_tokens: token_info.prompt_tokens,
5012 completion_tokens: token_info.completion_tokens,
5013 total_tokens: token_info.prompt_tokens
5014 + token_info.completion_tokens,
5015 });
5016
5017 let created = SystemTime::now()
5018 .duration_since(std::time::UNIX_EPOCH)
5019 .map_err(|e| {
5020 let err_msg = format!(
5021 "Failed to get the current time. Reason: {e}"
5022 );
5023
5024 #[cfg(feature = "logging")]
5025 error!(target: "stdout", "{}", &err_msg);
5026
5027 LlamaCoreError::Operation(err_msg)
5028 })?;
5029
5030 let chat_completion_chunk = ChatCompletionChunk {
5031 id,
5032 object: "chat.completion.chunk".to_string(),
5033 created: created.as_secs(),
5034 model: graph.name().to_owned(),
5035 system_fingerprint: "fp_44709d6fcb".to_string(),
5036 choices: vec![],
5037 usage,
5038 };
5039
5040 let chunk_str =
5042 serde_json::to_string(&chat_completion_chunk)
5043 .map_err(|e| {
5044 let err_msg = format!(
5045 "Failed to serialize chat completion chunk. Reason: {e}"
5046 );
5047
5048 #[cfg(feature = "logging")]
5049 error!(target: "stdout", "{}", &err_msg);
5050
5051 LlamaCoreError::Operation(err_msg)
5052 })?;
5053
5054 Ok(format!("data: {chunk_str}\n\n"))
5055 }
5056 PromptTooLongState::Done => {
5057 *prompt_too_long_state =
5058 PromptTooLongState::EndOfSequence;
5059
5060 Ok("data: [DONE]\n\n".to_string())
5061 }
5062 PromptTooLongState::EndOfSequence => {
5063 Ok("[GGML] End of sequence".to_string())
5064 }
5065 }
5066 }
5067 Err(e) => {
5068 let err_msg = format!(
5069 "Failed to compute the chat completion. Reason: {e}"
5070 );
5071
5072 #[cfg(feature = "logging")]
5073 error!(target: "stdout", "{}", &err_msg);
5074
5075 Err(LlamaCoreError::Backend(BackendError::ComputeSingle(
5076 err_msg,
5077 )))
5078 }
5079 }
5080 }
5081 None => {
5082 let err_msg = "There is no model available in the chat graphs.";
5083
5084 #[cfg(feature = "logging")]
5085 error!(target: "stdout", "{}", &err_msg);
5086
5087 Err(LlamaCoreError::Operation(err_msg.into()))
5088 }
5089 }
5090 }
5091 }
5092 }
5093 None => {
5094 match chat_graphs.iter_mut().next() {
5095 Some((_, graph)) => {
5096 match graph.compute_single() {
5098 Ok(_) => {
5099 #[cfg(feature = "logging")]
5100 debug!(target: "stdout", "Compute the chat stream chunk successfully.");
5101
5102 match stream_state {
5103 StreamState::Usage | StreamState::NoUsage => {
5104 let output_buffer =
5106 get_output_buffer_single(graph, OUTPUT_TENSOR)?;
5107
5108 #[cfg(feature = "logging")]
5109 info!(target: "stdout", "retrieved the output buffer");
5110
5111 let output = match String::from_utf8(output_buffer.clone()) {
5113 Ok(token) => token,
5114 Err(_) => {
5115 utf8_cache.extend_from_slice(&output_buffer[..]);
5117
5118 match String::from_utf8(utf8_cache.clone()) {
5119 Ok(token) => {
5120 utf8_cache.clear();
5121 token
5122 }
5123 Err(e) => {
5124 if utf8_cache.len() > 4 {
5125 #[cfg(feature = "logging")]
5126 error!(target: "stdout", "UTF-8 decode failed, cache too long: {e}");
5127 #[cfg(feature = "logging")]
5128 error!(target: "stdout", "The cached buffer: {:?}", &utf8_cache[..]);
5129 utf8_cache.clear();
5130 } else {
5131 #[cfg(feature = "logging")]
5132 warn!(target: "stdout", "UTF-8 decode incomplete: {e}");
5133 }
5134 String::new()
5135 }
5136 }
5137 }
5138 };
5139
5140 #[cfg(feature = "logging")]
5141 info!(target: "stdout", "decoded the output buffer");
5142
5143 let created = SystemTime::now()
5144 .duration_since(std::time::UNIX_EPOCH)
5145 .map_err(|e| {
5146 let err_msg = format!(
5147 "Failed to get the current time. Reason: {e}"
5148 );
5149
5150 #[cfg(feature = "logging")]
5151 error!(target: "stdout", "{}", &err_msg);
5152
5153 LlamaCoreError::Operation(err_msg)
5154 })?;
5155
5156 let chat_completion_chunk = ChatCompletionChunk {
5157 id,
5158 object: "chat.completion.chunk".to_string(),
5159 created: created.as_secs(),
5160 model: graph.name().to_owned(),
5161 system_fingerprint: "fp_44709d6fcb".to_string(),
5162 choices: vec![ChatCompletionChunkChoice {
5163 index: 0,
5164 delta: ChatCompletionChunkChoiceDelta {
5165 role: ChatCompletionRole::Assistant,
5166 content: Some(output),
5167 tool_calls: vec![],
5168 },
5169 logprobs: None,
5170 finish_reason: None,
5171 }],
5172 usage: None,
5173 };
5174
5175 #[cfg(feature = "logging")]
5176 info!(target: "stdout", "created chat completion chunk");
5177
5178 let chunk_str = serde_json::to_string(&chat_completion_chunk)
5180 .map_err(|e| {
5181 let err_msg = format!(
5182 "Failed to serialize chat completion chunk. Reason: {e}"
5183 );
5184
5185 #[cfg(feature = "logging")]
5186 error!(target: "stdout", "{}", &err_msg);
5187
5188 LlamaCoreError::Operation(err_msg)
5189 })?;
5190
5191 Ok(format!("data: {chunk_str}\n\n"))
5192 }
5193 StreamState::Done => {
5194 *stream_state = StreamState::EndOfSequence;
5195
5196 Ok("data: [DONE]\n\n".to_string())
5197 }
5198 StreamState::EndOfSequence => {
5199 Ok("[GGML] End of sequence".to_string())
5200 }
5201 }
5202 }
5203 Err(wasmedge_wasi_nn::Error::BackendError(
5204 wasmedge_wasi_nn::BackendError::EndOfSequence,
5205 )) => {
5206 #[cfg(feature = "logging")]
5207 debug!(target: "stdout", "End of sequence");
5208
5209 match stream_state {
5210 StreamState::Usage => {
5211 *stream_state = StreamState::Done;
5212
5213 let token_info = get_token_info_by_graph(graph)?;
5215
5216 let usage = Some(Usage {
5217 prompt_tokens: token_info.prompt_tokens,
5218 completion_tokens: token_info.completion_tokens,
5219 total_tokens: token_info.prompt_tokens
5220 + token_info.completion_tokens,
5221 });
5222
5223 #[cfg(feature = "logging")]
5224 info!(target: "stdout", "token_info: {} prompt tokens, {} completion tokens", token_info.prompt_tokens, token_info.completion_tokens);
5225
5226 let created = SystemTime::now()
5227 .duration_since(std::time::UNIX_EPOCH)
5228 .map_err(|e| {
5229 let err_msg = format!(
5230 "Failed to get the current time. Reason: {e}"
5231 );
5232
5233 #[cfg(feature = "logging")]
5234 error!(target: "stdout", "{}", &err_msg);
5235
5236 LlamaCoreError::Operation(err_msg)
5237 })?;
5238
5239 let chat_completion_chunk = ChatCompletionChunk {
5240 id,
5241 object: "chat.completion.chunk".to_string(),
5242 created: created.as_secs(),
5243 model: graph.name().to_owned(),
5244 system_fingerprint: "fp_44709d6fcb".to_string(),
5245 choices: vec![],
5246 usage,
5247 };
5248
5249 let chunk_str = serde_json::to_string(&chat_completion_chunk)
5251 .map_err(|e| {
5252 let err_msg = format!(
5253 "Failed to serialize chat completion chunk. Reason: {e}"
5254 );
5255
5256 #[cfg(feature = "logging")]
5257 error!(target: "stdout", "{}", &err_msg);
5258
5259 LlamaCoreError::Operation(err_msg)
5260 })?;
5261
5262 Ok(format!("data: {chunk_str}\n\n"))
5263 }
5264 StreamState::Done | StreamState::NoUsage => {
5265 *stream_state = StreamState::EndOfSequence;
5266
5267 Ok("data: [DONE]\n\n".to_string())
5268 }
5269 StreamState::EndOfSequence => {
5270 Ok("[GGML] End of sequence".to_string())
5271 }
5272 }
5273 }
5274 Err(wasmedge_wasi_nn::Error::BackendError(
5275 wasmedge_wasi_nn::BackendError::ContextFull,
5276 )) => {
5277 #[cfg(feature = "logging")]
5278 debug!(target: "stdout", "Context full");
5279
5280 match context_full_state {
5281 ContextFullState::Message => {
5282 match include_usage {
5283 true => *context_full_state = ContextFullState::Usage,
5284 false => *context_full_state = ContextFullState::Done,
5285 }
5286
5287 let created = SystemTime::now()
5288 .duration_since(std::time::UNIX_EPOCH)
5289 .map_err(|e| {
5290 let err_msg = format!(
5291 "Failed to get the current time. Reason: {e}"
5292 );
5293
5294 #[cfg(feature = "logging")]
5295 error!(target: "stdout", "{}", &err_msg);
5296
5297 LlamaCoreError::Operation(err_msg)
5298 })?;
5299
5300 let chat_completion_chunk = ChatCompletionChunk {
5301 id,
5302 object: "chat.completion.chunk".to_string(),
5303 created: created.as_secs(),
5304 model: graph.name().to_owned(),
5305 system_fingerprint: "fp_44709d6fcb".to_string(),
5306 choices: vec![ChatCompletionChunkChoice {
5307 index: 0,
5308 delta: ChatCompletionChunkChoiceDelta {
5309 role: ChatCompletionRole::Assistant,
5310 content: Some(
5311 "<|WASMEDGE-GGML-CONTEXT-FULL|>".to_string(),
5312 ),
5313 tool_calls: vec![],
5314 },
5315 logprobs: None,
5316 finish_reason: Some(FinishReason::length),
5317 }],
5318 usage: None,
5319 };
5320
5321 let chunk_str = serde_json::to_string(&chat_completion_chunk)
5323 .map_err(|e| {
5324 let err_msg = format!(
5325 "Failed to serialize chat completion chunk. Reason: {e}"
5326 );
5327
5328 #[cfg(feature = "logging")]
5329 error!(target: "stdout", "{}", &err_msg);
5330
5331 LlamaCoreError::Operation(err_msg)
5332 })?;
5333
5334 Ok(format!("data: {chunk_str}\n\n"))
5335 }
5336 ContextFullState::Usage => {
5337 *context_full_state = ContextFullState::Done;
5338
5339 let token_info = get_token_info_by_graph(graph)?;
5341
5342 let usage = Some(Usage {
5343 prompt_tokens: token_info.prompt_tokens,
5344 completion_tokens: token_info.completion_tokens,
5345 total_tokens: token_info.prompt_tokens
5346 + token_info.completion_tokens,
5347 });
5348
5349 let created = SystemTime::now()
5350 .duration_since(std::time::UNIX_EPOCH)
5351 .map_err(|e| {
5352 let err_msg = format!(
5353 "Failed to get the current time. Reason: {e}"
5354 );
5355
5356 #[cfg(feature = "logging")]
5357 error!(target: "stdout", "{}", &err_msg);
5358
5359 LlamaCoreError::Operation(err_msg)
5360 })?;
5361
5362 let chat_completion_chunk = ChatCompletionChunk {
5363 id,
5364 object: "chat.completion.chunk".to_string(),
5365 created: created.as_secs(),
5366 model: graph.name().to_owned(),
5367 system_fingerprint: "fp_44709d6fcb".to_string(),
5368 choices: vec![],
5369 usage,
5370 };
5371
5372 let chunk_str = serde_json::to_string(&chat_completion_chunk)
5374 .map_err(|e| {
5375 let err_msg = format!(
5376 "Failed to serialize chat completion chunk. Reason: {e}"
5377 );
5378
5379 #[cfg(feature = "logging")]
5380 error!(target: "stdout", "{}", &err_msg);
5381
5382 LlamaCoreError::Operation(err_msg)
5383 })?;
5384
5385 Ok(format!("data: {chunk_str}\n\n"))
5386 }
5387 ContextFullState::Done => {
5388 *context_full_state = ContextFullState::EndOfSequence;
5389
5390 Ok("data: [DONE]\n\n".to_string())
5391 }
5392 ContextFullState::EndOfSequence => {
5393 Ok("[GGML] End of sequence".to_string())
5394 }
5395 }
5396 }
5397 Err(wasmedge_wasi_nn::Error::BackendError(
5398 wasmedge_wasi_nn::BackendError::PromptTooLong,
5399 )) => {
5400 #[cfg(feature = "logging")]
5401 debug!(target: "stdout", "Prompt too long");
5402
5403 match prompt_too_long_state {
5404 PromptTooLongState::Message => {
5405 match include_usage {
5406 true => *prompt_too_long_state = PromptTooLongState::Usage,
5407 false => *prompt_too_long_state = PromptTooLongState::Done,
5408 }
5409
5410 let created = SystemTime::now()
5411 .duration_since(std::time::UNIX_EPOCH)
5412 .map_err(|e| {
5413 let err_msg = format!(
5414 "Failed to get the current time. Reason: {e}"
5415 );
5416
5417 #[cfg(feature = "logging")]
5418 error!(target: "stdout", "{}", &err_msg);
5419
5420 LlamaCoreError::Operation(err_msg)
5421 })?;
5422
5423 let chat_completion_chunk = ChatCompletionChunk {
5424 id,
5425 object: "chat.completion.chunk".to_string(),
5426 created: created.as_secs(),
5427 model: graph.name().to_owned(),
5428 system_fingerprint: "fp_44709d6fcb".to_string(),
5429 choices: vec![ChatCompletionChunkChoice {
5430 index: 0,
5431 delta: ChatCompletionChunkChoiceDelta {
5432 role: ChatCompletionRole::Assistant,
5433 content: None,
5434 tool_calls: vec![],
5435 },
5436 logprobs: None,
5437 finish_reason: Some(FinishReason::length),
5438 }],
5439 usage: None,
5440 };
5441
5442 let chunk_str = serde_json::to_string(&chat_completion_chunk)
5444 .map_err(|e| {
5445 let err_msg = format!(
5446 "Failed to serialize chat completion chunk. Reason: {e}"
5447 );
5448
5449 #[cfg(feature = "logging")]
5450 error!(target: "stdout", "{}", &err_msg);
5451
5452 LlamaCoreError::Operation(err_msg)
5453 })?;
5454
5455 Ok(format!("data: {chunk_str}\n\n"))
5456 }
5457 PromptTooLongState::Usage => {
5458 *prompt_too_long_state = PromptTooLongState::Done;
5459
5460 let token_info = get_token_info_by_graph(graph)?;
5462
5463 let usage = Some(Usage {
5464 prompt_tokens: token_info.prompt_tokens,
5465 completion_tokens: token_info.completion_tokens,
5466 total_tokens: token_info.prompt_tokens
5467 + token_info.completion_tokens,
5468 });
5469
5470 let created = SystemTime::now()
5471 .duration_since(std::time::UNIX_EPOCH)
5472 .map_err(|e| {
5473 let err_msg = format!(
5474 "Failed to get the current time. Reason: {e}"
5475 );
5476
5477 #[cfg(feature = "logging")]
5478 error!(target: "stdout", "{}", &err_msg);
5479
5480 LlamaCoreError::Operation(err_msg)
5481 })?;
5482
5483 let chat_completion_chunk = ChatCompletionChunk {
5484 id,
5485 object: "chat.completion.chunk".to_string(),
5486 created: created.as_secs(),
5487 model: graph.name().to_owned(),
5488 system_fingerprint: "fp_44709d6fcb".to_string(),
5489 choices: vec![],
5490 usage,
5491 };
5492
5493 let chunk_str = serde_json::to_string(&chat_completion_chunk)
5495 .map_err(|e| {
5496 let err_msg = format!(
5497 "Failed to serialize chat completion chunk. Reason: {e}"
5498 );
5499
5500 #[cfg(feature = "logging")]
5501 error!(target: "stdout", "{}", &err_msg);
5502
5503 LlamaCoreError::Operation(err_msg)
5504 })?;
5505
5506 Ok(format!("data: {chunk_str}\n\n"))
5507 }
5508 PromptTooLongState::Done => {
5509 *prompt_too_long_state = PromptTooLongState::EndOfSequence;
5510
5511 Ok("data: [DONE]\n\n".to_string())
5512 }
5513 PromptTooLongState::EndOfSequence => {
5514 Ok("[GGML] End of sequence".to_string())
5515 }
5516 }
5517 }
5518 Err(e) => {
5519 let err_msg =
5520 format!("Failed to compute the chat completion. Reason: {e}");
5521
5522 #[cfg(feature = "logging")]
5523 error!(target: "stdout", "{}", &err_msg);
5524
5525 Err(LlamaCoreError::Backend(BackendError::ComputeSingle(
5526 err_msg,
5527 )))
5528 }
5529 }
5530 }
5531 None => {
5532 let err_msg = "There is no model available in the chat graphs.";
5533
5534 #[cfg(feature = "logging")]
5535 error!(target: "stdout", "{}", &err_msg);
5536
5537 Err(LlamaCoreError::Operation(err_msg.into()))
5538 }
5539 }
5540 }
5541 };
5542
5543 #[cfg(feature = "logging")]
5544 info!(target: "stdout", "Return the chat stream chunk!");
5545
5546 res
5547}
5548
5549#[allow(dead_code)]
5550#[derive(Debug)]
5551struct ParseResult {
5552 raw: String,
5553 content: Option<String>,
5554 tool_calls: Vec<ToolCall>,
5555}
5556
5557#[cfg(test)]
5562mod tests {
5563 use super::*;
5564
5565 #[test]
5566 fn test_model_stream_lock_new() {
5567 let lock = ModelStreamLock::new();
5568 assert!(!lock.active.load(Ordering::SeqCst));
5569 assert!(lock.waker_queue.lock().unwrap().is_empty());
5570 }
5571
5572 #[test]
5573 fn test_model_stream_lock_acquire_release() {
5574 let lock = ModelStreamLock::new();
5575
5576 assert!(lock.try_acquire());
5578 assert!(lock.active.load(Ordering::SeqCst));
5579
5580 assert!(!lock.try_acquire());
5582
5583 lock.release();
5585 assert!(!lock.active.load(Ordering::SeqCst));
5586
5587 assert!(lock.try_acquire());
5589 }
5590
5591 #[test]
5592 fn test_different_models_can_acquire_parallel() {
5593 let lock_a = Arc::new(ModelStreamLock::new());
5594 let lock_b = Arc::new(ModelStreamLock::new());
5595
5596 assert!(lock_a.try_acquire());
5598 assert!(lock_b.try_acquire());
5599
5600 assert!(lock_a.active.load(Ordering::SeqCst));
5602 assert!(lock_b.active.load(Ordering::SeqCst));
5603
5604 lock_a.release();
5606 lock_b.release();
5607
5608 assert!(!lock_a.active.load(Ordering::SeqCst));
5609 assert!(!lock_b.active.load(Ordering::SeqCst));
5610 }
5611
5612 #[test]
5613 fn test_model_stream_lock_default() {
5614 let lock = ModelStreamLock::default();
5615 assert!(!lock.active.load(Ordering::SeqCst));
5616 }
5617
5618 #[test]
5619 fn test_get_or_create_model_lock() {
5620 let lock1 = get_or_create_model_lock("test-model-1").unwrap();
5622
5623 let lock2 = get_or_create_model_lock("test-model-1").unwrap();
5625
5626 assert!(Arc::ptr_eq(&lock1, &lock2));
5628
5629 let lock3 = get_or_create_model_lock("test-model-2").unwrap();
5631 assert!(!Arc::ptr_eq(&lock1, &lock3));
5632
5633 assert!(lock1.try_acquire());
5635 assert!(lock3.try_acquire()); lock1.release();
5638 lock3.release();
5639 }
5640
5641 #[test]
5642 fn test_concurrent_access_same_model() {
5643 let lock = get_or_create_model_lock("test-concurrent-model").unwrap();
5644
5645 assert!(lock.try_acquire());
5647
5648 let lock_clone = lock.clone();
5650 assert!(!lock_clone.try_acquire());
5651
5652 lock.release();
5654
5655 assert!(lock_clone.try_acquire());
5657
5658 lock_clone.release();
5659 }
5660
5661 #[test]
5662 fn test_model_lock_guard() {
5663 let lock = Arc::new(ModelStreamLock::new());
5664
5665 assert!(lock.try_acquire());
5667 assert!(lock.active.load(Ordering::SeqCst));
5668
5669 {
5671 let _guard = ModelLockGuard::new(lock.clone());
5672 assert!(lock.active.load(Ordering::SeqCst));
5674 }
5675 assert!(!lock.active.load(Ordering::SeqCst));
5677
5678 assert!(lock.try_acquire());
5680 lock.release();
5681 }
5682
5683 #[test]
5684 fn test_model_lock_guard_ensures_release() {
5685 let lock = get_or_create_model_lock("test-guard-model").unwrap();
5686
5687 assert!(lock.try_acquire());
5689 let _guard = ModelLockGuard::new(lock.clone());
5690
5691 let lock2 = get_or_create_model_lock("test-guard-model").unwrap();
5693 assert!(!lock2.try_acquire());
5694
5695 drop(_guard);
5697
5698 assert!(lock2.try_acquire());
5700 lock2.release();
5701 }
5702}